mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 00:36:56 +02:00
Add comprehensive ANN benchmarking suite (#279)
Extend benchmarks-ann/ with results database (SQLite with per-query detail and continuous writes), dataset subfolder organization, --subset-size and --warmup options. Supports systematic comparison across flat, rescore, IVF, and DiskANN index types.
This commit is contained in:
parent
a248ecd061
commit
8544081a67
26 changed files with 2127 additions and 292 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -31,3 +31,6 @@ poetry.lock
|
||||||
|
|
||||||
memstat.c
|
memstat.c
|
||||||
memstat.*
|
memstat.*
|
||||||
|
|
||||||
|
|
||||||
|
.DS_Store
|
||||||
6
benchmarks-ann/.gitignore
vendored
6
benchmarks-ann/.gitignore
vendored
|
|
@ -1,2 +1,8 @@
|
||||||
*.db
|
*.db
|
||||||
|
*.db-shm
|
||||||
|
*.db-wal
|
||||||
|
*.parquet
|
||||||
runs/
|
runs/
|
||||||
|
|
||||||
|
viewer/
|
||||||
|
searcher/
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
BENCH = python bench.py
|
BENCH = python bench.py
|
||||||
BASE_DB = seed/base.db
|
BASE_DB = cohere1m/base.db
|
||||||
EXT = ../dist/vec0
|
EXT = ../dist/vec0
|
||||||
|
|
||||||
# --- Baseline (brute-force) configs ---
|
# --- Baseline (brute-force) configs ---
|
||||||
|
|
@ -33,7 +33,7 @@ ALL_CONFIGS = $(BASELINES) $(RESCORE_CONFIGS) $(IVF_CONFIGS) $(DISKANN_CONFIGS)
|
||||||
|
|
||||||
# --- Data preparation ---
|
# --- Data preparation ---
|
||||||
seed:
|
seed:
|
||||||
$(MAKE) -C seed
|
$(MAKE) -C cohere1m
|
||||||
|
|
||||||
ground-truth: seed
|
ground-truth: seed
|
||||||
python ground_truth.py --subset-size 10000
|
python ground_truth.py --subset-size 10000
|
||||||
|
|
@ -42,43 +42,43 @@ ground-truth: seed
|
||||||
|
|
||||||
# --- Quick smoke test ---
|
# --- Quick smoke test ---
|
||||||
bench-smoke: seed
|
bench-smoke: seed
|
||||||
$(BENCH) --subset-size 5000 -k 10 -n 20 -o runs/smoke \
|
$(BENCH) --subset-size 5000 -k 10 -n 20 --dataset cohere1m -o runs \
|
||||||
"brute-float:type=baseline,variant=float" \
|
"brute-float:type=baseline,variant=float" \
|
||||||
"ivf-quick:type=ivf,nlist=16,nprobe=4" \
|
"ivf-quick:type=ivf,nlist=16,nprobe=4" \
|
||||||
"diskann-quick:type=diskann,R=48,L=64,quantizer=binary"
|
"diskann-quick:type=diskann,R=48,L=64,quantizer=binary"
|
||||||
|
|
||||||
bench-rescore: seed
|
bench-rescore: seed
|
||||||
$(BENCH) --subset-size 10000 -k 10 -o runs/rescore \
|
$(BENCH) --subset-size 10000 -k 10 --dataset cohere1m -o runs \
|
||||||
$(RESCORE_CONFIGS)
|
$(RESCORE_CONFIGS)
|
||||||
|
|
||||||
|
|
||||||
# --- Standard sizes ---
|
# --- Standard sizes ---
|
||||||
bench-10k: seed
|
bench-10k: seed
|
||||||
$(BENCH) --subset-size 10000 -k 10 -o runs/10k $(ALL_CONFIGS)
|
$(BENCH) --subset-size 10000 -k 10 --dataset cohere1m -o runs $(ALL_CONFIGS)
|
||||||
|
|
||||||
bench-50k: seed
|
bench-50k: seed
|
||||||
$(BENCH) --subset-size 50000 -k 10 -o runs/50k $(ALL_CONFIGS)
|
$(BENCH) --subset-size 50000 -k 10 --dataset cohere1m -o runs $(ALL_CONFIGS)
|
||||||
|
|
||||||
bench-100k: seed
|
bench-100k: seed
|
||||||
$(BENCH) --subset-size 100000 -k 10 -o runs/100k $(ALL_CONFIGS)
|
$(BENCH) --subset-size 100000 -k 10 --dataset cohere1m -o runs $(ALL_CONFIGS)
|
||||||
|
|
||||||
bench-all: bench-10k bench-50k bench-100k
|
bench-all: bench-10k bench-50k bench-100k
|
||||||
|
|
||||||
# --- IVF across sizes ---
|
# --- IVF across sizes ---
|
||||||
bench-ivf: seed
|
bench-ivf: seed
|
||||||
$(BENCH) --subset-size 10000 -k 10 -o runs/ivf $(BASELINES) $(IVF_CONFIGS)
|
$(BENCH) --subset-size 10000 -k 10 --dataset cohere1m -o runs $(BASELINES) $(IVF_CONFIGS)
|
||||||
$(BENCH) --subset-size 50000 -k 10 -o runs/ivf $(BASELINES) $(IVF_CONFIGS)
|
$(BENCH) --subset-size 50000 -k 10 --dataset cohere1m -o runs $(BASELINES) $(IVF_CONFIGS)
|
||||||
$(BENCH) --subset-size 100000 -k 10 -o runs/ivf $(BASELINES) $(IVF_CONFIGS)
|
$(BENCH) --subset-size 100000 -k 10 --dataset cohere1m -o runs $(BASELINES) $(IVF_CONFIGS)
|
||||||
|
|
||||||
# --- DiskANN across sizes ---
|
# --- DiskANN across sizes ---
|
||||||
bench-diskann: seed
|
bench-diskann: seed
|
||||||
$(BENCH) --subset-size 10000 -k 10 -o runs/diskann $(BASELINES) $(DISKANN_CONFIGS)
|
$(BENCH) --subset-size 10000 -k 10 --dataset cohere1m -o runs $(BASELINES) $(DISKANN_CONFIGS)
|
||||||
$(BENCH) --subset-size 50000 -k 10 -o runs/diskann $(BASELINES) $(DISKANN_CONFIGS)
|
$(BENCH) --subset-size 50000 -k 10 --dataset cohere1m -o runs $(BASELINES) $(DISKANN_CONFIGS)
|
||||||
$(BENCH) --subset-size 100000 -k 10 -o runs/diskann $(BASELINES) $(DISKANN_CONFIGS)
|
$(BENCH) --subset-size 100000 -k 10 --dataset cohere1m -o runs $(BASELINES) $(DISKANN_CONFIGS)
|
||||||
|
|
||||||
# --- Report ---
|
# --- Report ---
|
||||||
report:
|
report:
|
||||||
@echo "Use: sqlite3 runs/<dir>/results.db 'SELECT * FROM bench_results ORDER BY recall DESC'"
|
@echo "Use: sqlite3 runs/cohere1m/<size>/results.db 'SELECT run_id, config_name, status, recall FROM runs JOIN run_results USING(run_id)'"
|
||||||
|
|
||||||
# --- Cleanup ---
|
# --- Cleanup ---
|
||||||
clean:
|
clean:
|
||||||
|
|
|
||||||
|
|
@ -1,81 +1,111 @@
|
||||||
# KNN Benchmarks for sqlite-vec
|
# KNN Benchmarks for sqlite-vec
|
||||||
|
|
||||||
Benchmarking infrastructure for vec0 KNN configurations. Includes brute-force
|
Benchmarking infrastructure for vec0 KNN configurations. Includes brute-force
|
||||||
baselines (float, int8, bit); index-specific branches add their own types
|
baselines (float, int8, bit), rescore, IVF, and DiskANN index types.
|
||||||
via the `INDEX_REGISTRY` in `bench.py`.
|
|
||||||
|
## Datasets
|
||||||
|
|
||||||
|
Each dataset is a subdirectory containing a `Makefile` and `build_base_db.py`
|
||||||
|
that produce a `base.db`. The benchmark runner auto-discovers any subdirectory
|
||||||
|
with a `base.db` file.
|
||||||
|
|
||||||
|
```
|
||||||
|
cohere1m/ # Cohere 768d cosine, 1M vectors
|
||||||
|
Makefile # downloads parquets from Zilliz, builds base.db
|
||||||
|
build_base_db.py
|
||||||
|
base.db # (generated)
|
||||||
|
|
||||||
|
cohere10m/ # Cohere 768d cosine, 10M vectors (10 train shards)
|
||||||
|
Makefile # make -j12 download to fetch all shards in parallel
|
||||||
|
build_base_db.py
|
||||||
|
base.db # (generated)
|
||||||
|
```
|
||||||
|
|
||||||
|
Every `base.db` has the same schema:
|
||||||
|
|
||||||
|
| Table | Columns | Description |
|
||||||
|
|-------|---------|-------------|
|
||||||
|
| `train` | `id INTEGER PRIMARY KEY, vector BLOB` | Indexed vectors (f32 blobs) |
|
||||||
|
| `query_vectors` | `id INTEGER PRIMARY KEY, vector BLOB` | Query vectors for KNN evaluation |
|
||||||
|
| `neighbors` | `query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT` | Ground-truth nearest neighbors |
|
||||||
|
|
||||||
|
To add a new dataset, create a directory with a `Makefile` that builds `base.db`
|
||||||
|
with the tables above. It will be available via `--dataset <dirname>` automatically.
|
||||||
|
|
||||||
|
### Building datasets
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Cohere 1M
|
||||||
|
cd cohere1m && make download && make && cd ..
|
||||||
|
|
||||||
|
# Cohere 10M (parallel download recommended — 10 train shards + test + neighbors)
|
||||||
|
cd cohere10m && make -j12 download && make && cd ..
|
||||||
|
```
|
||||||
|
|
||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
- Built `dist/vec0` extension (run `make` from repo root)
|
- Built `dist/vec0` extension (run `make loadable` from repo root)
|
||||||
- Python 3.10+
|
- Python 3.10+
|
||||||
- `uv` (for seed data prep): `pip install uv`
|
- `uv`
|
||||||
|
|
||||||
## Quick start
|
## Quick start
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 1. Download dataset and build seed DB (~3 GB download, ~5 min)
|
# 1. Build a dataset
|
||||||
make seed
|
cd cohere1m && make && cd ..
|
||||||
|
|
||||||
# 2. Run a quick smoke test (5k vectors, ~1 min)
|
# 2. Quick smoke test (5k vectors)
|
||||||
make bench-smoke
|
make bench-smoke
|
||||||
|
|
||||||
# 3. Run full benchmark at 10k
|
# 3. Full benchmark at 10k
|
||||||
make bench-10k
|
make bench-10k
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
### Direct invocation
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python bench.py --subset-size 10000 \
|
uv run python bench.py --subset-size 10000 -k 10 -n 50 --dataset cohere1m \
|
||||||
"brute-float:type=baseline,variant=float" \
|
"brute-float:type=baseline,variant=float" \
|
||||||
"brute-int8:type=baseline,variant=int8" \
|
"rescore-bit-os8:type=rescore,quantizer=bit,oversample=8"
|
||||||
"brute-bit:type=baseline,variant=bit"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Config format
|
### Config format
|
||||||
|
|
||||||
`name:type=<index_type>,key=val,key=val`
|
`name:type=<index_type>,key=val,key=val`
|
||||||
|
|
||||||
| Index type | Keys | Branch |
|
| Index type | Keys |
|
||||||
|-----------|------|--------|
|
|-----------|------|
|
||||||
| `baseline` | `variant` (float/int8/bit), `oversample` | this branch |
|
| `baseline` | `variant` (float/int8/bit), `oversample` |
|
||||||
|
| `rescore` | `quantizer` (bit/int8), `oversample` |
|
||||||
Index branches register additional types in `INDEX_REGISTRY`. See the
|
| `ivf` | `nlist`, `nprobe` |
|
||||||
docstring in `bench.py` for the extension API.
|
| `diskann` | `R`, `L`, `quantizer` (binary/int8), `buffer_threshold` |
|
||||||
|
|
||||||
### Make targets
|
### Make targets
|
||||||
|
|
||||||
| Target | Description |
|
| Target | Description |
|
||||||
|--------|-------------|
|
|--------|-------------|
|
||||||
| `make seed` | Download COHERE 1M dataset |
|
| `make seed` | Download and build default dataset |
|
||||||
| `make ground-truth` | Pre-compute ground truth for 10k/50k/100k |
|
| `make bench-smoke` | Quick 5k test (3 configs) |
|
||||||
| `make bench-smoke` | Quick 5k baseline test |
|
|
||||||
| `make bench-10k` | All configs at 10k vectors |
|
| `make bench-10k` | All configs at 10k vectors |
|
||||||
| `make bench-50k` | All configs at 50k vectors |
|
| `make bench-50k` | All configs at 50k vectors |
|
||||||
| `make bench-100k` | All configs at 100k vectors |
|
| `make bench-100k` | All configs at 100k vectors |
|
||||||
| `make bench-all` | 10k + 50k + 100k |
|
| `make bench-all` | 10k + 50k + 100k |
|
||||||
|
| `make bench-ivf` | Baselines + IVF across 10k/50k/100k |
|
||||||
|
| `make bench-diskann` | Baselines + DiskANN across 10k/50k/100k |
|
||||||
|
|
||||||
|
## Results DB
|
||||||
|
|
||||||
|
Each run writes to `runs/<dataset>/<subset_size>/results.db` (SQLite, WAL mode).
|
||||||
|
Progress is written continuously — query from another terminal to monitor:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sqlite3 runs/cohere1m/10000/results.db "SELECT run_id, config_name, status FROM runs"
|
||||||
|
```
|
||||||
|
|
||||||
|
See `results_schema.sql` for the full schema (tables: `runs`, `run_results`,
|
||||||
|
`insert_batches`, `queries`).
|
||||||
|
|
||||||
## Adding an index type
|
## Adding an index type
|
||||||
|
|
||||||
In your index branch, add an entry to `INDEX_REGISTRY` in `bench.py` and
|
Add an entry to `INDEX_REGISTRY` in `bench.py` and append configs to
|
||||||
append your configs to `ALL_CONFIGS` in the `Makefile`. See the existing
|
`ALL_CONFIGS` in the `Makefile`. See existing entries for the pattern.
|
||||||
`baseline` entry and the comments in both files for the pattern.
|
|
||||||
|
|
||||||
## Results
|
|
||||||
|
|
||||||
Results are stored in `runs/<dir>/results.db` using the schema in `schema.sql`.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
sqlite3 runs/10k/results.db "
|
|
||||||
SELECT config_name, recall, mean_ms, qps
|
|
||||||
FROM bench_results
|
|
||||||
ORDER BY recall DESC
|
|
||||||
"
|
|
||||||
```
|
|
||||||
|
|
||||||
## Dataset
|
|
||||||
|
|
||||||
[Zilliz COHERE Medium 1M](https://zilliz.com/learn/datasets-for-vector-database-benchmarks):
|
|
||||||
768 dimensions, cosine distance, 1M train vectors + 10k query vectors with precomputed neighbors.
|
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
27
benchmarks-ann/datasets/cohere10m/Makefile
Normal file
27
benchmarks-ann/datasets/cohere10m/Makefile
Normal file
|
|
@ -0,0 +1,27 @@
|
||||||
|
BASE_URL = https://assets.zilliz.com/benchmark/cohere_large_10m
|
||||||
|
|
||||||
|
TRAIN_PARQUETS = $(shell printf 'train-%02d-of-10.parquet ' 0 1 2 3 4 5 6 7 8 9)
|
||||||
|
OTHER_PARQUETS = test.parquet neighbors.parquet
|
||||||
|
PARQUETS = $(TRAIN_PARQUETS) $(OTHER_PARQUETS)
|
||||||
|
|
||||||
|
.PHONY: all download clean
|
||||||
|
|
||||||
|
all: base.db
|
||||||
|
|
||||||
|
# Use: make -j12 download
|
||||||
|
download: $(PARQUETS)
|
||||||
|
|
||||||
|
train-%-of-10.parquet:
|
||||||
|
curl -L -o $@ $(BASE_URL)/$@
|
||||||
|
|
||||||
|
test.parquet:
|
||||||
|
curl -L -o $@ $(BASE_URL)/test.parquet
|
||||||
|
|
||||||
|
neighbors.parquet:
|
||||||
|
curl -L -o $@ $(BASE_URL)/neighbors.parquet
|
||||||
|
|
||||||
|
base.db: $(PARQUETS) build_base_db.py
|
||||||
|
uv run --with pandas --with pyarrow python build_base_db.py
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -f base.db
|
||||||
134
benchmarks-ann/datasets/cohere10m/build_base_db.py
Normal file
134
benchmarks-ann/datasets/cohere10m/build_base_db.py
Normal file
|
|
@ -0,0 +1,134 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Build base.db from downloaded parquet files (10M dataset, 10 train shards).
|
||||||
|
|
||||||
|
Reads train-00-of-10.parquet .. train-09-of-10.parquet, test.parquet,
|
||||||
|
neighbors.parquet and creates a SQLite database with tables:
|
||||||
|
train, query_vectors, neighbors.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
uv run --with pandas --with pyarrow python build_base_db.py
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
TRAIN_SHARDS = 10
|
||||||
|
|
||||||
|
|
||||||
|
def float_list_to_blob(floats):
|
||||||
|
"""Pack a list of floats into a little-endian f32 blob."""
|
||||||
|
return struct.pack(f"<{len(floats)}f", *floats)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
db_path = os.path.join(script_dir, "base.db")
|
||||||
|
|
||||||
|
train_paths = [
|
||||||
|
os.path.join(script_dir, f"train-{i:02d}-of-{TRAIN_SHARDS}.parquet")
|
||||||
|
for i in range(TRAIN_SHARDS)
|
||||||
|
]
|
||||||
|
test_path = os.path.join(script_dir, "test.parquet")
|
||||||
|
neighbors_path = os.path.join(script_dir, "neighbors.parquet")
|
||||||
|
|
||||||
|
for p in train_paths + [test_path, neighbors_path]:
|
||||||
|
if not os.path.exists(p):
|
||||||
|
print(f"ERROR: {p} not found. Run 'make download' first.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if os.path.exists(db_path):
|
||||||
|
os.remove(db_path)
|
||||||
|
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
conn.execute("PRAGMA page_size=4096")
|
||||||
|
|
||||||
|
# --- query_vectors (from test.parquet) ---
|
||||||
|
print("Loading test.parquet (query vectors)...")
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
df_test = pd.read_parquet(test_path)
|
||||||
|
conn.execute(
|
||||||
|
"CREATE TABLE query_vectors (id INTEGER PRIMARY KEY, vector BLOB)"
|
||||||
|
)
|
||||||
|
rows = []
|
||||||
|
for _, row in df_test.iterrows():
|
||||||
|
rows.append((int(row["id"]), float_list_to_blob(row["emb"])))
|
||||||
|
conn.executemany("INSERT INTO query_vectors (id, vector) VALUES (?, ?)", rows)
|
||||||
|
conn.commit()
|
||||||
|
print(f" {len(rows)} query vectors in {time.perf_counter() - t0:.1f}s")
|
||||||
|
|
||||||
|
# --- neighbors (from neighbors.parquet) ---
|
||||||
|
print("Loading neighbors.parquet...")
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
df_neighbors = pd.read_parquet(neighbors_path)
|
||||||
|
conn.execute(
|
||||||
|
"CREATE TABLE neighbors ("
|
||||||
|
" query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT,"
|
||||||
|
" UNIQUE(query_vector_id, rank))"
|
||||||
|
)
|
||||||
|
rows = []
|
||||||
|
for _, row in df_neighbors.iterrows():
|
||||||
|
qid = int(row["id"])
|
||||||
|
nids = row["neighbors_id"]
|
||||||
|
if isinstance(nids, str):
|
||||||
|
nids = json.loads(nids)
|
||||||
|
for rank, nid in enumerate(nids):
|
||||||
|
rows.append((qid, rank, str(int(nid))))
|
||||||
|
conn.executemany(
|
||||||
|
"INSERT INTO neighbors (query_vector_id, rank, neighbors_id) VALUES (?, ?, ?)",
|
||||||
|
rows,
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
print(f" {len(rows)} neighbor rows in {time.perf_counter() - t0:.1f}s")
|
||||||
|
|
||||||
|
# --- train (from 10 shard parquets) ---
|
||||||
|
print(f"Loading {TRAIN_SHARDS} train shards (10M vectors, this will take a while)...")
|
||||||
|
conn.execute(
|
||||||
|
"CREATE TABLE train (id INTEGER PRIMARY KEY, vector BLOB)"
|
||||||
|
)
|
||||||
|
|
||||||
|
global_t0 = time.perf_counter()
|
||||||
|
total_inserted = 0
|
||||||
|
batch_size = 10000
|
||||||
|
|
||||||
|
for shard_idx, train_path in enumerate(train_paths):
|
||||||
|
print(f" Shard {shard_idx + 1}/{TRAIN_SHARDS}: {os.path.basename(train_path)}")
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
df = pd.read_parquet(train_path)
|
||||||
|
shard_len = len(df)
|
||||||
|
|
||||||
|
for start in range(0, shard_len, batch_size):
|
||||||
|
chunk = df.iloc[start : start + batch_size]
|
||||||
|
rows = []
|
||||||
|
for _, row in chunk.iterrows():
|
||||||
|
rows.append((int(row["id"]), float_list_to_blob(row["emb"])))
|
||||||
|
conn.executemany("INSERT INTO train (id, vector) VALUES (?, ?)", rows)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
total_inserted += len(rows)
|
||||||
|
if total_inserted % 100000 < batch_size:
|
||||||
|
elapsed = time.perf_counter() - global_t0
|
||||||
|
rate = total_inserted / elapsed if elapsed > 0 else 0
|
||||||
|
print(
|
||||||
|
f" {total_inserted:>10} {elapsed:.0f}s {rate:.0f} rows/s",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
shard_elapsed = time.perf_counter() - t0
|
||||||
|
print(f" shard done: {shard_len} rows in {shard_elapsed:.1f}s")
|
||||||
|
|
||||||
|
elapsed = time.perf_counter() - global_t0
|
||||||
|
print(f" {total_inserted} train vectors in {elapsed:.1f}s")
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
size_mb = os.path.getsize(db_path) / (1024 * 1024)
|
||||||
|
print(f"\nDone: {db_path} ({size_mb:.0f} MB)")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
30
benchmarks-ann/datasets/nyt-1024/Makefile
Normal file
30
benchmarks-ann/datasets/nyt-1024/Makefile
Normal file
|
|
@ -0,0 +1,30 @@
|
||||||
|
MODEL ?= mixedbread-ai/mxbai-embed-large-v1
|
||||||
|
K ?= 100
|
||||||
|
BATCH_SIZE ?= 256
|
||||||
|
DATA_DIR ?= ../nyt/data
|
||||||
|
|
||||||
|
all: base.db
|
||||||
|
|
||||||
|
# Reuse data from ../nyt
|
||||||
|
$(DATA_DIR):
|
||||||
|
$(MAKE) -C ../nyt data
|
||||||
|
|
||||||
|
contents.db: $(DATA_DIR)
|
||||||
|
uv run ../nyt-768/build-contents.py --data-dir $(DATA_DIR) -o $@
|
||||||
|
|
||||||
|
base.db: contents.db queries.txt
|
||||||
|
uv run build-base.py \
|
||||||
|
--contents-db contents.db \
|
||||||
|
--model $(MODEL) \
|
||||||
|
--queries-file queries.txt \
|
||||||
|
--batch-size $(BATCH_SIZE) \
|
||||||
|
--k $(K) \
|
||||||
|
-o $@
|
||||||
|
|
||||||
|
queries.txt:
|
||||||
|
cp ../nyt/queries.txt $@
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -f base.db contents.db
|
||||||
|
|
||||||
|
.PHONY: all clean
|
||||||
163
benchmarks-ann/datasets/nyt-1024/build-base.py
Normal file
163
benchmarks-ann/datasets/nyt-1024/build-base.py
Normal file
|
|
@ -0,0 +1,163 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "sentence-transformers",
|
||||||
|
# "torch<=2.7",
|
||||||
|
# "tqdm",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sqlite3
|
||||||
|
from array import array
|
||||||
|
from itertools import batched
|
||||||
|
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Build base.db with train vectors, query vectors, and brute-force KNN neighbors",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--contents-db", "-c", default=None,
|
||||||
|
help="Path to contents.db (source of headlines and IDs)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", "-m", default="mixedbread-ai/mxbai-embed-large-v1",
|
||||||
|
help="HuggingFace model ID (default: mixedbread-ai/mxbai-embed-large-v1)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--queries-file", "-q", default="queries.txt",
|
||||||
|
help="Path to the queries file (default: queries.txt)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output", "-o", required=True,
|
||||||
|
help="Path to the output base.db",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size", "-b", type=int, default=256,
|
||||||
|
help="Batch size for embedding (default: 256)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--k", "-k", type=int, default=100,
|
||||||
|
help="Number of nearest neighbors (default: 100)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--limit", "-l", type=int, default=0,
|
||||||
|
help="Limit number of headlines to embed (0 = all)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vec-path", "-v", default="~/projects/sqlite-vec/dist/vec0",
|
||||||
|
help="Path to sqlite-vec extension (default: ~/projects/sqlite-vec/dist/vec0)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip-neighbors", action="store_true",
|
||||||
|
help="Skip the brute-force KNN neighbor computation",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
import os
|
||||||
|
vec_path = os.path.expanduser(args.vec_path)
|
||||||
|
|
||||||
|
print(f"Loading model {args.model}...")
|
||||||
|
model = SentenceTransformer(args.model)
|
||||||
|
|
||||||
|
# Read headlines from contents.db
|
||||||
|
src = sqlite3.connect(args.contents_db)
|
||||||
|
limit_clause = f" LIMIT {args.limit}" if args.limit > 0 else ""
|
||||||
|
headlines = src.execute(
|
||||||
|
f"SELECT id, headline FROM contents ORDER BY id{limit_clause}"
|
||||||
|
).fetchall()
|
||||||
|
src.close()
|
||||||
|
print(f"Loaded {len(headlines)} headlines from {args.contents_db}")
|
||||||
|
|
||||||
|
# Read queries
|
||||||
|
with open(args.queries_file) as f:
|
||||||
|
queries = [line.strip() for line in f if line.strip()]
|
||||||
|
print(f"Loaded {len(queries)} queries from {args.queries_file}")
|
||||||
|
|
||||||
|
# Create output database
|
||||||
|
db = sqlite3.connect(args.output)
|
||||||
|
db.enable_load_extension(True)
|
||||||
|
db.load_extension(vec_path)
|
||||||
|
db.enable_load_extension(False)
|
||||||
|
|
||||||
|
db.execute("CREATE TABLE IF NOT EXISTS train(id INTEGER PRIMARY KEY, vector BLOB)")
|
||||||
|
db.execute("CREATE TABLE IF NOT EXISTS query_vectors(id INTEGER PRIMARY KEY, vector BLOB)")
|
||||||
|
db.execute(
|
||||||
|
"CREATE TABLE IF NOT EXISTS neighbors("
|
||||||
|
" query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT,"
|
||||||
|
" UNIQUE(query_vector_id, rank))"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 1: Embed headlines -> train table
|
||||||
|
print("Embedding headlines...")
|
||||||
|
for batch in tqdm(
|
||||||
|
batched(headlines, args.batch_size),
|
||||||
|
total=(len(headlines) + args.batch_size - 1) // args.batch_size,
|
||||||
|
):
|
||||||
|
ids = [r[0] for r in batch]
|
||||||
|
texts = [r[1] for r in batch]
|
||||||
|
embeddings = model.encode(texts, normalize_embeddings=True)
|
||||||
|
|
||||||
|
params = [
|
||||||
|
(int(rid), array("f", emb.tolist()).tobytes())
|
||||||
|
for rid, emb in zip(ids, embeddings)
|
||||||
|
]
|
||||||
|
db.executemany("INSERT INTO train VALUES (?, ?)", params)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
del headlines
|
||||||
|
n = db.execute("SELECT count(*) FROM train").fetchone()[0]
|
||||||
|
print(f"Embedded {n} headlines")
|
||||||
|
|
||||||
|
# Step 2: Embed queries -> query_vectors table
|
||||||
|
print("Embedding queries...")
|
||||||
|
query_embeddings = model.encode(queries, normalize_embeddings=True)
|
||||||
|
query_params = []
|
||||||
|
for i, emb in enumerate(query_embeddings, 1):
|
||||||
|
blob = array("f", emb.tolist()).tobytes()
|
||||||
|
query_params.append((i, blob))
|
||||||
|
db.executemany("INSERT INTO query_vectors VALUES (?, ?)", query_params)
|
||||||
|
db.commit()
|
||||||
|
print(f"Embedded {len(queries)} queries")
|
||||||
|
|
||||||
|
if args.skip_neighbors:
|
||||||
|
db.close()
|
||||||
|
print(f"Done (skipped neighbors). Wrote {args.output}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 3: Brute-force KNN via sqlite-vec -> neighbors table
|
||||||
|
n_queries = db.execute("SELECT count(*) FROM query_vectors").fetchone()[0]
|
||||||
|
print(f"Computing {args.k}-NN for {n_queries} queries via sqlite-vec...")
|
||||||
|
for query_id, query_blob in tqdm(
|
||||||
|
db.execute("SELECT id, vector FROM query_vectors").fetchall()
|
||||||
|
):
|
||||||
|
results = db.execute(
|
||||||
|
"""
|
||||||
|
SELECT
|
||||||
|
train.id,
|
||||||
|
vec_distance_cosine(train.vector, ?) AS distance
|
||||||
|
FROM train
|
||||||
|
WHERE distance IS NOT NULL
|
||||||
|
ORDER BY distance ASC
|
||||||
|
LIMIT ?
|
||||||
|
""",
|
||||||
|
(query_blob, args.k),
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
params = [
|
||||||
|
(query_id, rank, str(rid))
|
||||||
|
for rank, (rid, _dist) in enumerate(results)
|
||||||
|
]
|
||||||
|
db.executemany("INSERT INTO neighbors VALUES (?, ?, ?)", params)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.close()
|
||||||
|
print(f"Done. Wrote {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
100
benchmarks-ann/datasets/nyt-1024/queries.txt
Normal file
100
benchmarks-ann/datasets/nyt-1024/queries.txt
Normal file
|
|
@ -0,0 +1,100 @@
|
||||||
|
latest news on climate change policy
|
||||||
|
presidential election results and analysis
|
||||||
|
stock market crash causes
|
||||||
|
coronavirus vaccine development updates
|
||||||
|
artificial intelligence breakthrough in healthcare
|
||||||
|
supreme court ruling on abortion rights
|
||||||
|
tech companies layoff announcements
|
||||||
|
earthquake damages in California
|
||||||
|
cybersecurity breach at major corporation
|
||||||
|
space exploration mission to Mars
|
||||||
|
immigration reform legislation debate
|
||||||
|
renewable energy investment trends
|
||||||
|
healthcare costs rising across America
|
||||||
|
protests against police brutality
|
||||||
|
wildfires destroy homes in the West
|
||||||
|
Olympic games highlights and records
|
||||||
|
celebrity scandal rocks Hollywood
|
||||||
|
breakthrough cancer treatment discovered
|
||||||
|
housing market bubble concerns
|
||||||
|
federal reserve interest rate decision
|
||||||
|
school shooting tragedy response
|
||||||
|
diplomatic tensions between superpowers
|
||||||
|
drone strike kills terrorist leader
|
||||||
|
social media platform faces regulation
|
||||||
|
archaeological discovery reveals ancient civilization
|
||||||
|
unemployment rate hits record low
|
||||||
|
autonomous vehicles testing expansion
|
||||||
|
streaming service launches original content
|
||||||
|
opioid crisis intervention programs
|
||||||
|
trade war tariffs impact economy
|
||||||
|
infrastructure bill passes Congress
|
||||||
|
data privacy concerns grow
|
||||||
|
minimum wage increase proposal
|
||||||
|
college admissions scandal exposed
|
||||||
|
NFL player protest during anthem
|
||||||
|
cryptocurrency regulation debate
|
||||||
|
pandemic lockdown restrictions eased
|
||||||
|
mass shooting gun control debate
|
||||||
|
tax reform legislation impact
|
||||||
|
ransomware attack cripples pipeline
|
||||||
|
climate activists stage demonstration
|
||||||
|
sports team wins championship
|
||||||
|
banking system collapse fears
|
||||||
|
pharmaceutical company fraud charges
|
||||||
|
genetic engineering ethical concerns
|
||||||
|
border wall funding controversy
|
||||||
|
impeachment proceedings begin
|
||||||
|
nuclear weapons treaty violation
|
||||||
|
artificial meat alternative launch
|
||||||
|
student loan debt forgiveness
|
||||||
|
venture capital funding decline
|
||||||
|
facial recognition ban proposed
|
||||||
|
election interference investigation
|
||||||
|
pandemic preparedness failures
|
||||||
|
police reform measures announced
|
||||||
|
wildfire prevention strategies
|
||||||
|
ocean pollution crisis worsens
|
||||||
|
manufacturing jobs returning
|
||||||
|
pension fund shortfall concerns
|
||||||
|
antitrust investigation launched
|
||||||
|
voting rights protection act
|
||||||
|
mental health awareness campaign
|
||||||
|
homeless population increasing
|
||||||
|
space debris collision risk
|
||||||
|
drug cartel violence escalates
|
||||||
|
renewable energy jobs growth
|
||||||
|
infrastructure deterioration report
|
||||||
|
vaccine mandate legal challenge
|
||||||
|
cryptocurrency market volatility
|
||||||
|
autonomous drone delivery service
|
||||||
|
deep fake technology dangers
|
||||||
|
Arctic ice melting accelerates
|
||||||
|
income inequality gap widens
|
||||||
|
election fraud claims disputed
|
||||||
|
corporate merger blocked
|
||||||
|
medical breakthrough extends life
|
||||||
|
transportation strike disrupts city
|
||||||
|
racial justice protests spread
|
||||||
|
carbon emissions reduction goals
|
||||||
|
financial crisis warning signs
|
||||||
|
cyberbullying prevention efforts
|
||||||
|
asteroid near miss with Earth
|
||||||
|
gene therapy approval granted
|
||||||
|
labor union organizing drive
|
||||||
|
surveillance technology expansion
|
||||||
|
education funding cuts proposed
|
||||||
|
disaster relief efforts underway
|
||||||
|
housing affordability crisis
|
||||||
|
clean water access shortage
|
||||||
|
artificial intelligence job displacement
|
||||||
|
trade agreement negotiations
|
||||||
|
prison reform initiative launched
|
||||||
|
species extinction accelerates
|
||||||
|
political corruption scandal
|
||||||
|
terrorism threat level raised
|
||||||
|
food safety contamination outbreak
|
||||||
|
ai model release
|
||||||
|
affordability interest rates
|
||||||
|
peanut allergies in newbons
|
||||||
|
breaking bad walter white
|
||||||
29
benchmarks-ann/datasets/nyt-384/Makefile
Normal file
29
benchmarks-ann/datasets/nyt-384/Makefile
Normal file
|
|
@ -0,0 +1,29 @@
|
||||||
|
MODEL ?= mixedbread-ai/mxbai-embed-xsmall-v1
|
||||||
|
K ?= 100
|
||||||
|
BATCH_SIZE ?= 512
|
||||||
|
DATA_DIR ?= ../nyt/data
|
||||||
|
|
||||||
|
all: base.db
|
||||||
|
|
||||||
|
$(DATA_DIR):
|
||||||
|
$(MAKE) -C ../nyt data
|
||||||
|
|
||||||
|
contents.db: $(DATA_DIR)
|
||||||
|
uv run ../nyt-768/build-contents.py --data-dir $(DATA_DIR) -o $@
|
||||||
|
|
||||||
|
base.db: contents.db queries.txt
|
||||||
|
uv run ../nyt-1024/build-base.py \
|
||||||
|
--contents-db contents.db \
|
||||||
|
--model $(MODEL) \
|
||||||
|
--queries-file queries.txt \
|
||||||
|
--batch-size $(BATCH_SIZE) \
|
||||||
|
--k $(K) \
|
||||||
|
-o $@
|
||||||
|
|
||||||
|
queries.txt:
|
||||||
|
cp ../nyt/queries.txt $@
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -f base.db contents.db
|
||||||
|
|
||||||
|
.PHONY: all clean
|
||||||
100
benchmarks-ann/datasets/nyt-384/queries.txt
Normal file
100
benchmarks-ann/datasets/nyt-384/queries.txt
Normal file
|
|
@ -0,0 +1,100 @@
|
||||||
|
latest news on climate change policy
|
||||||
|
presidential election results and analysis
|
||||||
|
stock market crash causes
|
||||||
|
coronavirus vaccine development updates
|
||||||
|
artificial intelligence breakthrough in healthcare
|
||||||
|
supreme court ruling on abortion rights
|
||||||
|
tech companies layoff announcements
|
||||||
|
earthquake damages in California
|
||||||
|
cybersecurity breach at major corporation
|
||||||
|
space exploration mission to Mars
|
||||||
|
immigration reform legislation debate
|
||||||
|
renewable energy investment trends
|
||||||
|
healthcare costs rising across America
|
||||||
|
protests against police brutality
|
||||||
|
wildfires destroy homes in the West
|
||||||
|
Olympic games highlights and records
|
||||||
|
celebrity scandal rocks Hollywood
|
||||||
|
breakthrough cancer treatment discovered
|
||||||
|
housing market bubble concerns
|
||||||
|
federal reserve interest rate decision
|
||||||
|
school shooting tragedy response
|
||||||
|
diplomatic tensions between superpowers
|
||||||
|
drone strike kills terrorist leader
|
||||||
|
social media platform faces regulation
|
||||||
|
archaeological discovery reveals ancient civilization
|
||||||
|
unemployment rate hits record low
|
||||||
|
autonomous vehicles testing expansion
|
||||||
|
streaming service launches original content
|
||||||
|
opioid crisis intervention programs
|
||||||
|
trade war tariffs impact economy
|
||||||
|
infrastructure bill passes Congress
|
||||||
|
data privacy concerns grow
|
||||||
|
minimum wage increase proposal
|
||||||
|
college admissions scandal exposed
|
||||||
|
NFL player protest during anthem
|
||||||
|
cryptocurrency regulation debate
|
||||||
|
pandemic lockdown restrictions eased
|
||||||
|
mass shooting gun control debate
|
||||||
|
tax reform legislation impact
|
||||||
|
ransomware attack cripples pipeline
|
||||||
|
climate activists stage demonstration
|
||||||
|
sports team wins championship
|
||||||
|
banking system collapse fears
|
||||||
|
pharmaceutical company fraud charges
|
||||||
|
genetic engineering ethical concerns
|
||||||
|
border wall funding controversy
|
||||||
|
impeachment proceedings begin
|
||||||
|
nuclear weapons treaty violation
|
||||||
|
artificial meat alternative launch
|
||||||
|
student loan debt forgiveness
|
||||||
|
venture capital funding decline
|
||||||
|
facial recognition ban proposed
|
||||||
|
election interference investigation
|
||||||
|
pandemic preparedness failures
|
||||||
|
police reform measures announced
|
||||||
|
wildfire prevention strategies
|
||||||
|
ocean pollution crisis worsens
|
||||||
|
manufacturing jobs returning
|
||||||
|
pension fund shortfall concerns
|
||||||
|
antitrust investigation launched
|
||||||
|
voting rights protection act
|
||||||
|
mental health awareness campaign
|
||||||
|
homeless population increasing
|
||||||
|
space debris collision risk
|
||||||
|
drug cartel violence escalates
|
||||||
|
renewable energy jobs growth
|
||||||
|
infrastructure deterioration report
|
||||||
|
vaccine mandate legal challenge
|
||||||
|
cryptocurrency market volatility
|
||||||
|
autonomous drone delivery service
|
||||||
|
deep fake technology dangers
|
||||||
|
Arctic ice melting accelerates
|
||||||
|
income inequality gap widens
|
||||||
|
election fraud claims disputed
|
||||||
|
corporate merger blocked
|
||||||
|
medical breakthrough extends life
|
||||||
|
transportation strike disrupts city
|
||||||
|
racial justice protests spread
|
||||||
|
carbon emissions reduction goals
|
||||||
|
financial crisis warning signs
|
||||||
|
cyberbullying prevention efforts
|
||||||
|
asteroid near miss with Earth
|
||||||
|
gene therapy approval granted
|
||||||
|
labor union organizing drive
|
||||||
|
surveillance technology expansion
|
||||||
|
education funding cuts proposed
|
||||||
|
disaster relief efforts underway
|
||||||
|
housing affordability crisis
|
||||||
|
clean water access shortage
|
||||||
|
artificial intelligence job displacement
|
||||||
|
trade agreement negotiations
|
||||||
|
prison reform initiative launched
|
||||||
|
species extinction accelerates
|
||||||
|
political corruption scandal
|
||||||
|
terrorism threat level raised
|
||||||
|
food safety contamination outbreak
|
||||||
|
ai model release
|
||||||
|
affordability interest rates
|
||||||
|
peanut allergies in newbons
|
||||||
|
breaking bad walter white
|
||||||
37
benchmarks-ann/datasets/nyt-768/Makefile
Normal file
37
benchmarks-ann/datasets/nyt-768/Makefile
Normal file
|
|
@ -0,0 +1,37 @@
|
||||||
|
MODEL ?= bge-base-en-v1.5-768
|
||||||
|
K ?= 100
|
||||||
|
BATCH_SIZE ?= 512
|
||||||
|
DATA_DIR ?= ../nyt/data
|
||||||
|
|
||||||
|
all: base.db
|
||||||
|
|
||||||
|
# Reuse data from ../nyt
|
||||||
|
$(DATA_DIR):
|
||||||
|
$(MAKE) -C ../nyt data
|
||||||
|
|
||||||
|
# Distill model (separate step, may take a while)
|
||||||
|
$(MODEL):
|
||||||
|
uv run distill-model.py
|
||||||
|
|
||||||
|
contents.db: $(DATA_DIR)
|
||||||
|
uv run build-contents.py --data-dir $(DATA_DIR) -o $@
|
||||||
|
|
||||||
|
base.db: contents.db queries.txt $(MODEL)
|
||||||
|
uv run ../nyt/build-base.py \
|
||||||
|
--contents-db contents.db \
|
||||||
|
--model $(MODEL) \
|
||||||
|
--queries-file queries.txt \
|
||||||
|
--batch-size $(BATCH_SIZE) \
|
||||||
|
--k $(K) \
|
||||||
|
-o $@
|
||||||
|
|
||||||
|
queries.txt:
|
||||||
|
cp ../nyt/queries.txt $@
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -f base.db contents.db
|
||||||
|
|
||||||
|
clean-all: clean
|
||||||
|
rm -rf $(MODEL)
|
||||||
|
|
||||||
|
.PHONY: all clean clean-all
|
||||||
64
benchmarks-ann/datasets/nyt-768/build-contents.py
Normal file
64
benchmarks-ann/datasets/nyt-768/build-contents.py
Normal file
|
|
@ -0,0 +1,64 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "duckdb",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sqlite3
|
||||||
|
import duckdb
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Load NYT headline CSVs into a SQLite contents database (most recent 1M, deduplicated)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--data-dir", "-d", default="../nyt/data",
|
||||||
|
help="Directory containing NYT CSV files (default: ../nyt/data)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--limit", "-l", type=int, default=1_000_000,
|
||||||
|
help="Maximum number of headlines to keep (default: 1000000)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output", "-o", required=True,
|
||||||
|
help="Path to the output SQLite database",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
glob_pattern = f"{args.data_dir}/new_york_times_stories_*.csv"
|
||||||
|
|
||||||
|
con = duckdb.connect()
|
||||||
|
rows = con.execute(
|
||||||
|
f"""
|
||||||
|
WITH deduped AS (
|
||||||
|
SELECT
|
||||||
|
headline,
|
||||||
|
max(pub_date) AS pub_date
|
||||||
|
FROM read_csv('{glob_pattern}', auto_detect=true, union_by_name=true)
|
||||||
|
WHERE headline IS NOT NULL AND trim(headline) != ''
|
||||||
|
GROUP BY headline
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
row_number() OVER (ORDER BY pub_date DESC) AS id,
|
||||||
|
headline
|
||||||
|
FROM deduped
|
||||||
|
ORDER BY pub_date DESC
|
||||||
|
LIMIT {args.limit}
|
||||||
|
"""
|
||||||
|
).fetchall()
|
||||||
|
con.close()
|
||||||
|
|
||||||
|
db = sqlite3.connect(args.output)
|
||||||
|
db.execute("CREATE TABLE contents(id INTEGER PRIMARY KEY, headline TEXT)")
|
||||||
|
db.executemany("INSERT INTO contents VALUES (?, ?)", rows)
|
||||||
|
db.commit()
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
print(f"Wrote {len(rows)} headlines to {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
13
benchmarks-ann/datasets/nyt-768/distill-model.py
Normal file
13
benchmarks-ann/datasets/nyt-768/distill-model.py
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "model2vec[distill]",
|
||||||
|
# "torch<=2.7",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
from model2vec.distill import distill
|
||||||
|
|
||||||
|
model = distill(model_name="BAAI/bge-base-en-v1.5", pca_dims=768)
|
||||||
|
model.save_pretrained("bge-base-en-v1.5-768")
|
||||||
|
print("Saved distilled model to bge-base-en-v1.5-768/")
|
||||||
100
benchmarks-ann/datasets/nyt-768/queries.txt
Normal file
100
benchmarks-ann/datasets/nyt-768/queries.txt
Normal file
|
|
@ -0,0 +1,100 @@
|
||||||
|
latest news on climate change policy
|
||||||
|
presidential election results and analysis
|
||||||
|
stock market crash causes
|
||||||
|
coronavirus vaccine development updates
|
||||||
|
artificial intelligence breakthrough in healthcare
|
||||||
|
supreme court ruling on abortion rights
|
||||||
|
tech companies layoff announcements
|
||||||
|
earthquake damages in California
|
||||||
|
cybersecurity breach at major corporation
|
||||||
|
space exploration mission to Mars
|
||||||
|
immigration reform legislation debate
|
||||||
|
renewable energy investment trends
|
||||||
|
healthcare costs rising across America
|
||||||
|
protests against police brutality
|
||||||
|
wildfires destroy homes in the West
|
||||||
|
Olympic games highlights and records
|
||||||
|
celebrity scandal rocks Hollywood
|
||||||
|
breakthrough cancer treatment discovered
|
||||||
|
housing market bubble concerns
|
||||||
|
federal reserve interest rate decision
|
||||||
|
school shooting tragedy response
|
||||||
|
diplomatic tensions between superpowers
|
||||||
|
drone strike kills terrorist leader
|
||||||
|
social media platform faces regulation
|
||||||
|
archaeological discovery reveals ancient civilization
|
||||||
|
unemployment rate hits record low
|
||||||
|
autonomous vehicles testing expansion
|
||||||
|
streaming service launches original content
|
||||||
|
opioid crisis intervention programs
|
||||||
|
trade war tariffs impact economy
|
||||||
|
infrastructure bill passes Congress
|
||||||
|
data privacy concerns grow
|
||||||
|
minimum wage increase proposal
|
||||||
|
college admissions scandal exposed
|
||||||
|
NFL player protest during anthem
|
||||||
|
cryptocurrency regulation debate
|
||||||
|
pandemic lockdown restrictions eased
|
||||||
|
mass shooting gun control debate
|
||||||
|
tax reform legislation impact
|
||||||
|
ransomware attack cripples pipeline
|
||||||
|
climate activists stage demonstration
|
||||||
|
sports team wins championship
|
||||||
|
banking system collapse fears
|
||||||
|
pharmaceutical company fraud charges
|
||||||
|
genetic engineering ethical concerns
|
||||||
|
border wall funding controversy
|
||||||
|
impeachment proceedings begin
|
||||||
|
nuclear weapons treaty violation
|
||||||
|
artificial meat alternative launch
|
||||||
|
student loan debt forgiveness
|
||||||
|
venture capital funding decline
|
||||||
|
facial recognition ban proposed
|
||||||
|
election interference investigation
|
||||||
|
pandemic preparedness failures
|
||||||
|
police reform measures announced
|
||||||
|
wildfire prevention strategies
|
||||||
|
ocean pollution crisis worsens
|
||||||
|
manufacturing jobs returning
|
||||||
|
pension fund shortfall concerns
|
||||||
|
antitrust investigation launched
|
||||||
|
voting rights protection act
|
||||||
|
mental health awareness campaign
|
||||||
|
homeless population increasing
|
||||||
|
space debris collision risk
|
||||||
|
drug cartel violence escalates
|
||||||
|
renewable energy jobs growth
|
||||||
|
infrastructure deterioration report
|
||||||
|
vaccine mandate legal challenge
|
||||||
|
cryptocurrency market volatility
|
||||||
|
autonomous drone delivery service
|
||||||
|
deep fake technology dangers
|
||||||
|
Arctic ice melting accelerates
|
||||||
|
income inequality gap widens
|
||||||
|
election fraud claims disputed
|
||||||
|
corporate merger blocked
|
||||||
|
medical breakthrough extends life
|
||||||
|
transportation strike disrupts city
|
||||||
|
racial justice protests spread
|
||||||
|
carbon emissions reduction goals
|
||||||
|
financial crisis warning signs
|
||||||
|
cyberbullying prevention efforts
|
||||||
|
asteroid near miss with Earth
|
||||||
|
gene therapy approval granted
|
||||||
|
labor union organizing drive
|
||||||
|
surveillance technology expansion
|
||||||
|
education funding cuts proposed
|
||||||
|
disaster relief efforts underway
|
||||||
|
housing affordability crisis
|
||||||
|
clean water access shortage
|
||||||
|
artificial intelligence job displacement
|
||||||
|
trade agreement negotiations
|
||||||
|
prison reform initiative launched
|
||||||
|
species extinction accelerates
|
||||||
|
political corruption scandal
|
||||||
|
terrorism threat level raised
|
||||||
|
food safety contamination outbreak
|
||||||
|
ai model release
|
||||||
|
affordability interest rates
|
||||||
|
peanut allergies in newbons
|
||||||
|
breaking bad walter white
|
||||||
1
benchmarks-ann/datasets/nyt/.gitignore
vendored
Normal file
1
benchmarks-ann/datasets/nyt/.gitignore
vendored
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
data/
|
||||||
30
benchmarks-ann/datasets/nyt/Makefile
Normal file
30
benchmarks-ann/datasets/nyt/Makefile
Normal file
|
|
@ -0,0 +1,30 @@
|
||||||
|
MODEL ?= minishlab/potion-base-8M
|
||||||
|
K ?= 100
|
||||||
|
BATCH_SIZE ?= 512
|
||||||
|
DATA_DIR ?= data
|
||||||
|
|
||||||
|
all: base.db contents.db
|
||||||
|
|
||||||
|
# Download NYT headlines CSVs from Kaggle (requires `kaggle` CLI + API token)
|
||||||
|
$(DATA_DIR):
|
||||||
|
kaggle datasets download -d johnbandy/new-york-times-headlines -p $(DATA_DIR) --unzip
|
||||||
|
|
||||||
|
contents.db: $(DATA_DIR)
|
||||||
|
uv run build-contents.py --data-dir $(DATA_DIR) -o $@
|
||||||
|
|
||||||
|
base.db: contents.db queries.txt
|
||||||
|
uv run build-base.py \
|
||||||
|
--contents-db contents.db \
|
||||||
|
--model $(MODEL) \
|
||||||
|
--queries-file queries.txt \
|
||||||
|
--batch-size $(BATCH_SIZE) \
|
||||||
|
--k $(K) \
|
||||||
|
-o $@
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -f base.db contents.db
|
||||||
|
|
||||||
|
clean-all: clean
|
||||||
|
rm -rf $(DATA_DIR)
|
||||||
|
|
||||||
|
.PHONY: all clean clean-all
|
||||||
165
benchmarks-ann/datasets/nyt/build-base.py
Normal file
165
benchmarks-ann/datasets/nyt/build-base.py
Normal file
|
|
@ -0,0 +1,165 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "model2vec",
|
||||||
|
# "torch<=2.7",
|
||||||
|
# "tqdm",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sqlite3
|
||||||
|
from array import array
|
||||||
|
from itertools import batched
|
||||||
|
|
||||||
|
from model2vec import StaticModel
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Build base.db with train vectors, query vectors, and brute-force KNN neighbors",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--contents-db", "-c", default=None,
|
||||||
|
help="Path to contents.db (source of headlines and IDs)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", "-m", default="minishlab/potion-base-8M",
|
||||||
|
help="HuggingFace model ID or local path (default: minishlab/potion-base-8M)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--queries-file", "-q", default="queries.txt",
|
||||||
|
help="Path to the queries file (default: queries.txt)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output", "-o", required=True,
|
||||||
|
help="Path to the output base.db",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size", "-b", type=int, default=512,
|
||||||
|
help="Batch size for embedding (default: 512)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--k", "-k", type=int, default=100,
|
||||||
|
help="Number of nearest neighbors (default: 100)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vec-path", "-v", default="~/projects/sqlite-vec/dist/vec0",
|
||||||
|
help="Path to sqlite-vec extension (default: ~/projects/sqlite-vec/dist/vec0)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rebuild-neighbors", action="store_true",
|
||||||
|
help="Only rebuild the neighbors table (skip embedding steps)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
import os
|
||||||
|
vec_path = os.path.expanduser(args.vec_path)
|
||||||
|
|
||||||
|
if args.rebuild_neighbors:
|
||||||
|
# Skip embedding, just open existing DB and rebuild neighbors
|
||||||
|
db = sqlite3.connect(args.output)
|
||||||
|
db.enable_load_extension(True)
|
||||||
|
db.load_extension(vec_path)
|
||||||
|
db.enable_load_extension(False)
|
||||||
|
db.execute("DROP TABLE IF EXISTS neighbors")
|
||||||
|
db.execute(
|
||||||
|
"CREATE TABLE neighbors("
|
||||||
|
" query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT,"
|
||||||
|
" UNIQUE(query_vector_id, rank))"
|
||||||
|
)
|
||||||
|
print(f"Rebuilding neighbors in {args.output}...")
|
||||||
|
else:
|
||||||
|
print(f"Loading model {args.model}...")
|
||||||
|
model = StaticModel.from_pretrained(args.model)
|
||||||
|
|
||||||
|
# Read headlines from contents.db
|
||||||
|
src = sqlite3.connect(args.contents_db)
|
||||||
|
headlines = src.execute("SELECT id, headline FROM contents ORDER BY id").fetchall()
|
||||||
|
src.close()
|
||||||
|
print(f"Loaded {len(headlines)} headlines from {args.contents_db}")
|
||||||
|
|
||||||
|
# Read queries
|
||||||
|
with open(args.queries_file) as f:
|
||||||
|
queries = [line.strip() for line in f if line.strip()]
|
||||||
|
print(f"Loaded {len(queries)} queries from {args.queries_file}")
|
||||||
|
|
||||||
|
# Create output database
|
||||||
|
db = sqlite3.connect(args.output)
|
||||||
|
db.enable_load_extension(True)
|
||||||
|
db.load_extension(vec_path)
|
||||||
|
db.enable_load_extension(False)
|
||||||
|
|
||||||
|
db.execute("CREATE TABLE train(id INTEGER PRIMARY KEY, vector BLOB)")
|
||||||
|
db.execute("CREATE TABLE query_vectors(id INTEGER PRIMARY KEY, vector BLOB)")
|
||||||
|
db.execute(
|
||||||
|
"CREATE TABLE neighbors("
|
||||||
|
" query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT,"
|
||||||
|
" UNIQUE(query_vector_id, rank))"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 1: Embed headlines -> train table
|
||||||
|
print("Embedding headlines...")
|
||||||
|
for batch in tqdm(
|
||||||
|
batched(headlines, args.batch_size),
|
||||||
|
total=(len(headlines) + args.batch_size - 1) // args.batch_size,
|
||||||
|
):
|
||||||
|
ids = [r[0] for r in batch]
|
||||||
|
texts = [r[1] for r in batch]
|
||||||
|
embeddings = model.encode(texts)
|
||||||
|
|
||||||
|
params = [
|
||||||
|
(int(rid), array("f", emb.tolist()).tobytes())
|
||||||
|
for rid, emb in zip(ids, embeddings)
|
||||||
|
]
|
||||||
|
db.executemany("INSERT INTO train VALUES (?, ?)", params)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
del headlines
|
||||||
|
n = db.execute("SELECT count(*) FROM train").fetchone()[0]
|
||||||
|
print(f"Embedded {n} headlines")
|
||||||
|
|
||||||
|
# Step 2: Embed queries -> query_vectors table
|
||||||
|
print("Embedding queries...")
|
||||||
|
query_embeddings = model.encode(queries)
|
||||||
|
query_params = []
|
||||||
|
for i, emb in enumerate(query_embeddings, 1):
|
||||||
|
blob = array("f", emb.tolist()).tobytes()
|
||||||
|
query_params.append((i, blob))
|
||||||
|
db.executemany("INSERT INTO query_vectors VALUES (?, ?)", query_params)
|
||||||
|
db.commit()
|
||||||
|
print(f"Embedded {len(queries)} queries")
|
||||||
|
|
||||||
|
# Step 3: Brute-force KNN via sqlite-vec -> neighbors table
|
||||||
|
n_queries = db.execute("SELECT count(*) FROM query_vectors").fetchone()[0]
|
||||||
|
print(f"Computing {args.k}-NN for {n_queries} queries via sqlite-vec...")
|
||||||
|
for query_id, query_blob in tqdm(
|
||||||
|
db.execute("SELECT id, vector FROM query_vectors").fetchall()
|
||||||
|
):
|
||||||
|
results = db.execute(
|
||||||
|
"""
|
||||||
|
SELECT
|
||||||
|
train.id,
|
||||||
|
vec_distance_cosine(train.vector, ?) AS distance
|
||||||
|
FROM train
|
||||||
|
WHERE distance IS NOT NULL
|
||||||
|
ORDER BY distance ASC
|
||||||
|
LIMIT ?
|
||||||
|
""",
|
||||||
|
(query_blob, args.k),
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
params = [
|
||||||
|
(query_id, rank, str(rid))
|
||||||
|
for rank, (rid, _dist) in enumerate(results)
|
||||||
|
]
|
||||||
|
db.executemany("INSERT INTO neighbors VALUES (?, ?, ?)", params)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.close()
|
||||||
|
print(f"Done. Wrote {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
52
benchmarks-ann/datasets/nyt/build-contents.py
Normal file
52
benchmarks-ann/datasets/nyt/build-contents.py
Normal file
|
|
@ -0,0 +1,52 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "duckdb",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import duckdb
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Load NYT headline CSVs into a SQLite contents database via DuckDB",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--data-dir", "-d", default="data",
|
||||||
|
help="Directory containing NYT CSV files (default: data)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output", "-o", required=True,
|
||||||
|
help="Path to the output SQLite database",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
glob_pattern = os.path.join(args.data_dir, "new_york_times_stories_*.csv")
|
||||||
|
|
||||||
|
con = duckdb.connect()
|
||||||
|
rows = con.execute(
|
||||||
|
f"""
|
||||||
|
SELECT
|
||||||
|
row_number() OVER () AS id,
|
||||||
|
headline
|
||||||
|
FROM read_csv('{glob_pattern}', auto_detect=true, union_by_name=true)
|
||||||
|
WHERE headline IS NOT NULL AND headline != ''
|
||||||
|
"""
|
||||||
|
).fetchall()
|
||||||
|
con.close()
|
||||||
|
|
||||||
|
db = sqlite3.connect(args.output)
|
||||||
|
db.execute("CREATE TABLE contents(id INTEGER PRIMARY KEY, headline TEXT)")
|
||||||
|
db.executemany("INSERT INTO contents VALUES (?, ?)", rows)
|
||||||
|
db.commit()
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
print(f"Wrote {len(rows)} headlines to {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
100
benchmarks-ann/datasets/nyt/queries.txt
Normal file
100
benchmarks-ann/datasets/nyt/queries.txt
Normal file
|
|
@ -0,0 +1,100 @@
|
||||||
|
latest news on climate change policy
|
||||||
|
presidential election results and analysis
|
||||||
|
stock market crash causes
|
||||||
|
coronavirus vaccine development updates
|
||||||
|
artificial intelligence breakthrough in healthcare
|
||||||
|
supreme court ruling on abortion rights
|
||||||
|
tech companies layoff announcements
|
||||||
|
earthquake damages in California
|
||||||
|
cybersecurity breach at major corporation
|
||||||
|
space exploration mission to Mars
|
||||||
|
immigration reform legislation debate
|
||||||
|
renewable energy investment trends
|
||||||
|
healthcare costs rising across America
|
||||||
|
protests against police brutality
|
||||||
|
wildfires destroy homes in the West
|
||||||
|
Olympic games highlights and records
|
||||||
|
celebrity scandal rocks Hollywood
|
||||||
|
breakthrough cancer treatment discovered
|
||||||
|
housing market bubble concerns
|
||||||
|
federal reserve interest rate decision
|
||||||
|
school shooting tragedy response
|
||||||
|
diplomatic tensions between superpowers
|
||||||
|
drone strike kills terrorist leader
|
||||||
|
social media platform faces regulation
|
||||||
|
archaeological discovery reveals ancient civilization
|
||||||
|
unemployment rate hits record low
|
||||||
|
autonomous vehicles testing expansion
|
||||||
|
streaming service launches original content
|
||||||
|
opioid crisis intervention programs
|
||||||
|
trade war tariffs impact economy
|
||||||
|
infrastructure bill passes Congress
|
||||||
|
data privacy concerns grow
|
||||||
|
minimum wage increase proposal
|
||||||
|
college admissions scandal exposed
|
||||||
|
NFL player protest during anthem
|
||||||
|
cryptocurrency regulation debate
|
||||||
|
pandemic lockdown restrictions eased
|
||||||
|
mass shooting gun control debate
|
||||||
|
tax reform legislation impact
|
||||||
|
ransomware attack cripples pipeline
|
||||||
|
climate activists stage demonstration
|
||||||
|
sports team wins championship
|
||||||
|
banking system collapse fears
|
||||||
|
pharmaceutical company fraud charges
|
||||||
|
genetic engineering ethical concerns
|
||||||
|
border wall funding controversy
|
||||||
|
impeachment proceedings begin
|
||||||
|
nuclear weapons treaty violation
|
||||||
|
artificial meat alternative launch
|
||||||
|
student loan debt forgiveness
|
||||||
|
venture capital funding decline
|
||||||
|
facial recognition ban proposed
|
||||||
|
election interference investigation
|
||||||
|
pandemic preparedness failures
|
||||||
|
police reform measures announced
|
||||||
|
wildfire prevention strategies
|
||||||
|
ocean pollution crisis worsens
|
||||||
|
manufacturing jobs returning
|
||||||
|
pension fund shortfall concerns
|
||||||
|
antitrust investigation launched
|
||||||
|
voting rights protection act
|
||||||
|
mental health awareness campaign
|
||||||
|
homeless population increasing
|
||||||
|
space debris collision risk
|
||||||
|
drug cartel violence escalates
|
||||||
|
renewable energy jobs growth
|
||||||
|
infrastructure deterioration report
|
||||||
|
vaccine mandate legal challenge
|
||||||
|
cryptocurrency market volatility
|
||||||
|
autonomous drone delivery service
|
||||||
|
deep fake technology dangers
|
||||||
|
Arctic ice melting accelerates
|
||||||
|
income inequality gap widens
|
||||||
|
election fraud claims disputed
|
||||||
|
corporate merger blocked
|
||||||
|
medical breakthrough extends life
|
||||||
|
transportation strike disrupts city
|
||||||
|
racial justice protests spread
|
||||||
|
carbon emissions reduction goals
|
||||||
|
financial crisis warning signs
|
||||||
|
cyberbullying prevention efforts
|
||||||
|
asteroid near miss with Earth
|
||||||
|
gene therapy approval granted
|
||||||
|
labor union organizing drive
|
||||||
|
surveillance technology expansion
|
||||||
|
education funding cuts proposed
|
||||||
|
disaster relief efforts underway
|
||||||
|
housing affordability crisis
|
||||||
|
clean water access shortage
|
||||||
|
artificial intelligence job displacement
|
||||||
|
trade agreement negotiations
|
||||||
|
prison reform initiative launched
|
||||||
|
species extinction accelerates
|
||||||
|
political corruption scandal
|
||||||
|
terrorism threat level raised
|
||||||
|
food safety contamination outbreak
|
||||||
|
ai model release
|
||||||
|
affordability interest rates
|
||||||
|
peanut allergies in newbons
|
||||||
|
breaking bad walter white
|
||||||
101
benchmarks-ann/faiss_kmeans.py
Normal file
101
benchmarks-ann/faiss_kmeans.py
Normal file
|
|
@ -0,0 +1,101 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Compute k-means centroids using FAISS and save to a centroids DB.
|
||||||
|
|
||||||
|
Reads the first N vectors from a base.db, runs FAISS k-means, and writes
|
||||||
|
the centroids to an output SQLite DB as float32 blobs.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python faiss_kmeans.py --base-db datasets/cohere10m/base.db --ntrain 100000 \
|
||||||
|
--nclusters 8192 -o centroids.db
|
||||||
|
|
||||||
|
Output schema:
|
||||||
|
CREATE TABLE centroids (
|
||||||
|
centroid_id INTEGER PRIMARY KEY,
|
||||||
|
centroid BLOB NOT NULL -- float32[D]
|
||||||
|
);
|
||||||
|
CREATE TABLE meta (key TEXT PRIMARY KEY, value TEXT);
|
||||||
|
-- ntrain, nclusters, dimensions, elapsed_s
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import struct
|
||||||
|
import time
|
||||||
|
|
||||||
|
import faiss
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="FAISS k-means centroid computation")
|
||||||
|
parser.add_argument("--base-db", required=True, help="path to base.db with train table")
|
||||||
|
parser.add_argument("--ntrain", type=int, required=True, help="number of vectors to train on")
|
||||||
|
parser.add_argument("--nclusters", type=int, required=True, help="number of clusters (nlist)")
|
||||||
|
parser.add_argument("--niter", type=int, default=20, help="k-means iterations (default 20)")
|
||||||
|
parser.add_argument("--seed", type=int, default=42, help="random seed")
|
||||||
|
parser.add_argument("-o", "--output", required=True, help="output centroids DB path")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Load vectors
|
||||||
|
print(f"Loading {args.ntrain} vectors from {args.base_db}...")
|
||||||
|
conn = sqlite3.connect(args.base_db)
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT vector FROM train ORDER BY id LIMIT ?", (args.ntrain,)
|
||||||
|
).fetchall()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
# Parse float32 blobs to numpy
|
||||||
|
first_blob = rows[0][0]
|
||||||
|
D = len(first_blob) // 4 # float32
|
||||||
|
print(f" Dimensions: {D}, loaded {len(rows)} vectors")
|
||||||
|
|
||||||
|
vectors = np.zeros((len(rows), D), dtype=np.float32)
|
||||||
|
for i, (blob,) in enumerate(rows):
|
||||||
|
vectors[i] = np.frombuffer(blob, dtype=np.float32)
|
||||||
|
|
||||||
|
# Normalize for cosine distance (FAISS k-means on L2 of unit vectors ≈ cosine)
|
||||||
|
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
|
||||||
|
norms[norms == 0] = 1
|
||||||
|
vectors /= norms
|
||||||
|
|
||||||
|
# Run FAISS k-means
|
||||||
|
print(f"Running k-means: {args.nclusters} clusters, {args.niter} iterations...")
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
kmeans = faiss.Kmeans(
|
||||||
|
D, args.nclusters,
|
||||||
|
niter=args.niter,
|
||||||
|
seed=args.seed,
|
||||||
|
verbose=True,
|
||||||
|
gpu=False,
|
||||||
|
)
|
||||||
|
kmeans.train(vectors)
|
||||||
|
elapsed = time.perf_counter() - t0
|
||||||
|
print(f" Done in {elapsed:.1f}s")
|
||||||
|
|
||||||
|
centroids = kmeans.centroids # (nclusters, D) float32
|
||||||
|
|
||||||
|
# Write output DB
|
||||||
|
if os.path.exists(args.output):
|
||||||
|
os.remove(args.output)
|
||||||
|
out = sqlite3.connect(args.output)
|
||||||
|
out.execute("CREATE TABLE centroids (centroid_id INTEGER PRIMARY KEY, centroid BLOB NOT NULL)")
|
||||||
|
out.execute("CREATE TABLE meta (key TEXT PRIMARY KEY, value TEXT)")
|
||||||
|
|
||||||
|
for i in range(args.nclusters):
|
||||||
|
blob = centroids[i].tobytes()
|
||||||
|
out.execute("INSERT INTO centroids (centroid_id, centroid) VALUES (?, ?)", (i, blob))
|
||||||
|
|
||||||
|
out.execute("INSERT INTO meta VALUES ('ntrain', ?)", (str(args.ntrain),))
|
||||||
|
out.execute("INSERT INTO meta VALUES ('nclusters', ?)", (str(args.nclusters),))
|
||||||
|
out.execute("INSERT INTO meta VALUES ('dimensions', ?)", (str(D),))
|
||||||
|
out.execute("INSERT INTO meta VALUES ('niter', ?)", (str(args.niter),))
|
||||||
|
out.execute("INSERT INTO meta VALUES ('elapsed_s', ?)", (str(round(elapsed, 3)),))
|
||||||
|
out.execute("INSERT INTO meta VALUES ('seed', ?)", (str(args.seed),))
|
||||||
|
out.commit()
|
||||||
|
out.close()
|
||||||
|
|
||||||
|
print(f"Wrote {args.nclusters} centroids to {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
76
benchmarks-ann/results_schema.sql
Normal file
76
benchmarks-ann/results_schema.sql
Normal file
|
|
@ -0,0 +1,76 @@
|
||||||
|
-- Comprehensive results schema for vec0 KNN benchmark runs.
|
||||||
|
-- Created in WAL mode: PRAGMA journal_mode=WAL
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS runs (
|
||||||
|
run_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
config_name TEXT NOT NULL,
|
||||||
|
index_type TEXT NOT NULL,
|
||||||
|
params TEXT NOT NULL, -- JSON: {"R":48,"L":128,"quantizer":"binary"}
|
||||||
|
dataset TEXT NOT NULL, -- "cohere1m"
|
||||||
|
subset_size INTEGER NOT NULL,
|
||||||
|
k INTEGER NOT NULL,
|
||||||
|
n_queries INTEGER NOT NULL,
|
||||||
|
phase TEXT NOT NULL DEFAULT 'both',
|
||||||
|
-- 'build', 'query', or 'both'
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending',
|
||||||
|
-- pending → inserting → training → querying → done | built | error
|
||||||
|
created_at_ns INTEGER NOT NULL -- time.time_ns()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS run_results (
|
||||||
|
run_id INTEGER PRIMARY KEY REFERENCES runs(run_id),
|
||||||
|
insert_started_ns INTEGER,
|
||||||
|
insert_ended_ns INTEGER,
|
||||||
|
insert_duration_ns INTEGER,
|
||||||
|
train_started_ns INTEGER, -- NULL if no training
|
||||||
|
train_ended_ns INTEGER,
|
||||||
|
train_duration_ns INTEGER,
|
||||||
|
build_duration_ns INTEGER, -- insert + train
|
||||||
|
db_file_size_bytes INTEGER,
|
||||||
|
db_file_path TEXT,
|
||||||
|
create_sql TEXT, -- CREATE VIRTUAL TABLE ...
|
||||||
|
insert_sql TEXT, -- INSERT INTO vec_items ...
|
||||||
|
train_sql TEXT, -- NULL if no training step
|
||||||
|
query_sql TEXT, -- SELECT ... WHERE embedding MATCH ...
|
||||||
|
k INTEGER, -- denormalized from runs for easy filtering
|
||||||
|
query_mean_ms REAL, -- denormalized aggregates
|
||||||
|
query_median_ms REAL,
|
||||||
|
query_p99_ms REAL,
|
||||||
|
query_total_ms REAL,
|
||||||
|
qps REAL,
|
||||||
|
recall REAL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS insert_batches (
|
||||||
|
batch_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
run_id INTEGER NOT NULL REFERENCES runs(run_id),
|
||||||
|
batch_lo INTEGER NOT NULL, -- start index (inclusive)
|
||||||
|
batch_hi INTEGER NOT NULL, -- end index (exclusive)
|
||||||
|
rows_in_batch INTEGER NOT NULL,
|
||||||
|
started_ns INTEGER NOT NULL,
|
||||||
|
ended_ns INTEGER NOT NULL,
|
||||||
|
duration_ns INTEGER NOT NULL,
|
||||||
|
cumulative_rows INTEGER NOT NULL, -- total rows inserted so far
|
||||||
|
rate_rows_per_s REAL NOT NULL -- cumulative rate
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS queries (
|
||||||
|
query_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
run_id INTEGER NOT NULL REFERENCES runs(run_id),
|
||||||
|
k INTEGER NOT NULL,
|
||||||
|
query_vector_id INTEGER NOT NULL,
|
||||||
|
started_ns INTEGER NOT NULL,
|
||||||
|
ended_ns INTEGER NOT NULL,
|
||||||
|
duration_ms REAL NOT NULL,
|
||||||
|
result_ids TEXT NOT NULL, -- JSON array
|
||||||
|
result_distances TEXT NOT NULL, -- JSON array
|
||||||
|
ground_truth_ids TEXT NOT NULL, -- JSON array
|
||||||
|
recall REAL NOT NULL,
|
||||||
|
UNIQUE(run_id, k, query_vector_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_runs_config ON runs(config_name);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_runs_type ON runs(index_type);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_runs_status ON runs(status);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_batches_run ON insert_batches(run_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_queries_run ON queries(run_id);
|
||||||
Loading…
Add table
Add a link
Reference in a new issue