mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 16:56:27 +02:00
441 lines
15 KiB
Python
441 lines
15 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""CPU profiling for sqlite-vec KNN configurations using macOS `sample` tool.
|
||
|
|
|
||
|
|
Builds dist/sqlite3 (with -g3), generates a SQL workload (inserts + repeated
|
||
|
|
KNN queries) for each config, profiles the sqlite3 process with `sample`, and
|
||
|
|
prints the top-N hottest functions by self (exclusive) CPU samples.
|
||
|
|
|
||
|
|
Usage:
|
||
|
|
cd benchmarks-ann
|
||
|
|
uv run profile.py --subset-size 50000 -n 50 \\
|
||
|
|
"baseline-int8:type=baseline,variant=int8,oversample=8" \\
|
||
|
|
"rescore-int8:type=rescore,quantizer=int8,oversample=8"
|
||
|
|
"""
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import os
|
||
|
|
import re
|
||
|
|
import shutil
|
||
|
|
import subprocess
|
||
|
|
import sys
|
||
|
|
import tempfile
|
||
|
|
|
||
|
|
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||
|
|
_PROJECT_ROOT = os.path.join(_SCRIPT_DIR, "..")
|
||
|
|
|
||
|
|
sys.path.insert(0, _SCRIPT_DIR)
|
||
|
|
from bench import (
|
||
|
|
BASE_DB,
|
||
|
|
DEFAULT_INSERT_SQL,
|
||
|
|
INDEX_REGISTRY,
|
||
|
|
INSERT_BATCH_SIZE,
|
||
|
|
parse_config,
|
||
|
|
)
|
||
|
|
|
||
|
|
SQLITE3_PATH = os.path.join(_PROJECT_ROOT, "dist", "sqlite3")
|
||
|
|
EXT_PATH = os.path.join(_PROJECT_ROOT, "dist", "vec0")
|
||
|
|
|
||
|
|
|
||
|
|
# ============================================================================
|
||
|
|
# SQL generation
|
||
|
|
# ============================================================================
|
||
|
|
|
||
|
|
|
||
|
|
def _query_sql_for_config(params, query_id, k):
|
||
|
|
"""Return a SQL query string for a single KNN query by query_vector id."""
|
||
|
|
index_type = params["index_type"]
|
||
|
|
qvec = f"(SELECT vector FROM base.query_vectors WHERE id = {query_id})"
|
||
|
|
|
||
|
|
if index_type == "baseline":
|
||
|
|
variant = params.get("variant", "float")
|
||
|
|
oversample = params.get("oversample", 8)
|
||
|
|
oversample_k = k * oversample
|
||
|
|
|
||
|
|
if variant == "int8":
|
||
|
|
return (
|
||
|
|
f"WITH coarse AS ("
|
||
|
|
f" SELECT id, embedding FROM vec_items"
|
||
|
|
f" WHERE embedding_int8 MATCH vec_quantize_int8({qvec}, 'unit')"
|
||
|
|
f" LIMIT {oversample_k}"
|
||
|
|
f") "
|
||
|
|
f"SELECT id, vec_distance_cosine(embedding, {qvec}) as distance "
|
||
|
|
f"FROM coarse ORDER BY 2 LIMIT {k};"
|
||
|
|
)
|
||
|
|
elif variant == "bit":
|
||
|
|
return (
|
||
|
|
f"WITH coarse AS ("
|
||
|
|
f" SELECT id, embedding FROM vec_items"
|
||
|
|
f" WHERE embedding_bq MATCH vec_quantize_binary({qvec})"
|
||
|
|
f" LIMIT {oversample_k}"
|
||
|
|
f") "
|
||
|
|
f"SELECT id, vec_distance_cosine(embedding, {qvec}) as distance "
|
||
|
|
f"FROM coarse ORDER BY 2 LIMIT {k};"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Default MATCH query (baseline-float, rescore, and others)
|
||
|
|
return (
|
||
|
|
f"SELECT id, distance FROM vec_items"
|
||
|
|
f" WHERE embedding MATCH {qvec} AND k = {k};"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def generate_sql(db_path, params, subset_size, n_queries, k, repeats):
|
||
|
|
"""Generate a complete SQL workload: load ext, create table, insert, query."""
|
||
|
|
lines = []
|
||
|
|
lines.append(".bail on")
|
||
|
|
lines.append(f".load {EXT_PATH}")
|
||
|
|
lines.append(f"ATTACH DATABASE '{os.path.abspath(BASE_DB)}' AS base;")
|
||
|
|
lines.append("PRAGMA page_size=8192;")
|
||
|
|
|
||
|
|
# Create table
|
||
|
|
reg = INDEX_REGISTRY[params["index_type"]]
|
||
|
|
lines.append(reg["create_table_sql"](params) + ";")
|
||
|
|
|
||
|
|
# Inserts
|
||
|
|
sql_fn = reg.get("insert_sql")
|
||
|
|
insert_sql = sql_fn(params) if sql_fn else None
|
||
|
|
if insert_sql is None:
|
||
|
|
insert_sql = DEFAULT_INSERT_SQL
|
||
|
|
for lo in range(0, subset_size, INSERT_BATCH_SIZE):
|
||
|
|
hi = min(lo + INSERT_BATCH_SIZE, subset_size)
|
||
|
|
stmt = insert_sql.replace(":lo", str(lo)).replace(":hi", str(hi))
|
||
|
|
lines.append(stmt + ";")
|
||
|
|
if hi % 10000 == 0 or hi == subset_size:
|
||
|
|
lines.append("-- progress: inserted %d/%d" % (hi, subset_size))
|
||
|
|
|
||
|
|
# Queries (repeated)
|
||
|
|
lines.append("-- BEGIN QUERIES")
|
||
|
|
for _rep in range(repeats):
|
||
|
|
for qid in range(n_queries):
|
||
|
|
lines.append(_query_sql_for_config(params, qid, k))
|
||
|
|
|
||
|
|
return "\n".join(lines)
|
||
|
|
|
||
|
|
|
||
|
|
# ============================================================================
|
||
|
|
# Profiling with macOS `sample`
|
||
|
|
# ============================================================================
|
||
|
|
|
||
|
|
|
||
|
|
def run_profile(sqlite3_path, db_path, sql_file, sample_output, duration=120):
|
||
|
|
"""Run sqlite3 under macOS `sample` profiler.
|
||
|
|
|
||
|
|
Starts sqlite3 directly with stdin from the SQL file, then immediately
|
||
|
|
attaches `sample` to its PID with -mayDie (tolerates process exit).
|
||
|
|
The workload must be long enough for sample to attach and capture useful data.
|
||
|
|
"""
|
||
|
|
sql_fd = open(sql_file, "r")
|
||
|
|
proc = subprocess.Popen(
|
||
|
|
[sqlite3_path, db_path],
|
||
|
|
stdin=sql_fd,
|
||
|
|
stdout=subprocess.DEVNULL,
|
||
|
|
stderr=subprocess.PIPE,
|
||
|
|
)
|
||
|
|
|
||
|
|
pid = proc.pid
|
||
|
|
print(f" sqlite3 PID: {pid}")
|
||
|
|
|
||
|
|
# Attach sample immediately (1ms interval, -mayDie tolerates process exit)
|
||
|
|
sample_proc = subprocess.Popen(
|
||
|
|
["sample", str(pid), str(duration), "1", "-mayDie", "-file", sample_output],
|
||
|
|
stdout=subprocess.DEVNULL,
|
||
|
|
stderr=subprocess.PIPE,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Wait for sqlite3 to finish
|
||
|
|
_, stderr = proc.communicate()
|
||
|
|
sql_fd.close()
|
||
|
|
rc = proc.returncode
|
||
|
|
if rc != 0:
|
||
|
|
print(f" sqlite3 failed (rc={rc}):", file=sys.stderr)
|
||
|
|
print(f" {stderr.decode().strip()}", file=sys.stderr)
|
||
|
|
sample_proc.kill()
|
||
|
|
return False
|
||
|
|
|
||
|
|
# Wait for sample to finish
|
||
|
|
sample_proc.wait()
|
||
|
|
return True
|
||
|
|
|
||
|
|
|
||
|
|
# ============================================================================
|
||
|
|
# Parse `sample` output
|
||
|
|
# ============================================================================
|
||
|
|
|
||
|
|
# Tree-drawing characters used by macOS `sample` to represent hierarchy.
|
||
|
|
# We replace them with spaces so indentation depth reflects tree depth.
|
||
|
|
_TREE_CHARS_RE = re.compile(r"[+!:|]")
|
||
|
|
|
||
|
|
# After tree chars are replaced with spaces, each call-graph line looks like:
|
||
|
|
# " 800 rescore_knn (in vec0.dylib) + 3808,3640,... [0x1a,0x2b,...] file.c:123"
|
||
|
|
# We extract just (indent, count, symbol, module) — everything after "(in ...)"
|
||
|
|
# is decoration we don't need.
|
||
|
|
_LEADING_RE = re.compile(r"^(\s+)(\d+)\s+(.+)")
|
||
|
|
|
||
|
|
|
||
|
|
def _extract_symbol_and_module(rest):
|
||
|
|
"""Given the text after 'count ', extract (symbol, module).
|
||
|
|
|
||
|
|
Handles patterns like:
|
||
|
|
'rescore_knn (in vec0.dylib) + 3808,3640,... [0x...]'
|
||
|
|
'pread (in libsystem_kernel.dylib) + 8 [0x...]'
|
||
|
|
'??? (in <unknown binary>) [0x...]'
|
||
|
|
'start (in dyld) + 2840 [0x198650274]'
|
||
|
|
'Thread_26759239 DispatchQueue_1: ...'
|
||
|
|
"""
|
||
|
|
# Try to find "(in ...)" to split symbol from module
|
||
|
|
m = re.match(r"^(.+?)\s+\(in\s+(.+?)\)", rest)
|
||
|
|
if m:
|
||
|
|
return m.group(1).strip(), m.group(2).strip()
|
||
|
|
# No module — return whole thing as symbol, strip trailing junk
|
||
|
|
sym = re.sub(r"\s+\[0x[0-9a-f].*", "", rest).strip()
|
||
|
|
return sym, ""
|
||
|
|
|
||
|
|
|
||
|
|
def _parse_call_graph_lines(text):
|
||
|
|
"""Parse call-graph section into list of (depth, count, symbol, module)."""
|
||
|
|
entries = []
|
||
|
|
for raw_line in text.split("\n"):
|
||
|
|
# Strip tree-drawing characters, replace with spaces to preserve depth
|
||
|
|
line = _TREE_CHARS_RE.sub(" ", raw_line)
|
||
|
|
m = _LEADING_RE.match(line)
|
||
|
|
if not m:
|
||
|
|
continue
|
||
|
|
depth = len(m.group(1))
|
||
|
|
count = int(m.group(2))
|
||
|
|
rest = m.group(3)
|
||
|
|
symbol, module = _extract_symbol_and_module(rest)
|
||
|
|
entries.append((depth, count, symbol, module))
|
||
|
|
return entries
|
||
|
|
|
||
|
|
|
||
|
|
def parse_sample_output(filepath):
|
||
|
|
"""Parse `sample` call-graph output, compute exclusive (self) samples per function.
|
||
|
|
|
||
|
|
Returns dict of {display_name: self_sample_count}.
|
||
|
|
"""
|
||
|
|
with open(filepath, "r") as f:
|
||
|
|
text = f.read()
|
||
|
|
|
||
|
|
# Find "Call graph:" section
|
||
|
|
cg_start = text.find("Call graph:")
|
||
|
|
if cg_start == -1:
|
||
|
|
print(" Warning: no 'Call graph:' section found in sample output")
|
||
|
|
return {}
|
||
|
|
|
||
|
|
# End at "Total number in stack" or EOF
|
||
|
|
cg_end = text.find("\nTotal number in stack", cg_start)
|
||
|
|
if cg_end == -1:
|
||
|
|
cg_end = len(text)
|
||
|
|
|
||
|
|
entries = _parse_call_graph_lines(text[cg_start:cg_end])
|
||
|
|
|
||
|
|
if not entries:
|
||
|
|
print(" Warning: no call graph entries parsed")
|
||
|
|
return {}
|
||
|
|
|
||
|
|
# Compute self (exclusive) samples per function:
|
||
|
|
# self = count - sum(direct_children_counts)
|
||
|
|
self_samples = {}
|
||
|
|
for i, (depth, count, sym, mod) in enumerate(entries):
|
||
|
|
children_sum = 0
|
||
|
|
child_depth = None
|
||
|
|
for j in range(i + 1, len(entries)):
|
||
|
|
j_depth = entries[j][0]
|
||
|
|
if j_depth <= depth:
|
||
|
|
break
|
||
|
|
if child_depth is None:
|
||
|
|
child_depth = j_depth
|
||
|
|
if j_depth == child_depth:
|
||
|
|
children_sum += entries[j][1]
|
||
|
|
|
||
|
|
self_count = count - children_sum
|
||
|
|
if self_count > 0:
|
||
|
|
key = f"{sym} ({mod})" if mod else sym
|
||
|
|
self_samples[key] = self_samples.get(key, 0) + self_count
|
||
|
|
|
||
|
|
return self_samples
|
||
|
|
|
||
|
|
|
||
|
|
# ============================================================================
|
||
|
|
# Display
|
||
|
|
# ============================================================================
|
||
|
|
|
||
|
|
|
||
|
|
def print_profile(title, self_samples, top_n=20):
|
||
|
|
total = sum(self_samples.values())
|
||
|
|
if total == 0:
|
||
|
|
print(f"\n=== {title} (no samples) ===")
|
||
|
|
return
|
||
|
|
|
||
|
|
sorted_syms = sorted(self_samples.items(), key=lambda x: -x[1])
|
||
|
|
|
||
|
|
print(f"\n=== {title} (top {top_n}, {total} total self-samples) ===")
|
||
|
|
for sym, count in sorted_syms[:top_n]:
|
||
|
|
pct = 100.0 * count / total
|
||
|
|
print(f" {pct:5.1f}% {count:>6} {sym}")
|
||
|
|
|
||
|
|
|
||
|
|
# ============================================================================
|
||
|
|
# Main
|
||
|
|
# ============================================================================
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = argparse.ArgumentParser(
|
||
|
|
description="CPU profiling for sqlite-vec KNN configurations",
|
||
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
|
|
epilog=__doc__,
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"configs", nargs="+", help="config specs (name:type=X,key=val,...)"
|
||
|
|
)
|
||
|
|
parser.add_argument("--subset-size", type=int, required=True)
|
||
|
|
parser.add_argument("-k", type=int, default=10, help="KNN k (default 10)")
|
||
|
|
parser.add_argument(
|
||
|
|
"-n", type=int, default=50, help="number of distinct queries (default 50)"
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--repeats",
|
||
|
|
type=int,
|
||
|
|
default=10,
|
||
|
|
help="repeat query set N times for more samples (default 10)",
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--top", type=int, default=20, help="show top N functions (default 20)"
|
||
|
|
)
|
||
|
|
parser.add_argument("--base-db", default=BASE_DB)
|
||
|
|
parser.add_argument("--sqlite3", default=SQLITE3_PATH)
|
||
|
|
parser.add_argument(
|
||
|
|
"--keep-temp",
|
||
|
|
action="store_true",
|
||
|
|
help="keep temp directory with DBs, SQL, and sample output",
|
||
|
|
)
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
# Check prerequisites
|
||
|
|
if not os.path.exists(args.base_db):
|
||
|
|
print(f"Error: base DB not found at {args.base_db}", file=sys.stderr)
|
||
|
|
print("Run 'make seed' in benchmarks-ann/ first.", file=sys.stderr)
|
||
|
|
sys.exit(1)
|
||
|
|
|
||
|
|
if not shutil.which("sample"):
|
||
|
|
print("Error: macOS 'sample' tool not found.", file=sys.stderr)
|
||
|
|
sys.exit(1)
|
||
|
|
|
||
|
|
# Build CLI
|
||
|
|
print("Building dist/sqlite3...")
|
||
|
|
result = subprocess.run(
|
||
|
|
["make", "cli"], cwd=_PROJECT_ROOT, capture_output=True, text=True
|
||
|
|
)
|
||
|
|
if result.returncode != 0:
|
||
|
|
print(f"Error: make cli failed:\n{result.stderr}", file=sys.stderr)
|
||
|
|
sys.exit(1)
|
||
|
|
print(" done.")
|
||
|
|
|
||
|
|
if not os.path.exists(args.sqlite3):
|
||
|
|
print(f"Error: sqlite3 not found at {args.sqlite3}", file=sys.stderr)
|
||
|
|
sys.exit(1)
|
||
|
|
|
||
|
|
configs = [parse_config(c) for c in args.configs]
|
||
|
|
|
||
|
|
tmpdir = tempfile.mkdtemp(prefix="sqlite-vec-profile-")
|
||
|
|
print(f"Working directory: {tmpdir}")
|
||
|
|
|
||
|
|
all_profiles = []
|
||
|
|
|
||
|
|
for i, (name, params) in enumerate(configs, 1):
|
||
|
|
reg = INDEX_REGISTRY[params["index_type"]]
|
||
|
|
desc = reg["describe"](params)
|
||
|
|
print(f"\n[{i}/{len(configs)}] {name} ({desc})")
|
||
|
|
|
||
|
|
# Generate SQL workload
|
||
|
|
db_path = os.path.join(tmpdir, f"{name}.db")
|
||
|
|
sql_text = generate_sql(
|
||
|
|
db_path, params, args.subset_size, args.n, args.k, args.repeats
|
||
|
|
)
|
||
|
|
sql_file = os.path.join(tmpdir, f"{name}.sql")
|
||
|
|
with open(sql_file, "w") as f:
|
||
|
|
f.write(sql_text)
|
||
|
|
|
||
|
|
total_queries = args.n * args.repeats
|
||
|
|
print(
|
||
|
|
f" SQL workload: {args.subset_size} inserts + "
|
||
|
|
f"{total_queries} queries ({args.n} x {args.repeats} repeats)"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Profile
|
||
|
|
sample_file = os.path.join(tmpdir, f"{name}.sample.txt")
|
||
|
|
print(f" Profiling...")
|
||
|
|
ok = run_profile(args.sqlite3, db_path, sql_file, sample_file)
|
||
|
|
if not ok:
|
||
|
|
print(f" FAILED — skipping {name}")
|
||
|
|
all_profiles.append((name, desc, {}))
|
||
|
|
continue
|
||
|
|
|
||
|
|
if not os.path.exists(sample_file):
|
||
|
|
print(f" Warning: sample output not created")
|
||
|
|
all_profiles.append((name, desc, {}))
|
||
|
|
continue
|
||
|
|
|
||
|
|
# Parse
|
||
|
|
self_samples = parse_sample_output(sample_file)
|
||
|
|
all_profiles.append((name, desc, self_samples))
|
||
|
|
|
||
|
|
# Show individual profile
|
||
|
|
print_profile(f"{name} ({desc})", self_samples, args.top)
|
||
|
|
|
||
|
|
# Side-by-side comparison if multiple configs
|
||
|
|
if len(all_profiles) > 1:
|
||
|
|
print("\n" + "=" * 80)
|
||
|
|
print("COMPARISON")
|
||
|
|
print("=" * 80)
|
||
|
|
|
||
|
|
# Collect all symbols that appear in top-N of any config
|
||
|
|
all_syms = set()
|
||
|
|
for _name, _desc, prof in all_profiles:
|
||
|
|
sorted_syms = sorted(prof.items(), key=lambda x: -x[1])
|
||
|
|
for sym, _count in sorted_syms[: args.top]:
|
||
|
|
all_syms.add(sym)
|
||
|
|
|
||
|
|
# Build comparison table
|
||
|
|
rows = []
|
||
|
|
for sym in all_syms:
|
||
|
|
row = [sym]
|
||
|
|
for _name, _desc, prof in all_profiles:
|
||
|
|
total = sum(prof.values())
|
||
|
|
count = prof.get(sym, 0)
|
||
|
|
pct = 100.0 * count / total if total > 0 else 0.0
|
||
|
|
row.append((pct, count))
|
||
|
|
max_pct = max(r[0] for r in row[1:])
|
||
|
|
rows.append((max_pct, row))
|
||
|
|
|
||
|
|
rows.sort(key=lambda x: -x[0])
|
||
|
|
|
||
|
|
# Header
|
||
|
|
header = f"{'function':>40}"
|
||
|
|
for name, desc, _ in all_profiles:
|
||
|
|
header += f" {name:>14}"
|
||
|
|
print(header)
|
||
|
|
print("-" * len(header))
|
||
|
|
|
||
|
|
for _sort_key, row in rows[: args.top * 2]:
|
||
|
|
sym = row[0]
|
||
|
|
display_sym = sym if len(sym) <= 40 else sym[:37] + "..."
|
||
|
|
line = f"{display_sym:>40}"
|
||
|
|
for pct, count in row[1:]:
|
||
|
|
if count > 0:
|
||
|
|
line += f" {pct:>13.1f}%"
|
||
|
|
else:
|
||
|
|
line += f" {'-':>14}"
|
||
|
|
print(line)
|
||
|
|
|
||
|
|
if args.keep_temp:
|
||
|
|
print(f"\nTemp files kept at: {tmpdir}")
|
||
|
|
else:
|
||
|
|
shutil.rmtree(tmpdir)
|
||
|
|
print(f"\nTemp files cleaned up. Use --keep-temp to preserve.")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|