mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 00:36:56 +02:00
Add DiskANN index for vec0 virtual table
Add DiskANN graph-based index: builds a Vamana graph with configurable R (max degree) and L (search list size, separate for insert/query), supports int8 quantization with rescore, lazy reverse-edge replacement, pre-quantized query optimization, and insert buffer reuse. Includes shadow table management, delete support, KNN integration, compile flag (SQLITE_VEC_ENABLE_DISKANN), release-demo workflow, fuzz targets, and tests. Fixes rescore int8 quantization bug.
This commit is contained in:
parent
e2c38f387c
commit
575371d751
23 changed files with 6550 additions and 135 deletions
|
|
@ -6,18 +6,16 @@ across different vec0 configurations.
|
|||
|
||||
Config format: name:type=<index_type>,key=val,key=val
|
||||
|
||||
Baseline (brute-force) keys:
|
||||
type=baseline, variant=float|int8|bit, oversample=8
|
||||
|
||||
Index-specific types can be registered via INDEX_REGISTRY (see below).
|
||||
Available types: none, vec0-flat, rescore, ivf, diskann
|
||||
|
||||
Usage:
|
||||
python bench.py --subset-size 10000 \
|
||||
"brute-float:type=baseline,variant=float" \
|
||||
"brute-int8:type=baseline,variant=int8" \
|
||||
"brute-bit:type=baseline,variant=bit"
|
||||
"raw:type=none" \
|
||||
"flat:type=vec0-flat,variant=float" \
|
||||
"flat-int8:type=vec0-flat,variant=int8"
|
||||
"""
|
||||
import argparse
|
||||
from datetime import datetime, timezone
|
||||
import os
|
||||
import sqlite3
|
||||
import statistics
|
||||
|
|
@ -56,11 +54,118 @@ INDEX_REGISTRY = {}
|
|||
|
||||
|
||||
# ============================================================================
|
||||
# Baseline implementation
|
||||
# "none" — regular table, no vec0, manual KNN via vec_distance_cosine()
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _baseline_create_table_sql(params):
|
||||
def _none_create_table_sql(params):
|
||||
variant = params["variant"]
|
||||
if variant == "int8":
|
||||
return (
|
||||
"CREATE TABLE vec_items ("
|
||||
" id INTEGER PRIMARY KEY,"
|
||||
" embedding BLOB NOT NULL,"
|
||||
" embedding_int8 BLOB NOT NULL)"
|
||||
)
|
||||
elif variant == "bit":
|
||||
return (
|
||||
"CREATE TABLE vec_items ("
|
||||
" id INTEGER PRIMARY KEY,"
|
||||
" embedding BLOB NOT NULL,"
|
||||
" embedding_bq BLOB NOT NULL)"
|
||||
)
|
||||
return (
|
||||
"CREATE TABLE vec_items ("
|
||||
" id INTEGER PRIMARY KEY,"
|
||||
" embedding BLOB NOT NULL)"
|
||||
)
|
||||
|
||||
|
||||
def _none_insert_sql(params):
|
||||
variant = params["variant"]
|
||||
if variant == "int8":
|
||||
return (
|
||||
"INSERT INTO vec_items(id, embedding, embedding_int8) "
|
||||
"SELECT id, vector, vec_quantize_int8(vector, 'unit') "
|
||||
"FROM base.train WHERE id >= :lo AND id < :hi"
|
||||
)
|
||||
elif variant == "bit":
|
||||
return (
|
||||
"INSERT INTO vec_items(id, embedding, embedding_bq) "
|
||||
"SELECT id, vector, vec_quantize_binary(vector) "
|
||||
"FROM base.train WHERE id >= :lo AND id < :hi"
|
||||
)
|
||||
return (
|
||||
"INSERT INTO vec_items(id, embedding) "
|
||||
"SELECT id, vector FROM base.train WHERE id >= :lo AND id < :hi"
|
||||
)
|
||||
|
||||
|
||||
def _none_run_query(conn, params, query, k):
|
||||
variant = params["variant"]
|
||||
oversample = params.get("oversample", 8)
|
||||
|
||||
if variant == "int8":
|
||||
q_int8 = conn.execute(
|
||||
"SELECT vec_quantize_int8(:query, 'unit')", {"query": query}
|
||||
).fetchone()[0]
|
||||
return conn.execute(
|
||||
"WITH coarse AS ("
|
||||
" SELECT id, embedding FROM ("
|
||||
" SELECT id, embedding, vec_distance_cosine(vec_int8(:q_int8), vec_int8(embedding_int8)) as dist "
|
||||
" FROM vec_items ORDER BY dist LIMIT :oversample_k"
|
||||
" )"
|
||||
") "
|
||||
"SELECT id, vec_distance_cosine(:query, embedding) as distance "
|
||||
"FROM coarse ORDER BY 2 LIMIT :k",
|
||||
{"q_int8": q_int8, "query": query, "k": k, "oversample_k": k * oversample},
|
||||
).fetchall()
|
||||
elif variant == "bit":
|
||||
q_bit = conn.execute(
|
||||
"SELECT vec_quantize_binary(:query)", {"query": query}
|
||||
).fetchone()[0]
|
||||
return conn.execute(
|
||||
"WITH coarse AS ("
|
||||
" SELECT id, embedding FROM ("
|
||||
" SELECT id, embedding, vec_distance_hamming(vec_bit(:q_bit), vec_bit(embedding_bq)) as dist "
|
||||
" FROM vec_items ORDER BY dist LIMIT :oversample_k"
|
||||
" )"
|
||||
") "
|
||||
"SELECT id, vec_distance_cosine(:query, embedding) as distance "
|
||||
"FROM coarse ORDER BY 2 LIMIT :k",
|
||||
{"q_bit": q_bit, "query": query, "k": k, "oversample_k": k * oversample},
|
||||
).fetchall()
|
||||
|
||||
return conn.execute(
|
||||
"SELECT id, vec_distance_cosine(:query, embedding) as distance "
|
||||
"FROM vec_items ORDER BY 2 LIMIT :k",
|
||||
{"query": query, "k": k},
|
||||
).fetchall()
|
||||
|
||||
|
||||
def _none_describe(params):
|
||||
v = params["variant"]
|
||||
if v in ("int8", "bit"):
|
||||
return f"none {v} (os={params['oversample']})"
|
||||
return f"none float"
|
||||
|
||||
|
||||
INDEX_REGISTRY["none"] = {
|
||||
"defaults": {"variant": "float", "oversample": 8},
|
||||
"create_table_sql": _none_create_table_sql,
|
||||
"insert_sql": _none_insert_sql,
|
||||
"post_insert_hook": None,
|
||||
"run_query": _none_run_query,
|
||||
"describe": _none_describe,
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# vec0-flat — vec0 virtual table with brute-force MATCH
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _vec0flat_create_table_sql(params):
|
||||
variant = params["variant"]
|
||||
extra = ""
|
||||
if variant == "int8":
|
||||
|
|
@ -76,7 +181,7 @@ def _baseline_create_table_sql(params):
|
|||
)
|
||||
|
||||
|
||||
def _baseline_insert_sql(params):
|
||||
def _vec0flat_insert_sql(params):
|
||||
variant = params["variant"]
|
||||
if variant == "int8":
|
||||
return (
|
||||
|
|
@ -93,7 +198,7 @@ def _baseline_insert_sql(params):
|
|||
return None # use default
|
||||
|
||||
|
||||
def _baseline_run_query(conn, params, query, k):
|
||||
def _vec0flat_run_query(conn, params, query, k):
|
||||
variant = params["variant"]
|
||||
oversample = params.get("oversample", 8)
|
||||
|
||||
|
|
@ -123,20 +228,20 @@ def _baseline_run_query(conn, params, query, k):
|
|||
return None # use default MATCH
|
||||
|
||||
|
||||
def _baseline_describe(params):
|
||||
def _vec0flat_describe(params):
|
||||
v = params["variant"]
|
||||
if v in ("int8", "bit"):
|
||||
return f"baseline {v} (os={params['oversample']})"
|
||||
return f"baseline {v}"
|
||||
return f"vec0-flat {v} (os={params['oversample']})"
|
||||
return f"vec0-flat {v}"
|
||||
|
||||
|
||||
INDEX_REGISTRY["baseline"] = {
|
||||
INDEX_REGISTRY["vec0-flat"] = {
|
||||
"defaults": {"variant": "float", "oversample": 8},
|
||||
"create_table_sql": _baseline_create_table_sql,
|
||||
"insert_sql": _baseline_insert_sql,
|
||||
"create_table_sql": _vec0flat_create_table_sql,
|
||||
"insert_sql": _vec0flat_insert_sql,
|
||||
"post_insert_hook": None,
|
||||
"run_query": _baseline_run_query,
|
||||
"describe": _baseline_describe,
|
||||
"run_query": _vec0flat_run_query,
|
||||
"describe": _vec0flat_describe,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -215,12 +320,64 @@ INDEX_REGISTRY["ivf"] = {
|
|||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DiskANN implementation
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _diskann_create_table_sql(params):
|
||||
bt = params["buffer_threshold"]
|
||||
extra = f", buffer_threshold={bt}" if bt > 0 else ""
|
||||
return (
|
||||
f"CREATE VIRTUAL TABLE vec_items USING vec0("
|
||||
f" id integer primary key,"
|
||||
f" embedding float[768] distance_metric=cosine"
|
||||
f" INDEXED BY diskann("
|
||||
f" neighbor_quantizer={params['quantizer']},"
|
||||
f" n_neighbors={params['R']},"
|
||||
f" search_list_size={params['L']}"
|
||||
f" {extra}"
|
||||
f" )"
|
||||
f")"
|
||||
)
|
||||
|
||||
|
||||
def _diskann_pre_query_hook(conn, params):
|
||||
L_search = params.get("L_search")
|
||||
if L_search:
|
||||
conn.execute(
|
||||
"INSERT INTO vec_items(id) VALUES (?)",
|
||||
(f"search_list_size_search={L_search}",),
|
||||
)
|
||||
conn.commit()
|
||||
print(f" Set search_list_size_search={L_search}")
|
||||
|
||||
|
||||
def _diskann_describe(params):
|
||||
desc = f"diskann q={params['quantizer']:<6} R={params['R']:<3} L={params['L']}"
|
||||
L_search = params.get("L_search")
|
||||
if L_search:
|
||||
desc += f" L_search={L_search}"
|
||||
return desc
|
||||
|
||||
|
||||
INDEX_REGISTRY["diskann"] = {
|
||||
"defaults": {"R": 72, "L": 128, "quantizer": "binary", "buffer_threshold": 0},
|
||||
"create_table_sql": _diskann_create_table_sql,
|
||||
"insert_sql": None,
|
||||
"post_insert_hook": None,
|
||||
"pre_query_hook": _diskann_pre_query_hook,
|
||||
"run_query": None,
|
||||
"describe": _diskann_describe,
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Config parsing
|
||||
# ============================================================================
|
||||
|
||||
INT_KEYS = {
|
||||
"R", "L", "buffer_threshold", "nlist", "nprobe", "oversample",
|
||||
"R", "L", "L_search", "buffer_threshold", "nlist", "nprobe", "oversample",
|
||||
"n_trees", "search_k",
|
||||
}
|
||||
|
||||
|
|
@ -238,7 +395,7 @@ def parse_config(spec):
|
|||
k, v = kv.split("=", 1)
|
||||
raw[k.strip()] = v.strip()
|
||||
|
||||
index_type = raw.pop("type", "baseline")
|
||||
index_type = raw.pop("type", "vec0-flat")
|
||||
if index_type not in INDEX_REGISTRY:
|
||||
raise ValueError(
|
||||
f"Unknown index type: {index_type}. "
|
||||
|
|
@ -289,7 +446,7 @@ def insert_loop(conn, sql, subset_size, label=""):
|
|||
return time.perf_counter() - t0
|
||||
|
||||
|
||||
def open_bench_db(db_path, ext_path, base_db):
|
||||
def create_bench_db(db_path, ext_path, base_db):
|
||||
if os.path.exists(db_path):
|
||||
os.remove(db_path)
|
||||
conn = sqlite3.connect(db_path)
|
||||
|
|
@ -300,6 +457,19 @@ def open_bench_db(db_path, ext_path, base_db):
|
|||
return conn
|
||||
|
||||
|
||||
def open_existing_bench_db(db_path, ext_path, base_db):
|
||||
if not os.path.exists(db_path):
|
||||
raise FileNotFoundError(
|
||||
f"Index DB not found: {db_path}\n"
|
||||
f"Build it first with: --phase build"
|
||||
)
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.enable_load_extension(True)
|
||||
conn.load_extension(ext_path)
|
||||
conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
|
||||
return conn
|
||||
|
||||
|
||||
DEFAULT_INSERT_SQL = (
|
||||
"INSERT INTO vec_items(id, embedding) "
|
||||
"SELECT id, vector FROM base.train WHERE id >= :lo AND id < :hi"
|
||||
|
|
@ -313,7 +483,7 @@ DEFAULT_INSERT_SQL = (
|
|||
|
||||
def build_index(base_db, ext_path, name, params, subset_size, out_dir):
|
||||
db_path = os.path.join(out_dir, f"{name}.{subset_size}.db")
|
||||
conn = open_bench_db(db_path, ext_path, base_db)
|
||||
conn = create_bench_db(db_path, ext_path, base_db)
|
||||
|
||||
reg = INDEX_REGISTRY[params["index_type"]]
|
||||
|
||||
|
|
@ -364,12 +534,16 @@ def _default_match_query(conn, query, k):
|
|||
).fetchall()
|
||||
|
||||
|
||||
def measure_knn(db_path, ext_path, base_db, params, subset_size, k=10, n=50):
|
||||
def measure_knn(db_path, ext_path, base_db, params, subset_size, k=10, n=50,
|
||||
pre_query_hook=None):
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.enable_load_extension(True)
|
||||
conn.load_extension(ext_path)
|
||||
conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
|
||||
|
||||
if pre_query_hook:
|
||||
pre_query_hook(conn, params)
|
||||
|
||||
query_vectors = load_query_vectors(base_db, n)
|
||||
|
||||
reg = INDEX_REGISTRY[params["index_type"]]
|
||||
|
|
@ -431,6 +605,34 @@ def measure_knn(db_path, ext_path, base_db, params, subset_size, k=10, n=50):
|
|||
# ============================================================================
|
||||
|
||||
|
||||
def open_results_db(results_path):
|
||||
db = sqlite3.connect(results_path)
|
||||
db.executescript(open(os.path.join(_SCRIPT_DIR, "schema.sql")).read())
|
||||
# Migrate existing DBs that predate the runs table
|
||||
cols = {r[1] for r in db.execute("PRAGMA table_info(runs)").fetchall()}
|
||||
if "phase" not in cols:
|
||||
db.execute("ALTER TABLE runs ADD COLUMN phase TEXT NOT NULL DEFAULT 'both'")
|
||||
db.commit()
|
||||
return db
|
||||
|
||||
|
||||
def create_run(db, config_name, index_type, subset_size, phase, k=None, n=None):
|
||||
cur = db.execute(
|
||||
"INSERT INTO runs (config_name, index_type, subset_size, phase, status, k, n) "
|
||||
"VALUES (?, ?, ?, ?, 'pending', ?, ?)",
|
||||
(config_name, index_type, subset_size, phase, k, n),
|
||||
)
|
||||
db.commit()
|
||||
return cur.lastrowid
|
||||
|
||||
|
||||
def update_run(db, run_id, **kwargs):
|
||||
sets = ", ".join(f"{k} = ?" for k in kwargs)
|
||||
vals = list(kwargs.values()) + [run_id]
|
||||
db.execute(f"UPDATE runs SET {sets} WHERE run_id = ?", vals)
|
||||
db.commit()
|
||||
|
||||
|
||||
def save_results(results_path, rows):
|
||||
db = sqlite3.connect(results_path)
|
||||
db.executescript(open(os.path.join(_SCRIPT_DIR, "schema.sql")).read())
|
||||
|
|
@ -500,6 +702,8 @@ def main():
|
|||
parser.add_argument("--subset-size", type=int, required=True)
|
||||
parser.add_argument("-k", type=int, default=10, help="KNN k (default 10)")
|
||||
parser.add_argument("-n", type=int, default=50, help="number of queries (default 50)")
|
||||
parser.add_argument("--phase", choices=["build", "query", "both"], default="both",
|
||||
help="build=build only, query=query existing index, both=default")
|
||||
parser.add_argument("--base-db", default=BASE_DB)
|
||||
parser.add_argument("--ext", default=EXT_PATH)
|
||||
parser.add_argument("-o", "--out-dir", default="runs")
|
||||
|
|
@ -508,55 +712,164 @@ def main():
|
|||
args = parser.parse_args()
|
||||
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
results_db = args.results_db or os.path.join(args.out_dir, "results.db")
|
||||
results_db_path = args.results_db or os.path.join(args.out_dir, "results.db")
|
||||
configs = [parse_config(c) for c in args.configs]
|
||||
results_db = open_results_db(results_db_path)
|
||||
|
||||
all_results = []
|
||||
for i, (name, params) in enumerate(configs, 1):
|
||||
reg = INDEX_REGISTRY[params["index_type"]]
|
||||
desc = reg["describe"](params)
|
||||
print(f"\n[{i}/{len(configs)}] {name} ({desc.strip()})")
|
||||
print(f"\n[{i}/{len(configs)}] {name} ({desc.strip()}) [phase={args.phase}]")
|
||||
|
||||
build = build_index(
|
||||
args.base_db, args.ext, name, params, args.subset_size, args.out_dir
|
||||
)
|
||||
train_str = f" + {build['train_time_s']}s train" if build["train_time_s"] > 0 else ""
|
||||
print(
|
||||
f" Build: {build['insert_time_s']}s insert{train_str} "
|
||||
f"{build['file_size_mb']} MB"
|
||||
)
|
||||
db_path = os.path.join(args.out_dir, f"{name}.{args.subset_size}.db")
|
||||
|
||||
print(f" Measuring KNN (k={args.k}, n={args.n})...")
|
||||
knn = measure_knn(
|
||||
build["db_path"], args.ext, args.base_db,
|
||||
params, args.subset_size, k=args.k, n=args.n,
|
||||
)
|
||||
print(f" KNN: mean={knn['mean_ms']}ms recall@{args.k}={knn['recall']}")
|
||||
if args.phase == "build":
|
||||
run_id = create_run(results_db, name, params["index_type"],
|
||||
args.subset_size, "build")
|
||||
update_run(results_db, run_id, status="inserting")
|
||||
|
||||
all_results.append({
|
||||
"name": name,
|
||||
"n_vectors": args.subset_size,
|
||||
"index_type": params["index_type"],
|
||||
"config_desc": desc,
|
||||
"db_path": build["db_path"],
|
||||
"insert_time_s": build["insert_time_s"],
|
||||
"train_time_s": build["train_time_s"],
|
||||
"total_time_s": build["total_time_s"],
|
||||
"insert_per_vec_ms": build["insert_per_vec_ms"],
|
||||
"rows": build["rows"],
|
||||
"file_size_mb": build["file_size_mb"],
|
||||
"k": args.k,
|
||||
"n_queries": args.n,
|
||||
"mean_ms": knn["mean_ms"],
|
||||
"median_ms": knn["median_ms"],
|
||||
"p99_ms": knn["p99_ms"],
|
||||
"total_ms": knn["total_ms"],
|
||||
"recall": knn["recall"],
|
||||
})
|
||||
build = build_index(
|
||||
args.base_db, args.ext, name, params, args.subset_size, args.out_dir
|
||||
)
|
||||
train_str = f" + {build['train_time_s']}s train" if build["train_time_s"] > 0 else ""
|
||||
print(
|
||||
f" Build: {build['insert_time_s']}s insert{train_str} "
|
||||
f"{build['file_size_mb']} MB"
|
||||
)
|
||||
update_run(results_db, run_id,
|
||||
status="built",
|
||||
db_path=build["db_path"],
|
||||
insert_time_s=build["insert_time_s"],
|
||||
train_time_s=build["train_time_s"],
|
||||
total_build_time_s=build["total_time_s"],
|
||||
rows=build["rows"],
|
||||
file_size_mb=build["file_size_mb"],
|
||||
finished_at=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"))
|
||||
print(f" Index DB: {build['db_path']}")
|
||||
|
||||
print_report(all_results)
|
||||
save_results(results_db, all_results)
|
||||
print(f"\nResults saved to {results_db}")
|
||||
elif args.phase == "query":
|
||||
if not os.path.exists(db_path):
|
||||
raise FileNotFoundError(
|
||||
f"Index DB not found: {db_path}\n"
|
||||
f"Build it first with: --phase build"
|
||||
)
|
||||
|
||||
run_id = create_run(results_db, name, params["index_type"],
|
||||
args.subset_size, "query", k=args.k, n=args.n)
|
||||
update_run(results_db, run_id, status="querying")
|
||||
|
||||
pre_hook = reg.get("pre_query_hook")
|
||||
print(f" Measuring KNN (k={args.k}, n={args.n})...")
|
||||
knn = measure_knn(
|
||||
db_path, args.ext, args.base_db,
|
||||
params, args.subset_size, k=args.k, n=args.n,
|
||||
pre_query_hook=pre_hook,
|
||||
)
|
||||
print(f" KNN: mean={knn['mean_ms']}ms recall@{args.k}={knn['recall']}")
|
||||
|
||||
qps = round(args.n / (knn["total_ms"] / 1000), 1) if knn["total_ms"] > 0 else 0
|
||||
update_run(results_db, run_id,
|
||||
status="done",
|
||||
db_path=db_path,
|
||||
mean_ms=knn["mean_ms"],
|
||||
median_ms=knn["median_ms"],
|
||||
p99_ms=knn["p99_ms"],
|
||||
total_query_ms=knn["total_ms"],
|
||||
qps=qps,
|
||||
recall=knn["recall"],
|
||||
finished_at=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"))
|
||||
|
||||
file_size_mb = os.path.getsize(db_path) / (1024 * 1024)
|
||||
all_results.append({
|
||||
"name": name,
|
||||
"n_vectors": args.subset_size,
|
||||
"index_type": params["index_type"],
|
||||
"config_desc": desc,
|
||||
"db_path": db_path,
|
||||
"insert_time_s": 0,
|
||||
"train_time_s": 0,
|
||||
"total_time_s": 0,
|
||||
"insert_per_vec_ms": 0,
|
||||
"rows": 0,
|
||||
"file_size_mb": file_size_mb,
|
||||
"k": args.k,
|
||||
"n_queries": args.n,
|
||||
"mean_ms": knn["mean_ms"],
|
||||
"median_ms": knn["median_ms"],
|
||||
"p99_ms": knn["p99_ms"],
|
||||
"total_ms": knn["total_ms"],
|
||||
"recall": knn["recall"],
|
||||
})
|
||||
|
||||
else: # both
|
||||
run_id = create_run(results_db, name, params["index_type"],
|
||||
args.subset_size, "both", k=args.k, n=args.n)
|
||||
update_run(results_db, run_id, status="inserting")
|
||||
|
||||
build = build_index(
|
||||
args.base_db, args.ext, name, params, args.subset_size, args.out_dir
|
||||
)
|
||||
train_str = f" + {build['train_time_s']}s train" if build["train_time_s"] > 0 else ""
|
||||
print(
|
||||
f" Build: {build['insert_time_s']}s insert{train_str} "
|
||||
f"{build['file_size_mb']} MB"
|
||||
)
|
||||
update_run(results_db, run_id, status="querying",
|
||||
db_path=build["db_path"],
|
||||
insert_time_s=build["insert_time_s"],
|
||||
train_time_s=build["train_time_s"],
|
||||
total_build_time_s=build["total_time_s"],
|
||||
rows=build["rows"],
|
||||
file_size_mb=build["file_size_mb"])
|
||||
|
||||
print(f" Measuring KNN (k={args.k}, n={args.n})...")
|
||||
knn = measure_knn(
|
||||
build["db_path"], args.ext, args.base_db,
|
||||
params, args.subset_size, k=args.k, n=args.n,
|
||||
)
|
||||
print(f" KNN: mean={knn['mean_ms']}ms recall@{args.k}={knn['recall']}")
|
||||
|
||||
qps = round(args.n / (knn["total_ms"] / 1000), 1) if knn["total_ms"] > 0 else 0
|
||||
update_run(results_db, run_id,
|
||||
status="done",
|
||||
mean_ms=knn["mean_ms"],
|
||||
median_ms=knn["median_ms"],
|
||||
p99_ms=knn["p99_ms"],
|
||||
total_query_ms=knn["total_ms"],
|
||||
qps=qps,
|
||||
recall=knn["recall"],
|
||||
finished_at=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"))
|
||||
|
||||
all_results.append({
|
||||
"name": name,
|
||||
"n_vectors": args.subset_size,
|
||||
"index_type": params["index_type"],
|
||||
"config_desc": desc,
|
||||
"db_path": build["db_path"],
|
||||
"insert_time_s": build["insert_time_s"],
|
||||
"train_time_s": build["train_time_s"],
|
||||
"total_time_s": build["total_time_s"],
|
||||
"insert_per_vec_ms": build["insert_per_vec_ms"],
|
||||
"rows": build["rows"],
|
||||
"file_size_mb": build["file_size_mb"],
|
||||
"k": args.k,
|
||||
"n_queries": args.n,
|
||||
"mean_ms": knn["mean_ms"],
|
||||
"median_ms": knn["median_ms"],
|
||||
"p99_ms": knn["p99_ms"],
|
||||
"total_ms": knn["total_ms"],
|
||||
"recall": knn["recall"],
|
||||
})
|
||||
|
||||
results_db.close()
|
||||
|
||||
if all_results:
|
||||
print_report(all_results)
|
||||
save_results(results_db_path, all_results)
|
||||
print(f"\nResults saved to {results_db_path}")
|
||||
elif args.phase == "build":
|
||||
print(f"\nBuild complete. Results tracked in {results_db_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue