diff --git a/Makefile b/Makefile
index 1ebdbed..051590e 100644
--- a/Makefile
+++ b/Makefile
@@ -42,6 +42,11 @@ ifndef OMIT_SIMD
ifeq ($(shell uname -sm),Darwin arm64)
CFLAGS += -mcpu=apple-m1 -DSQLITE_VEC_ENABLE_NEON
endif
+ ifeq ($(shell uname -s),Linux)
+ ifneq ($(filter avx,$(shell grep -o 'avx[^ ]*' /proc/cpuinfo 2>/dev/null | head -1)),)
+ CFLAGS += -mavx -DSQLITE_VEC_ENABLE_AVX
+ endif
+ endif
endif
ifdef USE_BREW_SQLITE
@@ -155,6 +160,13 @@ clean:
rm -rf dist
+TARGET_AMALGAMATION=$(prefix)/sqlite-vec.c
+
+amalgamation: $(TARGET_AMALGAMATION)
+
+$(TARGET_AMALGAMATION): sqlite-vec.c $(wildcard sqlite-vec-*.c) scripts/amalgamate.py $(prefix)
+ python3 scripts/amalgamate.py sqlite-vec.c > $@
+
FORMAT_FILES=sqlite-vec.h sqlite-vec.c
format: $(FORMAT_FILES)
clang-format -i $(FORMAT_FILES)
@@ -174,7 +186,7 @@ evidence-of:
test:
sqlite3 :memory: '.read test.sql'
-.PHONY: version loadable static test clean gh-release evidence-of install uninstall
+.PHONY: version loadable static test clean gh-release evidence-of install uninstall amalgamation
publish-release:
./scripts/publish-release.sh
diff --git a/TODO.md b/TODO.md
new file mode 100644
index 0000000..4c3cc19
--- /dev/null
+++ b/TODO.md
@@ -0,0 +1,73 @@
+# TODO: `ann` base branch + consolidated benchmarks
+
+## 1. Create `ann` branch with shared code
+
+### 1.1 Branch setup
+- [x] `git checkout -B ann origin/main`
+- [x] Cherry-pick `624f998` (vec0_distance_full shared distance dispatch)
+- [x] Cherry-pick stdint.h fix for test header
+- [ ] Pull NEON cosine optimization from ivf-yolo3 into shared code
+ - Currently only in ivf branch but is general-purpose (benefits all distance calcs)
+ - Lives in `distance_cosine_float()` — ~57 lines of ARM NEON vectorized cosine
+
+### 1.2 Benchmark infrastructure (`benchmarks-ann/`)
+- [x] Seed data pipeline (`seed/Makefile`, `seed/build_base_db.py`)
+- [x] Ground truth generator (`ground_truth.py`)
+- [x] Results schema (`schema.sql`)
+- [x] Benchmark runner with `INDEX_REGISTRY` extension point (`bench.py`)
+ - Baseline configs (float, int8-rescore, bit-rescore) implemented
+ - Index branches register their types via `INDEX_REGISTRY` dict
+- [x] Makefile with baseline targets
+- [x] README
+
+### 1.3 Rebase feature branches onto `ann`
+- [x] Rebase `diskann-yolo2` onto `ann` (1 commit: DiskANN implementation)
+- [x] Rebase `ivf-yolo3` onto `ann` (1 commit: IVF implementation)
+- [x] Rebase `annoy-yolo2` onto `ann` (2 commits: Annoy implementation + schema fix)
+- [x] Verify each branch has only its index-specific commits remaining
+- [ ] Force-push all 4 branches to origin
+
+---
+
+## 2. Per-branch: register index type in benchmarks
+
+Each index branch should add to `benchmarks-ann/` when rebased onto `ann`:
+
+### 2.1 Register in `bench.py`
+
+Add an `INDEX_REGISTRY` entry. Each entry provides:
+- `defaults` — default param values
+- `create_table_sql(params)` — CREATE VIRTUAL TABLE with INDEXED BY clause
+- `insert_sql(params)` — custom insert SQL, or None for default
+- `post_insert_hook(conn, params)` — training/building step, returns time
+- `run_query(conn, params, query, k)` — custom query, or None for default MATCH
+- `describe(params)` — one-line description for report output
+
+### 2.2 Add configs to `Makefile`
+
+Append index-specific config variables and targets. Example pattern:
+
+```makefile
+DISKANN_CONFIGS = \
+ "diskann-R48-binary:type=diskann,R=48,L=128,quantizer=binary" \
+ ...
+
+ALL_CONFIGS += $(DISKANN_CONFIGS)
+
+bench-diskann: seed
+ $(BENCH) --subset-size 10000 -k 10 -o runs/diskann $(BASELINES) $(DISKANN_CONFIGS)
+ ...
+```
+
+### 2.3 Migrate existing benchmark results/docs
+
+- Move useful results docs (RESULTS.md, etc.) into `benchmarks-ann/results/`
+- Delete redundant per-branch benchmark directories once consolidated infra is proven
+
+---
+
+## 3. Future improvements
+
+- [ ] Reporting script (`report.py`) — query results.db, produce markdown comparison tables
+- [ ] Profiling targets in Makefile (lift from ivf-yolo3's Instruments/perf wrappers)
+- [ ] Pre-computed ground truth integration (use GT DB files instead of on-the-fly brute-force)
diff --git a/benchmarks-ann/.gitignore b/benchmarks-ann/.gitignore
new file mode 100644
index 0000000..c418b76
--- /dev/null
+++ b/benchmarks-ann/.gitignore
@@ -0,0 +1,2 @@
+*.db
+runs/
diff --git a/benchmarks-ann/Makefile b/benchmarks-ann/Makefile
new file mode 100644
index 0000000..59e2dcd
--- /dev/null
+++ b/benchmarks-ann/Makefile
@@ -0,0 +1,61 @@
+BENCH = python bench.py
+BASE_DB = seed/base.db
+EXT = ../dist/vec0
+
+# --- Baseline (brute-force) configs ---
+BASELINES = \
+ "brute-float:type=baseline,variant=float" \
+ "brute-int8:type=baseline,variant=int8" \
+ "brute-bit:type=baseline,variant=bit"
+
+# --- Index-specific configs ---
+# Each index branch should add its own configs here. Example:
+#
+# DISKANN_CONFIGS = \
+# "diskann-R48-binary:type=diskann,R=48,L=128,quantizer=binary" \
+# "diskann-R72-int8:type=diskann,R=72,L=128,quantizer=int8"
+#
+# IVF_CONFIGS = \
+# "ivf-n128-p16:type=ivf,nlist=128,nprobe=16"
+#
+# ANNOY_CONFIGS = \
+# "annoy-t50:type=annoy,n_trees=50"
+
+ALL_CONFIGS = $(BASELINES)
+
+.PHONY: seed ground-truth bench-smoke bench-10k bench-50k bench-100k bench-all \
+ report clean
+
+# --- Data preparation ---
+seed:
+ $(MAKE) -C seed
+
+ground-truth: seed
+ python ground_truth.py --subset-size 10000
+ python ground_truth.py --subset-size 50000
+ python ground_truth.py --subset-size 100000
+
+# --- Quick smoke test ---
+bench-smoke: seed
+ $(BENCH) --subset-size 5000 -k 10 -n 20 -o runs/smoke \
+ $(BASELINES)
+
+# --- Standard sizes ---
+bench-10k: seed
+ $(BENCH) --subset-size 10000 -k 10 -o runs/10k $(ALL_CONFIGS)
+
+bench-50k: seed
+ $(BENCH) --subset-size 50000 -k 10 -o runs/50k $(ALL_CONFIGS)
+
+bench-100k: seed
+ $(BENCH) --subset-size 100000 -k 10 -o runs/100k $(ALL_CONFIGS)
+
+bench-all: bench-10k bench-50k bench-100k
+
+# --- Report ---
+report:
+ @echo "Use: sqlite3 runs/
/results.db 'SELECT * FROM bench_results ORDER BY recall DESC'"
+
+# --- Cleanup ---
+clean:
+ rm -rf runs/
diff --git a/benchmarks-ann/README.md b/benchmarks-ann/README.md
new file mode 100644
index 0000000..1f7fd5c
--- /dev/null
+++ b/benchmarks-ann/README.md
@@ -0,0 +1,81 @@
+# KNN Benchmarks for sqlite-vec
+
+Benchmarking infrastructure for vec0 KNN configurations. Includes brute-force
+baselines (float, int8, bit); index-specific branches add their own types
+via the `INDEX_REGISTRY` in `bench.py`.
+
+## Prerequisites
+
+- Built `dist/vec0` extension (run `make` from repo root)
+- Python 3.10+
+- `uv` (for seed data prep): `pip install uv`
+
+## Quick start
+
+```bash
+# 1. Download dataset and build seed DB (~3 GB download, ~5 min)
+make seed
+
+# 2. Run a quick smoke test (5k vectors, ~1 min)
+make bench-smoke
+
+# 3. Run full benchmark at 10k
+make bench-10k
+```
+
+## Usage
+
+### Direct invocation
+
+```bash
+python bench.py --subset-size 10000 \
+ "brute-float:type=baseline,variant=float" \
+ "brute-int8:type=baseline,variant=int8" \
+ "brute-bit:type=baseline,variant=bit"
+```
+
+### Config format
+
+`name:type=,key=val,key=val`
+
+| Index type | Keys | Branch |
+|-----------|------|--------|
+| `baseline` | `variant` (float/int8/bit), `oversample` | this branch |
+
+Index branches register additional types in `INDEX_REGISTRY`. See the
+docstring in `bench.py` for the extension API.
+
+### Make targets
+
+| Target | Description |
+|--------|-------------|
+| `make seed` | Download COHERE 1M dataset |
+| `make ground-truth` | Pre-compute ground truth for 10k/50k/100k |
+| `make bench-smoke` | Quick 5k baseline test |
+| `make bench-10k` | All configs at 10k vectors |
+| `make bench-50k` | All configs at 50k vectors |
+| `make bench-100k` | All configs at 100k vectors |
+| `make bench-all` | 10k + 50k + 100k |
+
+## Adding an index type
+
+In your index branch, add an entry to `INDEX_REGISTRY` in `bench.py` and
+append your configs to `ALL_CONFIGS` in the `Makefile`. See the existing
+`baseline` entry and the comments in both files for the pattern.
+
+## Results
+
+Results are stored in `runs//results.db` using the schema in `schema.sql`.
+
+```bash
+sqlite3 runs/10k/results.db "
+ SELECT config_name, recall, mean_ms, qps
+ FROM bench_results
+ ORDER BY recall DESC
+"
+```
+
+## Dataset
+
+[Zilliz COHERE Medium 1M](https://zilliz.com/learn/datasets-for-vector-database-benchmarks):
+768 dimensions, cosine distance, 1M train vectors + 10k query vectors with precomputed neighbors.
diff --git a/benchmarks-ann/bench.py b/benchmarks-ann/bench.py
new file mode 100644
index 0000000..93f8f82
--- /dev/null
+++ b/benchmarks-ann/bench.py
@@ -0,0 +1,488 @@
+#!/usr/bin/env python3
+"""Benchmark runner for sqlite-vec KNN configurations.
+
+Measures insert time, build/train time, DB size, KNN latency, and recall
+across different vec0 configurations.
+
+Config format: name:type=,key=val,key=val
+
+ Baseline (brute-force) keys:
+ type=baseline, variant=float|int8|bit, oversample=8
+
+ Index-specific types can be registered via INDEX_REGISTRY (see below).
+
+Usage:
+ python bench.py --subset-size 10000 \
+ "brute-float:type=baseline,variant=float" \
+ "brute-int8:type=baseline,variant=int8" \
+ "brute-bit:type=baseline,variant=bit"
+"""
+import argparse
+import os
+import sqlite3
+import statistics
+import time
+
+_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
+EXT_PATH = os.path.join(_SCRIPT_DIR, "..", "dist", "vec0")
+BASE_DB = os.path.join(_SCRIPT_DIR, "seed", "base.db")
+INSERT_BATCH_SIZE = 1000
+
+
+# ============================================================================
+# Index registry — extension point for ANN index branches
+# ============================================================================
+#
+# Each index type provides a dict with:
+# "defaults": dict of default params
+# "create_table_sql": fn(params) -> SQL string
+# "insert_sql": fn(params) -> SQL string (or None for default)
+# "post_insert_hook": fn(conn, params) -> train_time_s (or None)
+# "run_query": fn(conn, params, query, k) -> [(id, distance), ...] (or None for default MATCH)
+# "describe": fn(params) -> str (one-line description)
+#
+# To add a new index type, add an entry here. Example (in your branch):
+#
+# INDEX_REGISTRY["diskann"] = {
+# "defaults": {"R": 72, "L": 128, "quantizer": "binary", "buffer_threshold": 0},
+# "create_table_sql": lambda p: f"CREATE VIRTUAL TABLE vec_items USING vec0(...)",
+# "insert_sql": None,
+# "post_insert_hook": None,
+# "run_query": None,
+# "describe": lambda p: f"diskann q={p['quantizer']} R={p['R']} L={p['L']}",
+# }
+
+INDEX_REGISTRY = {}
+
+
+# ============================================================================
+# Baseline implementation
+# ============================================================================
+
+
+def _baseline_create_table_sql(params):
+ variant = params["variant"]
+ extra = ""
+ if variant == "int8":
+ extra = ", embedding_int8 int8[768]"
+ elif variant == "bit":
+ extra = ", embedding_bq bit[768]"
+ return (
+ f"CREATE VIRTUAL TABLE vec_items USING vec0("
+ f" chunk_size=256,"
+ f" id integer primary key,"
+ f" embedding float[768] distance_metric=cosine"
+ f" {extra})"
+ )
+
+
+def _baseline_insert_sql(params):
+ variant = params["variant"]
+ if variant == "int8":
+ return (
+ "INSERT INTO vec_items(id, embedding, embedding_int8) "
+ "SELECT id, vector, vec_quantize_int8(vector, 'unit') "
+ "FROM base.train WHERE id >= :lo AND id < :hi"
+ )
+ elif variant == "bit":
+ return (
+ "INSERT INTO vec_items(id, embedding, embedding_bq) "
+ "SELECT id, vector, vec_quantize_binary(vector) "
+ "FROM base.train WHERE id >= :lo AND id < :hi"
+ )
+ return None # use default
+
+
+def _baseline_run_query(conn, params, query, k):
+ variant = params["variant"]
+ oversample = params.get("oversample", 8)
+
+ if variant == "int8":
+ return conn.execute(
+ "WITH coarse AS ("
+ " SELECT id, embedding FROM vec_items"
+ " WHERE embedding_int8 MATCH vec_quantize_int8(:query, 'unit')"
+ " LIMIT :oversample_k"
+ ") "
+ "SELECT id, vec_distance_cosine(embedding, :query) as distance "
+ "FROM coarse ORDER BY 2 LIMIT :k",
+ {"query": query, "k": k, "oversample_k": k * oversample},
+ ).fetchall()
+ elif variant == "bit":
+ return conn.execute(
+ "WITH coarse AS ("
+ " SELECT id, embedding FROM vec_items"
+ " WHERE embedding_bq MATCH vec_quantize_binary(:query)"
+ " LIMIT :oversample_k"
+ ") "
+ "SELECT id, vec_distance_cosine(embedding, :query) as distance "
+ "FROM coarse ORDER BY 2 LIMIT :k",
+ {"query": query, "k": k, "oversample_k": k * oversample},
+ ).fetchall()
+
+ return None # use default MATCH
+
+
+def _baseline_describe(params):
+ v = params["variant"]
+ if v in ("int8", "bit"):
+ return f"baseline {v} (os={params['oversample']})"
+ return f"baseline {v}"
+
+
+INDEX_REGISTRY["baseline"] = {
+ "defaults": {"variant": "float", "oversample": 8},
+ "create_table_sql": _baseline_create_table_sql,
+ "insert_sql": _baseline_insert_sql,
+ "post_insert_hook": None,
+ "run_query": _baseline_run_query,
+ "describe": _baseline_describe,
+}
+
+
+# ============================================================================
+# Config parsing
+# ============================================================================
+
+INT_KEYS = {
+ "R", "L", "buffer_threshold", "nlist", "nprobe", "oversample",
+ "n_trees", "search_k",
+}
+
+
+def parse_config(spec):
+ """Parse 'name:type=baseline,key=val,...' into (name, params_dict)."""
+ if ":" in spec:
+ name, opts_str = spec.split(":", 1)
+ else:
+ name, opts_str = spec, ""
+
+ raw = {}
+ if opts_str:
+ for kv in opts_str.split(","):
+ k, v = kv.split("=", 1)
+ raw[k.strip()] = v.strip()
+
+ index_type = raw.pop("type", "baseline")
+ if index_type not in INDEX_REGISTRY:
+ raise ValueError(
+ f"Unknown index type: {index_type}. "
+ f"Available: {', '.join(sorted(INDEX_REGISTRY.keys()))}"
+ )
+
+ reg = INDEX_REGISTRY[index_type]
+ params = dict(reg["defaults"])
+ for k, v in raw.items():
+ if k in INT_KEYS:
+ params[k] = int(v)
+ else:
+ params[k] = v
+ params["index_type"] = index_type
+
+ return name, params
+
+
+# ============================================================================
+# Shared helpers
+# ============================================================================
+
+
+def load_query_vectors(base_db_path, n):
+ conn = sqlite3.connect(base_db_path)
+ rows = conn.execute(
+ "SELECT id, vector FROM query_vectors ORDER BY id LIMIT :n", {"n": n}
+ ).fetchall()
+ conn.close()
+ return [(r[0], r[1]) for r in rows]
+
+
+def insert_loop(conn, sql, subset_size, label=""):
+ t0 = time.perf_counter()
+ for lo in range(0, subset_size, INSERT_BATCH_SIZE):
+ hi = min(lo + INSERT_BATCH_SIZE, subset_size)
+ conn.execute(sql, {"lo": lo, "hi": hi})
+ conn.commit()
+ done = hi
+ if done % 5000 == 0 or done == subset_size:
+ elapsed = time.perf_counter() - t0
+ rate = done / elapsed if elapsed > 0 else 0
+ print(
+ f" [{label}] {done:>8}/{subset_size} "
+ f"{elapsed:.1f}s {rate:.0f} rows/s",
+ flush=True,
+ )
+ return time.perf_counter() - t0
+
+
+def open_bench_db(db_path, ext_path, base_db):
+ if os.path.exists(db_path):
+ os.remove(db_path)
+ conn = sqlite3.connect(db_path)
+ conn.enable_load_extension(True)
+ conn.load_extension(ext_path)
+ conn.execute("PRAGMA page_size=8192")
+ conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
+ return conn
+
+
+DEFAULT_INSERT_SQL = (
+ "INSERT INTO vec_items(id, embedding) "
+ "SELECT id, vector FROM base.train WHERE id >= :lo AND id < :hi"
+)
+
+
+# ============================================================================
+# Build
+# ============================================================================
+
+
+def build_index(base_db, ext_path, name, params, subset_size, out_dir):
+ db_path = os.path.join(out_dir, f"{name}.{subset_size}.db")
+ conn = open_bench_db(db_path, ext_path, base_db)
+
+ reg = INDEX_REGISTRY[params["index_type"]]
+
+ conn.execute(reg["create_table_sql"](params))
+
+ label = params["index_type"]
+ print(f" Inserting {subset_size} vectors...")
+
+ sql_fn = reg.get("insert_sql")
+ sql = sql_fn(params) if sql_fn else None
+ if sql is None:
+ sql = DEFAULT_INSERT_SQL
+
+ insert_time = insert_loop(conn, sql, subset_size, label)
+
+ train_time = 0.0
+ hook = reg.get("post_insert_hook")
+ if hook:
+ train_time = hook(conn, params)
+
+ row_count = conn.execute("SELECT count(*) FROM vec_items").fetchone()[0]
+ conn.close()
+ file_size_mb = os.path.getsize(db_path) / (1024 * 1024)
+
+ return {
+ "db_path": db_path,
+ "insert_time_s": round(insert_time, 3),
+ "train_time_s": round(train_time, 3),
+ "total_time_s": round(insert_time + train_time, 3),
+ "insert_per_vec_ms": round((insert_time / row_count) * 1000, 2)
+ if row_count
+ else 0,
+ "rows": row_count,
+ "file_size_mb": round(file_size_mb, 2),
+ }
+
+
+# ============================================================================
+# KNN measurement
+# ============================================================================
+
+
+def _default_match_query(conn, query, k):
+ return conn.execute(
+ "SELECT id, distance FROM vec_items "
+ "WHERE embedding MATCH :query AND k = :k",
+ {"query": query, "k": k},
+ ).fetchall()
+
+
+def measure_knn(db_path, ext_path, base_db, params, subset_size, k=10, n=50):
+ conn = sqlite3.connect(db_path)
+ conn.enable_load_extension(True)
+ conn.load_extension(ext_path)
+ conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
+
+ query_vectors = load_query_vectors(base_db, n)
+
+ reg = INDEX_REGISTRY[params["index_type"]]
+ query_fn = reg.get("run_query")
+
+ times_ms = []
+ recalls = []
+ for qid, query in query_vectors:
+ t0 = time.perf_counter()
+
+ results = None
+ if query_fn:
+ results = query_fn(conn, params, query, k)
+ if results is None:
+ results = _default_match_query(conn, query, k)
+
+ elapsed_ms = (time.perf_counter() - t0) * 1000
+ times_ms.append(elapsed_ms)
+ result_ids = set(r[0] for r in results)
+
+ # Ground truth: use pre-computed neighbors table for full dataset,
+ # otherwise brute-force over the subset
+ if subset_size >= 1000000:
+ gt_rows = conn.execute(
+ "SELECT CAST(neighbors_id AS INTEGER) FROM base.neighbors "
+ "WHERE query_vector_id = :qid AND rank < :k",
+ {"qid": qid, "k": k},
+ ).fetchall()
+ else:
+ gt_rows = conn.execute(
+ "SELECT id FROM ("
+ " SELECT id, vec_distance_cosine(vector, :query) as dist "
+ " FROM base.train WHERE id < :n ORDER BY dist LIMIT :k"
+ ")",
+ {"query": query, "k": k, "n": subset_size},
+ ).fetchall()
+ gt_ids = set(r[0] for r in gt_rows)
+
+ if gt_ids:
+ recalls.append(len(result_ids & gt_ids) / len(gt_ids))
+ else:
+ recalls.append(0.0)
+
+ conn.close()
+
+ return {
+ "mean_ms": round(statistics.mean(times_ms), 2),
+ "median_ms": round(statistics.median(times_ms), 2),
+ "p99_ms": round(sorted(times_ms)[int(len(times_ms) * 0.99)], 2)
+ if len(times_ms) > 1
+ else round(times_ms[0], 2),
+ "total_ms": round(sum(times_ms), 2),
+ "recall": round(statistics.mean(recalls), 4),
+ }
+
+
+# ============================================================================
+# Results persistence
+# ============================================================================
+
+
+def save_results(results_path, rows):
+ db = sqlite3.connect(results_path)
+ db.executescript(open(os.path.join(_SCRIPT_DIR, "schema.sql")).read())
+ for r in rows:
+ db.execute(
+ "INSERT OR REPLACE INTO build_results "
+ "(config_name, index_type, subset_size, db_path, "
+ " insert_time_s, train_time_s, total_time_s, rows, file_size_mb) "
+ "VALUES (?,?,?,?,?,?,?,?,?)",
+ (
+ r["name"], r["index_type"], r["n_vectors"], r["db_path"],
+ r["insert_time_s"], r["train_time_s"], r["total_time_s"],
+ r["rows"], r["file_size_mb"],
+ ),
+ )
+ db.execute(
+ "INSERT OR REPLACE INTO bench_results "
+ "(config_name, index_type, subset_size, k, n, "
+ " mean_ms, median_ms, p99_ms, total_ms, qps, recall, db_path) "
+ "VALUES (?,?,?,?,?,?,?,?,?,?,?,?)",
+ (
+ r["name"], r["index_type"], r["n_vectors"], r["k"], r["n_queries"],
+ r["mean_ms"], r["median_ms"], r["p99_ms"], r["total_ms"],
+ round(r["n_queries"] / (r["total_ms"] / 1000), 1)
+ if r["total_ms"] > 0 else 0,
+ r["recall"], r["db_path"],
+ ),
+ )
+ db.commit()
+ db.close()
+
+
+# ============================================================================
+# Reporting
+# ============================================================================
+
+
+def print_report(all_results):
+ print(
+ f"\n{'name':>20} {'N':>7} {'type':>10} {'config':>28} "
+ f"{'ins(s)':>7} {'train':>6} {'MB':>7} "
+ f"{'qry(ms)':>8} {'recall':>7}"
+ )
+ print("-" * 115)
+ for r in all_results:
+ train = f"{r['train_time_s']:.1f}" if r["train_time_s"] > 0 else "-"
+ print(
+ f"{r['name']:>20} {r['n_vectors']:>7} {r['index_type']:>10} "
+ f"{r['config_desc']:>28} "
+ f"{r['insert_time_s']:>7.1f} {train:>6} {r['file_size_mb']:>7.1f} "
+ f"{r['mean_ms']:>8.2f} {r['recall']:>7.4f}"
+ )
+
+
+# ============================================================================
+# Main
+# ============================================================================
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Benchmark runner for sqlite-vec KNN configurations",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog=__doc__,
+ )
+ parser.add_argument("configs", nargs="+", help="config specs (name:type=X,key=val,...)")
+ parser.add_argument("--subset-size", type=int, required=True)
+ parser.add_argument("-k", type=int, default=10, help="KNN k (default 10)")
+ parser.add_argument("-n", type=int, default=50, help="number of queries (default 50)")
+ parser.add_argument("--base-db", default=BASE_DB)
+ parser.add_argument("--ext", default=EXT_PATH)
+ parser.add_argument("-o", "--out-dir", default="runs")
+ parser.add_argument("--results-db", default=None,
+ help="path to results DB (default: /results.db)")
+ args = parser.parse_args()
+
+ os.makedirs(args.out_dir, exist_ok=True)
+ results_db = args.results_db or os.path.join(args.out_dir, "results.db")
+ configs = [parse_config(c) for c in args.configs]
+
+ all_results = []
+ for i, (name, params) in enumerate(configs, 1):
+ reg = INDEX_REGISTRY[params["index_type"]]
+ desc = reg["describe"](params)
+ print(f"\n[{i}/{len(configs)}] {name} ({desc.strip()})")
+
+ build = build_index(
+ args.base_db, args.ext, name, params, args.subset_size, args.out_dir
+ )
+ train_str = f" + {build['train_time_s']}s train" if build["train_time_s"] > 0 else ""
+ print(
+ f" Build: {build['insert_time_s']}s insert{train_str} "
+ f"{build['file_size_mb']} MB"
+ )
+
+ print(f" Measuring KNN (k={args.k}, n={args.n})...")
+ knn = measure_knn(
+ build["db_path"], args.ext, args.base_db,
+ params, args.subset_size, k=args.k, n=args.n,
+ )
+ print(f" KNN: mean={knn['mean_ms']}ms recall@{args.k}={knn['recall']}")
+
+ all_results.append({
+ "name": name,
+ "n_vectors": args.subset_size,
+ "index_type": params["index_type"],
+ "config_desc": desc,
+ "db_path": build["db_path"],
+ "insert_time_s": build["insert_time_s"],
+ "train_time_s": build["train_time_s"],
+ "total_time_s": build["total_time_s"],
+ "insert_per_vec_ms": build["insert_per_vec_ms"],
+ "rows": build["rows"],
+ "file_size_mb": build["file_size_mb"],
+ "k": args.k,
+ "n_queries": args.n,
+ "mean_ms": knn["mean_ms"],
+ "median_ms": knn["median_ms"],
+ "p99_ms": knn["p99_ms"],
+ "total_ms": knn["total_ms"],
+ "recall": knn["recall"],
+ })
+
+ print_report(all_results)
+ save_results(results_db, all_results)
+ print(f"\nResults saved to {results_db}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/benchmarks-ann/ground_truth.py b/benchmarks-ann/ground_truth.py
new file mode 100644
index 0000000..636a495
--- /dev/null
+++ b/benchmarks-ann/ground_truth.py
@@ -0,0 +1,168 @@
+#!/usr/bin/env python3
+"""Compute per-subset ground truth for ANN benchmarks.
+
+For subset sizes < 1M, builds a temporary vec0 float table with the first N
+vectors and runs brute-force KNN to get correct ground truth per subset.
+
+For 1M (the full dataset), converts the existing `neighbors` table.
+
+Output: ground_truth.{subset_size}.db with table:
+ ground_truth(query_vector_id, rank, neighbor_id, distance)
+
+Usage:
+ python ground_truth.py --subset-size 50000
+ python ground_truth.py --subset-size 1000000
+"""
+import argparse
+import os
+import sqlite3
+import time
+
+_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
+EXT_PATH = os.path.join(_SCRIPT_DIR, "..", "dist", "vec0")
+BASE_DB = os.path.join(_SCRIPT_DIR, "seed", "base.db")
+FULL_DATASET_SIZE = 1_000_000
+
+
+def gen_ground_truth_subset(base_db, ext_path, subset_size, n_queries, k, out_path):
+ """Build ground truth by brute-force KNN over the first `subset_size` vectors."""
+ if os.path.exists(out_path):
+ os.remove(out_path)
+
+ conn = sqlite3.connect(out_path)
+ conn.enable_load_extension(True)
+ conn.load_extension(ext_path)
+
+ conn.execute(
+ "CREATE TABLE ground_truth ("
+ " query_vector_id INTEGER NOT NULL,"
+ " rank INTEGER NOT NULL,"
+ " neighbor_id INTEGER NOT NULL,"
+ " distance REAL NOT NULL,"
+ " PRIMARY KEY (query_vector_id, rank)"
+ ")"
+ )
+
+ conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
+
+ print(f" Building temp vec0 table with {subset_size} vectors...")
+ conn.execute(
+ "CREATE VIRTUAL TABLE tmp_vec USING vec0("
+ " id integer primary key,"
+ " embedding float[768] distance_metric=cosine"
+ ")"
+ )
+
+ t0 = time.perf_counter()
+ conn.execute(
+ "INSERT INTO tmp_vec(id, embedding) "
+ "SELECT id, vector FROM base.train WHERE id < :n",
+ {"n": subset_size},
+ )
+ conn.commit()
+ build_time = time.perf_counter() - t0
+ print(f" Temp table built in {build_time:.1f}s")
+
+ query_vectors = conn.execute(
+ "SELECT id, vector FROM base.query_vectors ORDER BY id LIMIT :n",
+ {"n": n_queries},
+ ).fetchall()
+
+ print(f" Running brute-force KNN for {len(query_vectors)} queries, k={k}...")
+ t0 = time.perf_counter()
+
+ for i, (qid, qvec) in enumerate(query_vectors):
+ results = conn.execute(
+ "SELECT id, distance FROM tmp_vec "
+ "WHERE embedding MATCH :query AND k = :k",
+ {"query": qvec, "k": k},
+ ).fetchall()
+
+ for rank, (nid, dist) in enumerate(results):
+ conn.execute(
+ "INSERT INTO ground_truth(query_vector_id, rank, neighbor_id, distance) "
+ "VALUES (?, ?, ?, ?)",
+ (qid, rank, nid, dist),
+ )
+
+ if (i + 1) % 10 == 0 or i == 0:
+ elapsed = time.perf_counter() - t0
+ eta = (elapsed / (i + 1)) * (len(query_vectors) - i - 1)
+ print(
+ f" {i+1}/{len(query_vectors)} queries "
+ f"elapsed={elapsed:.1f}s eta={eta:.1f}s",
+ flush=True,
+ )
+
+ conn.commit()
+ conn.execute("DROP TABLE tmp_vec")
+ conn.execute("DETACH DATABASE base")
+ conn.commit()
+
+ elapsed = time.perf_counter() - t0
+ total_rows = conn.execute("SELECT count(*) FROM ground_truth").fetchone()[0]
+ conn.close()
+ print(f" Ground truth: {total_rows} rows in {elapsed:.1f}s -> {out_path}")
+
+
+def gen_ground_truth_full(base_db, n_queries, k, out_path):
+ """Convert the existing neighbors table for the full 1M dataset."""
+ if os.path.exists(out_path):
+ os.remove(out_path)
+
+ conn = sqlite3.connect(out_path)
+ conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
+
+ conn.execute(
+ "CREATE TABLE ground_truth ("
+ " query_vector_id INTEGER NOT NULL,"
+ " rank INTEGER NOT NULL,"
+ " neighbor_id INTEGER NOT NULL,"
+ " distance REAL,"
+ " PRIMARY KEY (query_vector_id, rank)"
+ ")"
+ )
+
+ conn.execute(
+ "INSERT INTO ground_truth(query_vector_id, rank, neighbor_id) "
+ "SELECT query_vector_id, rank, CAST(neighbors_id AS INTEGER) "
+ "FROM base.neighbors "
+ "WHERE query_vector_id < :n AND rank < :k",
+ {"n": n_queries, "k": k},
+ )
+ conn.commit()
+
+ total_rows = conn.execute("SELECT count(*) FROM ground_truth").fetchone()[0]
+ conn.execute("DETACH DATABASE base")
+ conn.close()
+ print(f" Ground truth (full): {total_rows} rows -> {out_path}")
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Generate per-subset ground truth")
+ parser.add_argument(
+ "--subset-size", type=int, required=True, help="number of vectors in subset"
+ )
+ parser.add_argument("-n", type=int, default=100, help="number of query vectors")
+ parser.add_argument("-k", type=int, default=100, help="max k for ground truth")
+ parser.add_argument("--base-db", default=BASE_DB)
+ parser.add_argument("--ext", default=EXT_PATH)
+ parser.add_argument(
+ "-o", "--out-dir", default=os.path.join(_SCRIPT_DIR, "seed"),
+ help="output directory for ground_truth.{N}.db",
+ )
+ args = parser.parse_args()
+
+ os.makedirs(args.out_dir, exist_ok=True)
+ out_path = os.path.join(args.out_dir, f"ground_truth.{args.subset_size}.db")
+
+ if args.subset_size >= FULL_DATASET_SIZE:
+ gen_ground_truth_full(args.base_db, args.n, args.k, out_path)
+ else:
+ gen_ground_truth_subset(
+ args.base_db, args.ext, args.subset_size, args.n, args.k, out_path
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/benchmarks-ann/profile.py b/benchmarks-ann/profile.py
new file mode 100644
index 0000000..0792373
--- /dev/null
+++ b/benchmarks-ann/profile.py
@@ -0,0 +1,440 @@
+#!/usr/bin/env python3
+"""CPU profiling for sqlite-vec KNN configurations using macOS `sample` tool.
+
+Builds dist/sqlite3 (with -g3), generates a SQL workload (inserts + repeated
+KNN queries) for each config, profiles the sqlite3 process with `sample`, and
+prints the top-N hottest functions by self (exclusive) CPU samples.
+
+Usage:
+ cd benchmarks-ann
+ uv run profile.py --subset-size 50000 -n 50 \\
+ "baseline-int8:type=baseline,variant=int8,oversample=8" \\
+ "rescore-int8:type=rescore,quantizer=int8,oversample=8"
+"""
+
+import argparse
+import os
+import re
+import shutil
+import subprocess
+import sys
+import tempfile
+
+_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
+_PROJECT_ROOT = os.path.join(_SCRIPT_DIR, "..")
+
+sys.path.insert(0, _SCRIPT_DIR)
+from bench import (
+ BASE_DB,
+ DEFAULT_INSERT_SQL,
+ INDEX_REGISTRY,
+ INSERT_BATCH_SIZE,
+ parse_config,
+)
+
+SQLITE3_PATH = os.path.join(_PROJECT_ROOT, "dist", "sqlite3")
+EXT_PATH = os.path.join(_PROJECT_ROOT, "dist", "vec0")
+
+
+# ============================================================================
+# SQL generation
+# ============================================================================
+
+
+def _query_sql_for_config(params, query_id, k):
+ """Return a SQL query string for a single KNN query by query_vector id."""
+ index_type = params["index_type"]
+ qvec = f"(SELECT vector FROM base.query_vectors WHERE id = {query_id})"
+
+ if index_type == "baseline":
+ variant = params.get("variant", "float")
+ oversample = params.get("oversample", 8)
+ oversample_k = k * oversample
+
+ if variant == "int8":
+ return (
+ f"WITH coarse AS ("
+ f" SELECT id, embedding FROM vec_items"
+ f" WHERE embedding_int8 MATCH vec_quantize_int8({qvec}, 'unit')"
+ f" LIMIT {oversample_k}"
+ f") "
+ f"SELECT id, vec_distance_cosine(embedding, {qvec}) as distance "
+ f"FROM coarse ORDER BY 2 LIMIT {k};"
+ )
+ elif variant == "bit":
+ return (
+ f"WITH coarse AS ("
+ f" SELECT id, embedding FROM vec_items"
+ f" WHERE embedding_bq MATCH vec_quantize_binary({qvec})"
+ f" LIMIT {oversample_k}"
+ f") "
+ f"SELECT id, vec_distance_cosine(embedding, {qvec}) as distance "
+ f"FROM coarse ORDER BY 2 LIMIT {k};"
+ )
+
+ # Default MATCH query (baseline-float, rescore, and others)
+ return (
+ f"SELECT id, distance FROM vec_items"
+ f" WHERE embedding MATCH {qvec} AND k = {k};"
+ )
+
+
+def generate_sql(db_path, params, subset_size, n_queries, k, repeats):
+ """Generate a complete SQL workload: load ext, create table, insert, query."""
+ lines = []
+ lines.append(".bail on")
+ lines.append(f".load {EXT_PATH}")
+ lines.append(f"ATTACH DATABASE '{os.path.abspath(BASE_DB)}' AS base;")
+ lines.append("PRAGMA page_size=8192;")
+
+ # Create table
+ reg = INDEX_REGISTRY[params["index_type"]]
+ lines.append(reg["create_table_sql"](params) + ";")
+
+ # Inserts
+ sql_fn = reg.get("insert_sql")
+ insert_sql = sql_fn(params) if sql_fn else None
+ if insert_sql is None:
+ insert_sql = DEFAULT_INSERT_SQL
+ for lo in range(0, subset_size, INSERT_BATCH_SIZE):
+ hi = min(lo + INSERT_BATCH_SIZE, subset_size)
+ stmt = insert_sql.replace(":lo", str(lo)).replace(":hi", str(hi))
+ lines.append(stmt + ";")
+ if hi % 10000 == 0 or hi == subset_size:
+ lines.append("-- progress: inserted %d/%d" % (hi, subset_size))
+
+ # Queries (repeated)
+ lines.append("-- BEGIN QUERIES")
+ for _rep in range(repeats):
+ for qid in range(n_queries):
+ lines.append(_query_sql_for_config(params, qid, k))
+
+ return "\n".join(lines)
+
+
+# ============================================================================
+# Profiling with macOS `sample`
+# ============================================================================
+
+
+def run_profile(sqlite3_path, db_path, sql_file, sample_output, duration=120):
+ """Run sqlite3 under macOS `sample` profiler.
+
+ Starts sqlite3 directly with stdin from the SQL file, then immediately
+ attaches `sample` to its PID with -mayDie (tolerates process exit).
+ The workload must be long enough for sample to attach and capture useful data.
+ """
+ sql_fd = open(sql_file, "r")
+ proc = subprocess.Popen(
+ [sqlite3_path, db_path],
+ stdin=sql_fd,
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.PIPE,
+ )
+
+ pid = proc.pid
+ print(f" sqlite3 PID: {pid}")
+
+ # Attach sample immediately (1ms interval, -mayDie tolerates process exit)
+ sample_proc = subprocess.Popen(
+ ["sample", str(pid), str(duration), "1", "-mayDie", "-file", sample_output],
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.PIPE,
+ )
+
+ # Wait for sqlite3 to finish
+ _, stderr = proc.communicate()
+ sql_fd.close()
+ rc = proc.returncode
+ if rc != 0:
+ print(f" sqlite3 failed (rc={rc}):", file=sys.stderr)
+ print(f" {stderr.decode().strip()}", file=sys.stderr)
+ sample_proc.kill()
+ return False
+
+ # Wait for sample to finish
+ sample_proc.wait()
+ return True
+
+
+# ============================================================================
+# Parse `sample` output
+# ============================================================================
+
+# Tree-drawing characters used by macOS `sample` to represent hierarchy.
+# We replace them with spaces so indentation depth reflects tree depth.
+_TREE_CHARS_RE = re.compile(r"[+!:|]")
+
+# After tree chars are replaced with spaces, each call-graph line looks like:
+# " 800 rescore_knn (in vec0.dylib) + 3808,3640,... [0x1a,0x2b,...] file.c:123"
+# We extract just (indent, count, symbol, module) — everything after "(in ...)"
+# is decoration we don't need.
+_LEADING_RE = re.compile(r"^(\s+)(\d+)\s+(.+)")
+
+
+def _extract_symbol_and_module(rest):
+ """Given the text after 'count ', extract (symbol, module).
+
+ Handles patterns like:
+ 'rescore_knn (in vec0.dylib) + 3808,3640,... [0x...]'
+ 'pread (in libsystem_kernel.dylib) + 8 [0x...]'
+ '??? (in ) [0x...]'
+ 'start (in dyld) + 2840 [0x198650274]'
+ 'Thread_26759239 DispatchQueue_1: ...'
+ """
+ # Try to find "(in ...)" to split symbol from module
+ m = re.match(r"^(.+?)\s+\(in\s+(.+?)\)", rest)
+ if m:
+ return m.group(1).strip(), m.group(2).strip()
+ # No module — return whole thing as symbol, strip trailing junk
+ sym = re.sub(r"\s+\[0x[0-9a-f].*", "", rest).strip()
+ return sym, ""
+
+
+def _parse_call_graph_lines(text):
+ """Parse call-graph section into list of (depth, count, symbol, module)."""
+ entries = []
+ for raw_line in text.split("\n"):
+ # Strip tree-drawing characters, replace with spaces to preserve depth
+ line = _TREE_CHARS_RE.sub(" ", raw_line)
+ m = _LEADING_RE.match(line)
+ if not m:
+ continue
+ depth = len(m.group(1))
+ count = int(m.group(2))
+ rest = m.group(3)
+ symbol, module = _extract_symbol_and_module(rest)
+ entries.append((depth, count, symbol, module))
+ return entries
+
+
+def parse_sample_output(filepath):
+ """Parse `sample` call-graph output, compute exclusive (self) samples per function.
+
+ Returns dict of {display_name: self_sample_count}.
+ """
+ with open(filepath, "r") as f:
+ text = f.read()
+
+ # Find "Call graph:" section
+ cg_start = text.find("Call graph:")
+ if cg_start == -1:
+ print(" Warning: no 'Call graph:' section found in sample output")
+ return {}
+
+ # End at "Total number in stack" or EOF
+ cg_end = text.find("\nTotal number in stack", cg_start)
+ if cg_end == -1:
+ cg_end = len(text)
+
+ entries = _parse_call_graph_lines(text[cg_start:cg_end])
+
+ if not entries:
+ print(" Warning: no call graph entries parsed")
+ return {}
+
+ # Compute self (exclusive) samples per function:
+ # self = count - sum(direct_children_counts)
+ self_samples = {}
+ for i, (depth, count, sym, mod) in enumerate(entries):
+ children_sum = 0
+ child_depth = None
+ for j in range(i + 1, len(entries)):
+ j_depth = entries[j][0]
+ if j_depth <= depth:
+ break
+ if child_depth is None:
+ child_depth = j_depth
+ if j_depth == child_depth:
+ children_sum += entries[j][1]
+
+ self_count = count - children_sum
+ if self_count > 0:
+ key = f"{sym} ({mod})" if mod else sym
+ self_samples[key] = self_samples.get(key, 0) + self_count
+
+ return self_samples
+
+
+# ============================================================================
+# Display
+# ============================================================================
+
+
+def print_profile(title, self_samples, top_n=20):
+ total = sum(self_samples.values())
+ if total == 0:
+ print(f"\n=== {title} (no samples) ===")
+ return
+
+ sorted_syms = sorted(self_samples.items(), key=lambda x: -x[1])
+
+ print(f"\n=== {title} (top {top_n}, {total} total self-samples) ===")
+ for sym, count in sorted_syms[:top_n]:
+ pct = 100.0 * count / total
+ print(f" {pct:5.1f}% {count:>6} {sym}")
+
+
+# ============================================================================
+# Main
+# ============================================================================
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="CPU profiling for sqlite-vec KNN configurations",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog=__doc__,
+ )
+ parser.add_argument(
+ "configs", nargs="+", help="config specs (name:type=X,key=val,...)"
+ )
+ parser.add_argument("--subset-size", type=int, required=True)
+ parser.add_argument("-k", type=int, default=10, help="KNN k (default 10)")
+ parser.add_argument(
+ "-n", type=int, default=50, help="number of distinct queries (default 50)"
+ )
+ parser.add_argument(
+ "--repeats",
+ type=int,
+ default=10,
+ help="repeat query set N times for more samples (default 10)",
+ )
+ parser.add_argument(
+ "--top", type=int, default=20, help="show top N functions (default 20)"
+ )
+ parser.add_argument("--base-db", default=BASE_DB)
+ parser.add_argument("--sqlite3", default=SQLITE3_PATH)
+ parser.add_argument(
+ "--keep-temp",
+ action="store_true",
+ help="keep temp directory with DBs, SQL, and sample output",
+ )
+ args = parser.parse_args()
+
+ # Check prerequisites
+ if not os.path.exists(args.base_db):
+ print(f"Error: base DB not found at {args.base_db}", file=sys.stderr)
+ print("Run 'make seed' in benchmarks-ann/ first.", file=sys.stderr)
+ sys.exit(1)
+
+ if not shutil.which("sample"):
+ print("Error: macOS 'sample' tool not found.", file=sys.stderr)
+ sys.exit(1)
+
+ # Build CLI
+ print("Building dist/sqlite3...")
+ result = subprocess.run(
+ ["make", "cli"], cwd=_PROJECT_ROOT, capture_output=True, text=True
+ )
+ if result.returncode != 0:
+ print(f"Error: make cli failed:\n{result.stderr}", file=sys.stderr)
+ sys.exit(1)
+ print(" done.")
+
+ if not os.path.exists(args.sqlite3):
+ print(f"Error: sqlite3 not found at {args.sqlite3}", file=sys.stderr)
+ sys.exit(1)
+
+ configs = [parse_config(c) for c in args.configs]
+
+ tmpdir = tempfile.mkdtemp(prefix="sqlite-vec-profile-")
+ print(f"Working directory: {tmpdir}")
+
+ all_profiles = []
+
+ for i, (name, params) in enumerate(configs, 1):
+ reg = INDEX_REGISTRY[params["index_type"]]
+ desc = reg["describe"](params)
+ print(f"\n[{i}/{len(configs)}] {name} ({desc})")
+
+ # Generate SQL workload
+ db_path = os.path.join(tmpdir, f"{name}.db")
+ sql_text = generate_sql(
+ db_path, params, args.subset_size, args.n, args.k, args.repeats
+ )
+ sql_file = os.path.join(tmpdir, f"{name}.sql")
+ with open(sql_file, "w") as f:
+ f.write(sql_text)
+
+ total_queries = args.n * args.repeats
+ print(
+ f" SQL workload: {args.subset_size} inserts + "
+ f"{total_queries} queries ({args.n} x {args.repeats} repeats)"
+ )
+
+ # Profile
+ sample_file = os.path.join(tmpdir, f"{name}.sample.txt")
+ print(f" Profiling...")
+ ok = run_profile(args.sqlite3, db_path, sql_file, sample_file)
+ if not ok:
+ print(f" FAILED — skipping {name}")
+ all_profiles.append((name, desc, {}))
+ continue
+
+ if not os.path.exists(sample_file):
+ print(f" Warning: sample output not created")
+ all_profiles.append((name, desc, {}))
+ continue
+
+ # Parse
+ self_samples = parse_sample_output(sample_file)
+ all_profiles.append((name, desc, self_samples))
+
+ # Show individual profile
+ print_profile(f"{name} ({desc})", self_samples, args.top)
+
+ # Side-by-side comparison if multiple configs
+ if len(all_profiles) > 1:
+ print("\n" + "=" * 80)
+ print("COMPARISON")
+ print("=" * 80)
+
+ # Collect all symbols that appear in top-N of any config
+ all_syms = set()
+ for _name, _desc, prof in all_profiles:
+ sorted_syms = sorted(prof.items(), key=lambda x: -x[1])
+ for sym, _count in sorted_syms[: args.top]:
+ all_syms.add(sym)
+
+ # Build comparison table
+ rows = []
+ for sym in all_syms:
+ row = [sym]
+ for _name, _desc, prof in all_profiles:
+ total = sum(prof.values())
+ count = prof.get(sym, 0)
+ pct = 100.0 * count / total if total > 0 else 0.0
+ row.append((pct, count))
+ max_pct = max(r[0] for r in row[1:])
+ rows.append((max_pct, row))
+
+ rows.sort(key=lambda x: -x[0])
+
+ # Header
+ header = f"{'function':>40}"
+ for name, desc, _ in all_profiles:
+ header += f" {name:>14}"
+ print(header)
+ print("-" * len(header))
+
+ for _sort_key, row in rows[: args.top * 2]:
+ sym = row[0]
+ display_sym = sym if len(sym) <= 40 else sym[:37] + "..."
+ line = f"{display_sym:>40}"
+ for pct, count in row[1:]:
+ if count > 0:
+ line += f" {pct:>13.1f}%"
+ else:
+ line += f" {'-':>14}"
+ print(line)
+
+ if args.keep_temp:
+ print(f"\nTemp files kept at: {tmpdir}")
+ else:
+ shutil.rmtree(tmpdir)
+ print(f"\nTemp files cleaned up. Use --keep-temp to preserve.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/benchmarks-ann/schema.sql b/benchmarks-ann/schema.sql
new file mode 100644
index 0000000..681df4e
--- /dev/null
+++ b/benchmarks-ann/schema.sql
@@ -0,0 +1,35 @@
+-- Canonical results schema for vec0 KNN benchmark comparisons.
+-- The index_type column is a free-form TEXT field. Baseline configs use
+-- "baseline"; index-specific branches add their own types (registered
+-- via INDEX_REGISTRY in bench.py).
+
+CREATE TABLE IF NOT EXISTS build_results (
+ config_name TEXT NOT NULL,
+ index_type TEXT NOT NULL,
+ subset_size INTEGER NOT NULL,
+ db_path TEXT NOT NULL,
+ insert_time_s REAL NOT NULL,
+ train_time_s REAL, -- NULL when no training/build step is needed
+ total_time_s REAL NOT NULL,
+ rows INTEGER NOT NULL,
+ file_size_mb REAL NOT NULL,
+ created_at TEXT NOT NULL DEFAULT (datetime('now')),
+ PRIMARY KEY (config_name, subset_size)
+);
+
+CREATE TABLE IF NOT EXISTS bench_results (
+ config_name TEXT NOT NULL,
+ index_type TEXT NOT NULL,
+ subset_size INTEGER NOT NULL,
+ k INTEGER NOT NULL,
+ n INTEGER NOT NULL,
+ mean_ms REAL NOT NULL,
+ median_ms REAL NOT NULL,
+ p99_ms REAL NOT NULL,
+ total_ms REAL NOT NULL,
+ qps REAL NOT NULL,
+ recall REAL NOT NULL,
+ db_path TEXT NOT NULL,
+ created_at TEXT NOT NULL DEFAULT (datetime('now')),
+ PRIMARY KEY (config_name, subset_size, k)
+);
diff --git a/benchmarks-ann/seed/.gitignore b/benchmarks-ann/seed/.gitignore
new file mode 100644
index 0000000..8efed50
--- /dev/null
+++ b/benchmarks-ann/seed/.gitignore
@@ -0,0 +1,2 @@
+*.parquet
+base.db
diff --git a/benchmarks-ann/seed/Makefile b/benchmarks-ann/seed/Makefile
new file mode 100644
index 0000000..186bf66
--- /dev/null
+++ b/benchmarks-ann/seed/Makefile
@@ -0,0 +1,24 @@
+BASE_URL = https://assets.zilliz.com/benchmark/cohere_medium_1m
+
+PARQUETS = train.parquet test.parquet neighbors.parquet
+
+.PHONY: all download base.db clean
+
+all: base.db
+
+download: $(PARQUETS)
+
+train.parquet:
+ curl -L -o $@ $(BASE_URL)/train.parquet
+
+test.parquet:
+ curl -L -o $@ $(BASE_URL)/test.parquet
+
+neighbors.parquet:
+ curl -L -o $@ $(BASE_URL)/neighbors.parquet
+
+base.db: $(PARQUETS) build_base_db.py
+ uv run --with pandas --with pyarrow python build_base_db.py
+
+clean:
+ rm -f base.db
diff --git a/benchmarks-ann/seed/build_base_db.py b/benchmarks-ann/seed/build_base_db.py
new file mode 100644
index 0000000..33d280d
--- /dev/null
+++ b/benchmarks-ann/seed/build_base_db.py
@@ -0,0 +1,121 @@
+#!/usr/bin/env python3
+"""Build base.db from downloaded parquet files.
+
+Reads train.parquet, test.parquet, neighbors.parquet and creates a SQLite
+database with tables: train, query_vectors, neighbors.
+
+Usage:
+ uv run --with pandas --with pyarrow python build_base_db.py
+"""
+import json
+import os
+import sqlite3
+import struct
+import sys
+import time
+
+import pandas as pd
+
+
+def float_list_to_blob(floats):
+ """Pack a list of floats into a little-endian f32 blob."""
+ return struct.pack(f"<{len(floats)}f", *floats)
+
+
+def main():
+ seed_dir = os.path.dirname(os.path.abspath(__file__))
+ db_path = os.path.join(seed_dir, "base.db")
+
+ train_path = os.path.join(seed_dir, "train.parquet")
+ test_path = os.path.join(seed_dir, "test.parquet")
+ neighbors_path = os.path.join(seed_dir, "neighbors.parquet")
+
+ for p in (train_path, test_path, neighbors_path):
+ if not os.path.exists(p):
+ print(f"ERROR: {p} not found. Run 'make download' first.")
+ sys.exit(1)
+
+ if os.path.exists(db_path):
+ os.remove(db_path)
+
+ conn = sqlite3.connect(db_path)
+ conn.execute("PRAGMA journal_mode=WAL")
+ conn.execute("PRAGMA page_size=4096")
+
+ # --- query_vectors (from test.parquet) ---
+ print("Loading test.parquet (query vectors)...")
+ t0 = time.perf_counter()
+ df_test = pd.read_parquet(test_path)
+ conn.execute(
+ "CREATE TABLE query_vectors (id INTEGER PRIMARY KEY, vector BLOB)"
+ )
+ rows = []
+ for _, row in df_test.iterrows():
+ rows.append((int(row["id"]), float_list_to_blob(row["emb"])))
+ conn.executemany("INSERT INTO query_vectors (id, vector) VALUES (?, ?)", rows)
+ conn.commit()
+ print(f" {len(rows)} query vectors in {time.perf_counter() - t0:.1f}s")
+
+ # --- neighbors (from neighbors.parquet) ---
+ print("Loading neighbors.parquet...")
+ t0 = time.perf_counter()
+ df_neighbors = pd.read_parquet(neighbors_path)
+ conn.execute(
+ "CREATE TABLE neighbors ("
+ " query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT,"
+ " UNIQUE(query_vector_id, rank))"
+ )
+ rows = []
+ for _, row in df_neighbors.iterrows():
+ qid = int(row["id"])
+ # neighbors_id may be a numpy array or JSON string
+ nids = row["neighbors_id"]
+ if isinstance(nids, str):
+ nids = json.loads(nids)
+ for rank, nid in enumerate(nids):
+ rows.append((qid, rank, str(int(nid))))
+ conn.executemany(
+ "INSERT INTO neighbors (query_vector_id, rank, neighbors_id) VALUES (?, ?, ?)",
+ rows,
+ )
+ conn.commit()
+ print(f" {len(rows)} neighbor rows in {time.perf_counter() - t0:.1f}s")
+
+ # --- train (from train.parquet) ---
+ print("Loading train.parquet (1M vectors, this takes a few minutes)...")
+ t0 = time.perf_counter()
+ conn.execute(
+ "CREATE TABLE train (id INTEGER PRIMARY KEY, vector BLOB)"
+ )
+
+ batch_size = 10000
+ df_iter = pd.read_parquet(train_path)
+ total = len(df_iter)
+
+ for start in range(0, total, batch_size):
+ chunk = df_iter.iloc[start : start + batch_size]
+ rows = []
+ for _, row in chunk.iterrows():
+ rows.append((int(row["id"]), float_list_to_blob(row["emb"])))
+ conn.executemany("INSERT INTO train (id, vector) VALUES (?, ?)", rows)
+ conn.commit()
+
+ done = min(start + batch_size, total)
+ elapsed = time.perf_counter() - t0
+ rate = done / elapsed if elapsed > 0 else 0
+ eta = (total - done) / rate if rate > 0 else 0
+ print(
+ f" {done:>8}/{total} {elapsed:.0f}s {rate:.0f} rows/s eta {eta:.0f}s",
+ flush=True,
+ )
+
+ elapsed = time.perf_counter() - t0
+ print(f" {total} train vectors in {elapsed:.1f}s")
+
+ conn.close()
+ size_mb = os.path.getsize(db_path) / (1024 * 1024)
+ print(f"\nDone: {db_path} ({size_mb:.0f} MB)")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/benchmarks/exhaustive-memory/bench.py b/benchmarks/exhaustive-memory/bench.py
index c9da831..7c969d6 100644
--- a/benchmarks/exhaustive-memory/bench.py
+++ b/benchmarks/exhaustive-memory/bench.py
@@ -248,59 +248,6 @@ def bench_libsql(base, query, page_size, k) -> BenchResult:
return BenchResult(f"libsql ({page_size})", build_time, times)
-def register_np(db, array, name):
- ptr = array.__array_interface__["data"][0]
- nvectors, dimensions = array.__array_interface__["shape"]
- element_type = array.__array_interface__["typestr"]
-
- assert element_type == " BenchResult:
- print(f"sqlite-vec static...")
-
- db = sqlite3.connect(":memory:")
- db.enable_load_extension(True)
- db.load_extension("../../dist/vec0")
-
-
-
- t = time.time()
- register_np(db, base, "base")
- build_time = time.time() - t
-
- times = []
- results = []
- for (
- idx,
- q,
- ) in enumerate(query):
- t0 = time.time()
- result = db.execute(
- """
- select
- rowid
- from base
- where vector match ?
- and k = ?
- order by distance
- """,
- [q.tobytes(), k],
- ).fetchall()
- assert len(result) == k
- times.append(time.time() - t0)
- return BenchResult(f"sqlite-vec static", build_time, times)
-
def bench_faiss(base, query, k) -> BenchResult:
import faiss
dimensions = base.shape[1]
@@ -438,8 +385,6 @@ def suite(name, base, query, k, benchmarks):
for b in benchmarks:
if b == "faiss":
results.append(bench_faiss(base, query, k=k))
- elif b == "vec-static":
- results.append(bench_sqlite_vec_static(base, query, k=k))
elif b.startswith("vec-scalar"):
_, page_size = b.split('.')
results.append(bench_sqlite_vec_scalar(base, query, page_size, k=k))
@@ -541,7 +486,7 @@ def parse_args():
help="Number of queries to use. Defaults all",
)
parser.add_argument(
- "-x", help="type of runs to make", default="faiss,vec-scalar.4096,vec-static,vec-vec0.4096.16,usearch,duckdb,hnswlib,numpy"
+ "-x", help="type of runs to make", default="faiss,vec-scalar.4096,vec-vec0.4096.16,usearch,duckdb,hnswlib,numpy"
)
args = parser.parse_args()
diff --git a/benchmarks/profiling/build-from-npy.sql b/benchmarks/profiling/build-from-npy.sql
index 134df70..92ef59c 100644
--- a/benchmarks/profiling/build-from-npy.sql
+++ b/benchmarks/profiling/build-from-npy.sql
@@ -8,10 +8,3 @@ create virtual table vec_items using vec0(
embedding float[1536]
);
--- 65s (limit 1e5), ~615MB on disk
-insert into vec_items
- select
- rowid,
- vector
- from vec_npy_each(vec_npy_file('examples/dbpedia-openai/data/vectors.npy'))
- limit 1e5;
diff --git a/benchmarks/self-params/build.py b/benchmarks/self-params/build.py
index bc6e388..c5d9fc1 100644
--- a/benchmarks/self-params/build.py
+++ b/benchmarks/self-params/build.py
@@ -6,7 +6,6 @@ def connect(path):
db = sqlite3.connect(path)
db.enable_load_extension(True)
db.load_extension("../dist/vec0")
- db.execute("select load_extension('../dist/vec0', 'sqlite3_vec_fs_read_init')")
db.enable_load_extension(False)
return db
@@ -18,8 +17,6 @@ page_sizes = [ # 4096, 8192,
chunk_sizes = [128, 256, 1024, 2048]
types = ["f32", "int8", "bit"]
-SRC = "../examples/dbpedia-openai/data/vectors.npy"
-
for page_size in page_sizes:
for chunk_size in chunk_sizes:
for t in types:
@@ -42,15 +39,8 @@ for page_size in page_sizes:
func = "vec_quantize_i8(vector, 'unit')"
if t == "bit":
func = "vec_quantize_binary(vector)"
- db.execute(
- f"""
- insert into vec_items
- select rowid, {func}
- from vec_npy_each(vec_npy_file(?))
- limit 100000
- """,
- [SRC],
- )
+ # TODO: replace with non-npy data loading
+ pass
elapsed = time.time() - t0
print(elapsed)
diff --git a/bindings/go/ncruces/go-sqlite3.patch b/bindings/go/ncruces/go-sqlite3.patch
index f202bc3..03bead9 100644
--- a/bindings/go/ncruces/go-sqlite3.patch
+++ b/bindings/go/ncruces/go-sqlite3.patch
@@ -6,7 +6,6 @@ index ed2aaec..4cc0b0e 100755
-Wl,--initial-memory=327680 \
-D_HAVE_SQLITE_CONFIG_H \
-DSQLITE_CUSTOM_INCLUDE=sqlite_opt.h \
-+ -DSQLITE_VEC_OMIT_FS=1 \
$(awk '{print "-Wl,--export="$0}' exports.txt)
"$BINARYEN/wasm-ctor-eval" -g -c _initialize sqlite3.wasm -o sqlite3.tmp
diff --git a/bindings/python/extra_init.py b/bindings/python/extra_init.py
index 267bc41..4408855 100644
--- a/bindings/python/extra_init.py
+++ b/bindings/python/extra_init.py
@@ -1,6 +1,5 @@
from typing import List
from struct import pack
-from sqlite3 import Connection
def serialize_float32(vector: List[float]) -> bytes:
@@ -13,33 +12,3 @@ def serialize_int8(vector: List[int]) -> bytes:
return pack("%sb" % len(vector), *vector)
-try:
- import numpy.typing as npt
-
- def register_numpy(db: Connection, name: str, array: npt.NDArray):
- """ayoo"""
-
- ptr = array.__array_interface__["data"][0]
- nvectors, dimensions = array.__array_interface__["shape"]
- element_type = array.__array_interface__["typestr"]
-
- assert element_type == " dist/sqlite-vec.c
+"""
+
+import re
+import sys
+import os
+
+
+def strip_lsp_block(content):
+ """Remove the LSP-support pattern:
+ #ifndef SQLITE_VEC_H
+ #include "sqlite-vec.c" // ...
+ #endif
+ """
+ pattern = re.compile(
+ r'^\s*#ifndef\s+SQLITE_VEC_H\s*\n'
+ r'\s*#include\s+"sqlite-vec\.c"[^\n]*\n'
+ r'\s*#endif[^\n]*\n',
+ re.MULTILINE,
+ )
+ return pattern.sub('', content)
+
+
+def strip_include_guard(content, guard_macro):
+ """Remove the include guard pair:
+ #ifndef GUARD_MACRO
+ #define GUARD_MACRO
+ ...content...
+ (trailing #endif removed)
+ """
+ # Strip the #ifndef / #define pair at the top
+ header_pattern = re.compile(
+ r'^\s*#ifndef\s+' + re.escape(guard_macro) + r'\s*\n'
+ r'\s*#define\s+' + re.escape(guard_macro) + r'\s*\n',
+ re.MULTILINE,
+ )
+ content = header_pattern.sub('', content, count=1)
+
+ # Strip the trailing #endif (last one in file that closes the guard)
+ # Find the last #endif and remove it
+ lines = content.rstrip('\n').split('\n')
+ for i in range(len(lines) - 1, -1, -1):
+ if re.match(r'^\s*#endif', lines[i]):
+ lines.pop(i)
+ break
+
+ return '\n'.join(lines) + '\n'
+
+
+def detect_include_guard(content):
+ """Detect an include guard macro like SQLITE_VEC_IVF_C."""
+ m = re.match(
+ r'\s*(?:/\*[\s\S]*?\*/\s*)?' # optional block comment
+ r'#ifndef\s+(SQLITE_VEC_\w+_C)\s*\n'
+ r'#define\s+\1',
+ content,
+ )
+ return m.group(1) if m else None
+
+
+def inline_include(match, base_dir):
+ """Replace an #include "sqlite-vec-*.c" with the file's contents."""
+ filename = match.group(1)
+ filepath = os.path.join(base_dir, filename)
+
+ if not os.path.exists(filepath):
+ print(f"Warning: {filepath} not found, leaving #include in place", file=sys.stderr)
+ return match.group(0)
+
+ with open(filepath, 'r') as f:
+ content = f.read()
+
+ # Strip LSP-support block
+ content = strip_lsp_block(content)
+
+ # Strip include guard if present
+ guard = detect_include_guard(content)
+ if guard:
+ content = strip_include_guard(content, guard)
+
+ separator = '/' * 78
+ header = f'\n{separator}\n// Begin inlined: {filename}\n{separator}\n\n'
+ footer = f'\n{separator}\n// End inlined: {filename}\n{separator}\n'
+
+ return header + content.strip('\n') + footer
+
+
+def amalgamate(input_path):
+ base_dir = os.path.dirname(os.path.abspath(input_path))
+
+ with open(input_path, 'r') as f:
+ content = f.read()
+
+ # Replace #include "sqlite-vec-*.c" with inlined contents
+ include_pattern = re.compile(r'^#include\s+"(sqlite-vec-[^"]+\.c)"\s*$', re.MULTILINE)
+ content = include_pattern.sub(lambda m: inline_include(m, base_dir), content)
+
+ return content
+
+
+def main():
+ if len(sys.argv) != 2:
+ print(f"Usage: {sys.argv[0]} ", file=sys.stderr)
+ sys.exit(1)
+
+ result = amalgamate(sys.argv[1])
+ sys.stdout.write(result)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/site/api-reference.md b/site/api-reference.md
index bd144ea..ba8c648 100644
--- a/site/api-reference.md
+++ b/site/api-reference.md
@@ -568,65 +568,6 @@ select 'todo';
-- 'todo'
-```
-
-## NumPy Utilities {#numpy}
-
-Functions to read data from or work with [NumPy arrays](https://numpy.org/doc/stable/reference/generated/numpy.array.html).
-
-### `vec_npy_each(vector)` {#vec_npy_each}
-
-xxx
-
-
-```sql
--- db.execute('select quote(?)', [to_npy(np.array([[1.0], [2.0], [3.0]], dtype=np.float32))]).fetchone()
-select
- rowid,
- vector,
- vec_type(vector),
- vec_to_json(vector)
-from vec_npy_each(
- X'934E554D5059010076007B276465736372273A20273C6634272C2027666F727472616E5F6F72646572273A2046616C73652C20277368617065273A2028332C2031292C207D202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020200A0000803F0000004000004040'
-)
-/*
-┌───────┬─────────────┬──────────────────┬─────────────────────┐
-│ rowid │ vector │ vec_type(vector) │ vec_to_json(vector) │
-├───────┼─────────────┼──────────────────┼─────────────────────┤
-│ 0 │ X'0000803F' │ 'float32' │ '[1.000000]' │
-├───────┼─────────────┼──────────────────┼─────────────────────┤
-│ 1 │ X'00000040' │ 'float32' │ '[2.000000]' │
-├───────┼─────────────┼──────────────────┼─────────────────────┤
-│ 2 │ X'00004040' │ 'float32' │ '[3.000000]' │
-└───────┴─────────────┴──────────────────┴─────────────────────┘
-
-*/
-
-
--- db.execute('select quote(?)', [to_npy(np.array([[1.0], [2.0], [3.0]], dtype=np.float32))]).fetchone()
-select
- rowid,
- vector,
- vec_type(vector),
- vec_to_json(vector)
-from vec_npy_each(
- X'934E554D5059010076007B276465736372273A20273C6634272C2027666F727472616E5F6F72646572273A2046616C73652C20277368617065273A2028332C2031292C207D202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020200A0000803F0000004000004040'
-)
-/*
-┌───────┬─────────────┬──────────────────┬─────────────────────┐
-│ rowid │ vector │ vec_type(vector) │ vec_to_json(vector) │
-├───────┼─────────────┼──────────────────┼─────────────────────┤
-│ 0 │ X'0000803F' │ 'float32' │ '[1.000000]' │
-├───────┼─────────────┼──────────────────┼─────────────────────┤
-│ 1 │ X'00000040' │ 'float32' │ '[2.000000]' │
-├───────┼─────────────┼──────────────────┼─────────────────────┤
-│ 2 │ X'00004040' │ 'float32' │ '[3.000000]' │
-└───────┴─────────────┴──────────────────┴─────────────────────┘
-
-*/
-
-
-
```
## Meta {#meta}
diff --git a/site/compiling.md b/site/compiling.md
index 9ce3c83..b3b2e33 100644
--- a/site/compiling.md
+++ b/site/compiling.md
@@ -59,5 +59,4 @@ The current compile-time flags are:
- `SQLITE_VEC_ENABLE_AVX`, enables AVX CPU instructions for some vector search operations
- `SQLITE_VEC_ENABLE_NEON`, enables NEON CPU instructions for some vector search operations
-- `SQLITE_VEC_OMIT_FS`, removes some obsure SQL functions and features that use the filesystem, meant for some WASM builds where there's no available filesystem
- `SQLITE_VEC_STATIC`, meant for statically linking `sqlite-vec`
diff --git a/sqlite-vec.c b/sqlite-vec.c
index c1874a7..390123b 100644
--- a/sqlite-vec.c
+++ b/sqlite-vec.c
@@ -11,7 +11,7 @@
#include
#include
-#ifndef SQLITE_VEC_OMIT_FS
+#ifdef SQLITE_VEC_DEBUG
#include
#endif
@@ -224,6 +224,63 @@ static f32 l2_sqr_float_neon(const void *pVect1v, const void *pVect2v,
return sqrt(sum_scalar);
}
+static f32 cosine_float_neon(const void *pVect1v, const void *pVect2v,
+ const void *qty_ptr) {
+ f32 *pVect1 = (f32 *)pVect1v;
+ f32 *pVect2 = (f32 *)pVect2v;
+ size_t qty = *((size_t *)qty_ptr);
+ size_t qty16 = qty >> 4;
+ const f32 *pEnd1 = pVect1 + (qty16 << 4);
+
+ float32x4_t dot0 = vdupq_n_f32(0), dot1 = vdupq_n_f32(0);
+ float32x4_t dot2 = vdupq_n_f32(0), dot3 = vdupq_n_f32(0);
+ float32x4_t amag0 = vdupq_n_f32(0), amag1 = vdupq_n_f32(0);
+ float32x4_t amag2 = vdupq_n_f32(0), amag3 = vdupq_n_f32(0);
+ float32x4_t bmag0 = vdupq_n_f32(0), bmag1 = vdupq_n_f32(0);
+ float32x4_t bmag2 = vdupq_n_f32(0), bmag3 = vdupq_n_f32(0);
+
+ while (pVect1 < pEnd1) {
+ float32x4_t v1, v2;
+ v1 = vld1q_f32(pVect1); pVect1 += 4;
+ v2 = vld1q_f32(pVect2); pVect2 += 4;
+ dot0 = vfmaq_f32(dot0, v1, v2);
+ amag0 = vfmaq_f32(amag0, v1, v1);
+ bmag0 = vfmaq_f32(bmag0, v2, v2);
+
+ v1 = vld1q_f32(pVect1); pVect1 += 4;
+ v2 = vld1q_f32(pVect2); pVect2 += 4;
+ dot1 = vfmaq_f32(dot1, v1, v2);
+ amag1 = vfmaq_f32(amag1, v1, v1);
+ bmag1 = vfmaq_f32(bmag1, v2, v2);
+
+ v1 = vld1q_f32(pVect1); pVect1 += 4;
+ v2 = vld1q_f32(pVect2); pVect2 += 4;
+ dot2 = vfmaq_f32(dot2, v1, v2);
+ amag2 = vfmaq_f32(amag2, v1, v1);
+ bmag2 = vfmaq_f32(bmag2, v2, v2);
+
+ v1 = vld1q_f32(pVect1); pVect1 += 4;
+ v2 = vld1q_f32(pVect2); pVect2 += 4;
+ dot3 = vfmaq_f32(dot3, v1, v2);
+ amag3 = vfmaq_f32(amag3, v1, v1);
+ bmag3 = vfmaq_f32(bmag3, v2, v2);
+ }
+
+ f32 dot_s = vaddvq_f32(vaddq_f32(vaddq_f32(dot0, dot1), vaddq_f32(dot2, dot3)));
+ f32 amag_s = vaddvq_f32(vaddq_f32(vaddq_f32(amag0, amag1), vaddq_f32(amag2, amag3)));
+ f32 bmag_s = vaddvq_f32(vaddq_f32(vaddq_f32(bmag0, bmag1), vaddq_f32(bmag2, bmag3)));
+
+ const f32 *pEnd2 = pVect1 + (qty - (qty16 << 4));
+ while (pVect1 < pEnd2) {
+ dot_s += *pVect1 * *pVect2;
+ amag_s += *pVect1 * *pVect1;
+ bmag_s += *pVect2 * *pVect2;
+ pVect1++; pVect2++;
+ }
+
+ return 1.0f - (dot_s / (sqrtf(amag_s) * sqrtf(bmag_s)));
+}
+
static f32 l2_sqr_int8_neon(const void *pVect1v, const void *pVect2v,
const void *qty_ptr) {
i8 *pVect1 = (i8 *)pVect1v;
@@ -462,6 +519,11 @@ static double distance_l1_f32(const void *a, const void *b, const void *d) {
static f32 distance_cosine_float(const void *pVect1v, const void *pVect2v,
const void *qty_ptr) {
+#ifdef SQLITE_VEC_ENABLE_NEON
+ if ((*(const size_t *)qty_ptr) > 16) {
+ return cosine_float_neon(pVect1v, pVect2v, qty_ptr);
+ }
+#endif
f32 *pVect1 = (f32 *)pVect1v;
f32 *pVect2 = (f32 *)pVect2v;
size_t qty = *((size_t *)qty_ptr);
@@ -478,8 +540,7 @@ static f32 distance_cosine_float(const void *pVect1v, const void *pVect2v,
}
return 1 - (dot / (sqrt(aMag) * sqrt(bMag)));
}
-static f32 distance_cosine_int8(const void *pA, const void *pB,
- const void *pD) {
+static f32 cosine_int8(const void *pA, const void *pB, const void *pD) {
i8 *a = (i8 *)pA;
i8 *b = (i8 *)pB;
size_t d = *((size_t *)pD);
@@ -497,6 +558,125 @@ static f32 distance_cosine_int8(const void *pA, const void *pB,
return 1 - (dot / (sqrt(aMag) * sqrt(bMag)));
}
+#ifdef SQLITE_VEC_ENABLE_NEON
+static f32 cosine_int8_neon(const void *pA, const void *pB, const void *pD) {
+ const i8 *a = (const i8 *)pA;
+ const i8 *b = (const i8 *)pB;
+ size_t d = *((const size_t *)pD);
+ const i8 *aEnd = a + d;
+
+ int32x4_t dot_acc1 = vdupq_n_s32(0);
+ int32x4_t dot_acc2 = vdupq_n_s32(0);
+ int32x4_t aMag_acc1 = vdupq_n_s32(0);
+ int32x4_t aMag_acc2 = vdupq_n_s32(0);
+ int32x4_t bMag_acc1 = vdupq_n_s32(0);
+ int32x4_t bMag_acc2 = vdupq_n_s32(0);
+
+ while (a < aEnd - 31) {
+ int8x16_t va1 = vld1q_s8(a);
+ int8x16_t vb1 = vld1q_s8(b);
+ int16x8_t a1_lo = vmovl_s8(vget_low_s8(va1));
+ int16x8_t a1_hi = vmovl_s8(vget_high_s8(va1));
+ int16x8_t b1_lo = vmovl_s8(vget_low_s8(vb1));
+ int16x8_t b1_hi = vmovl_s8(vget_high_s8(vb1));
+
+ dot_acc1 = vmlal_s16(dot_acc1, vget_low_s16(a1_lo), vget_low_s16(b1_lo));
+ dot_acc1 = vmlal_s16(dot_acc1, vget_high_s16(a1_lo), vget_high_s16(b1_lo));
+ dot_acc2 = vmlal_s16(dot_acc2, vget_low_s16(a1_hi), vget_low_s16(b1_hi));
+ dot_acc2 = vmlal_s16(dot_acc2, vget_high_s16(a1_hi), vget_high_s16(b1_hi));
+
+ aMag_acc1 = vmlal_s16(aMag_acc1, vget_low_s16(a1_lo), vget_low_s16(a1_lo));
+ aMag_acc1 = vmlal_s16(aMag_acc1, vget_high_s16(a1_lo), vget_high_s16(a1_lo));
+ aMag_acc2 = vmlal_s16(aMag_acc2, vget_low_s16(a1_hi), vget_low_s16(a1_hi));
+ aMag_acc2 = vmlal_s16(aMag_acc2, vget_high_s16(a1_hi), vget_high_s16(a1_hi));
+
+ bMag_acc1 = vmlal_s16(bMag_acc1, vget_low_s16(b1_lo), vget_low_s16(b1_lo));
+ bMag_acc1 = vmlal_s16(bMag_acc1, vget_high_s16(b1_lo), vget_high_s16(b1_lo));
+ bMag_acc2 = vmlal_s16(bMag_acc2, vget_low_s16(b1_hi), vget_low_s16(b1_hi));
+ bMag_acc2 = vmlal_s16(bMag_acc2, vget_high_s16(b1_hi), vget_high_s16(b1_hi));
+
+ int8x16_t va2 = vld1q_s8(a + 16);
+ int8x16_t vb2 = vld1q_s8(b + 16);
+ int16x8_t a2_lo = vmovl_s8(vget_low_s8(va2));
+ int16x8_t a2_hi = vmovl_s8(vget_high_s8(va2));
+ int16x8_t b2_lo = vmovl_s8(vget_low_s8(vb2));
+ int16x8_t b2_hi = vmovl_s8(vget_high_s8(vb2));
+
+ dot_acc1 = vmlal_s16(dot_acc1, vget_low_s16(a2_lo), vget_low_s16(b2_lo));
+ dot_acc1 = vmlal_s16(dot_acc1, vget_high_s16(a2_lo), vget_high_s16(b2_lo));
+ dot_acc2 = vmlal_s16(dot_acc2, vget_low_s16(a2_hi), vget_low_s16(b2_hi));
+ dot_acc2 = vmlal_s16(dot_acc2, vget_high_s16(a2_hi), vget_high_s16(b2_hi));
+
+ aMag_acc1 = vmlal_s16(aMag_acc1, vget_low_s16(a2_lo), vget_low_s16(a2_lo));
+ aMag_acc1 = vmlal_s16(aMag_acc1, vget_high_s16(a2_lo), vget_high_s16(a2_lo));
+ aMag_acc2 = vmlal_s16(aMag_acc2, vget_low_s16(a2_hi), vget_low_s16(a2_hi));
+ aMag_acc2 = vmlal_s16(aMag_acc2, vget_high_s16(a2_hi), vget_high_s16(a2_hi));
+
+ bMag_acc1 = vmlal_s16(bMag_acc1, vget_low_s16(b2_lo), vget_low_s16(b2_lo));
+ bMag_acc1 = vmlal_s16(bMag_acc1, vget_high_s16(b2_lo), vget_high_s16(b2_lo));
+ bMag_acc2 = vmlal_s16(bMag_acc2, vget_low_s16(b2_hi), vget_low_s16(b2_hi));
+ bMag_acc2 = vmlal_s16(bMag_acc2, vget_high_s16(b2_hi), vget_high_s16(b2_hi));
+
+ a += 32;
+ b += 32;
+ }
+
+ while (a < aEnd - 15) {
+ int8x16_t va = vld1q_s8(a);
+ int8x16_t vb = vld1q_s8(b);
+ int16x8_t a_lo = vmovl_s8(vget_low_s8(va));
+ int16x8_t a_hi = vmovl_s8(vget_high_s8(va));
+ int16x8_t b_lo = vmovl_s8(vget_low_s8(vb));
+ int16x8_t b_hi = vmovl_s8(vget_high_s8(vb));
+
+ dot_acc1 = vmlal_s16(dot_acc1, vget_low_s16(a_lo), vget_low_s16(b_lo));
+ dot_acc1 = vmlal_s16(dot_acc1, vget_high_s16(a_lo), vget_high_s16(b_lo));
+ dot_acc1 = vmlal_s16(dot_acc1, vget_low_s16(a_hi), vget_low_s16(b_hi));
+ dot_acc1 = vmlal_s16(dot_acc1, vget_high_s16(a_hi), vget_high_s16(b_hi));
+
+ aMag_acc1 = vmlal_s16(aMag_acc1, vget_low_s16(a_lo), vget_low_s16(a_lo));
+ aMag_acc1 = vmlal_s16(aMag_acc1, vget_high_s16(a_lo), vget_high_s16(a_lo));
+ aMag_acc1 = vmlal_s16(aMag_acc1, vget_low_s16(a_hi), vget_low_s16(a_hi));
+ aMag_acc1 = vmlal_s16(aMag_acc1, vget_high_s16(a_hi), vget_high_s16(a_hi));
+
+ bMag_acc1 = vmlal_s16(bMag_acc1, vget_low_s16(b_lo), vget_low_s16(b_lo));
+ bMag_acc1 = vmlal_s16(bMag_acc1, vget_high_s16(b_lo), vget_high_s16(b_lo));
+ bMag_acc1 = vmlal_s16(bMag_acc1, vget_low_s16(b_hi), vget_low_s16(b_hi));
+ bMag_acc1 = vmlal_s16(bMag_acc1, vget_high_s16(b_hi), vget_high_s16(b_hi));
+
+ a += 16;
+ b += 16;
+ }
+
+ int32x4_t dot_sum = vaddq_s32(dot_acc1, dot_acc2);
+ int32x4_t aMag_sum = vaddq_s32(aMag_acc1, aMag_acc2);
+ int32x4_t bMag_sum = vaddq_s32(bMag_acc1, bMag_acc2);
+
+ i32 dot = vaddvq_s32(dot_sum);
+ i32 aMag = vaddvq_s32(aMag_sum);
+ i32 bMag = vaddvq_s32(bMag_sum);
+
+ while (a < aEnd) {
+ dot += (i32)*a * (i32)*b;
+ aMag += (i32)*a * (i32)*a;
+ bMag += (i32)*b * (i32)*b;
+ a++;
+ b++;
+ }
+
+ return 1.0f - ((f32)dot / (sqrtf((f32)aMag) * sqrtf((f32)bMag)));
+}
+#endif
+
+static f32 distance_cosine_int8(const void *a, const void *b, const void *d) {
+#ifdef SQLITE_VEC_ENABLE_NEON
+ if ((*(const size_t *)d) > 15) {
+ return cosine_int8_neon(a, b, d);
+ }
+#endif
+ return cosine_int8(a, b, d);
+}
+
// https://github.com/facebookresearch/faiss/blob/77e2e79cd0a680adc343b9840dd865da724c579e/faiss/utils/hamming_distance/common.h#L34
static u8 hamdist_table[256] = {
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 2, 3, 2, 3, 3, 4,
@@ -511,6 +691,59 @@ static u8 hamdist_table[256] = {
4, 5, 5, 6, 5, 6, 6, 7, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8};
+#ifdef SQLITE_VEC_ENABLE_NEON
+static f32 distance_hamming_neon(const u8 *a, const u8 *b, size_t n_bytes) {
+ const u8 *pEnd = a + n_bytes;
+
+ uint32x4_t acc1 = vdupq_n_u32(0);
+ uint32x4_t acc2 = vdupq_n_u32(0);
+ uint32x4_t acc3 = vdupq_n_u32(0);
+ uint32x4_t acc4 = vdupq_n_u32(0);
+
+ while (a <= pEnd - 64) {
+ uint8x16_t v1 = vld1q_u8(a);
+ uint8x16_t v2 = vld1q_u8(b);
+ acc1 = vaddq_u32(acc1, vpaddlq_u16(vpaddlq_u8(vcntq_u8(veorq_u8(v1, v2)))));
+
+ v1 = vld1q_u8(a + 16);
+ v2 = vld1q_u8(b + 16);
+ acc2 = vaddq_u32(acc2, vpaddlq_u16(vpaddlq_u8(vcntq_u8(veorq_u8(v1, v2)))));
+
+ v1 = vld1q_u8(a + 32);
+ v2 = vld1q_u8(b + 32);
+ acc3 = vaddq_u32(acc3, vpaddlq_u16(vpaddlq_u8(vcntq_u8(veorq_u8(v1, v2)))));
+
+ v1 = vld1q_u8(a + 48);
+ v2 = vld1q_u8(b + 48);
+ acc4 = vaddq_u32(acc4, vpaddlq_u16(vpaddlq_u8(vcntq_u8(veorq_u8(v1, v2)))));
+
+ a += 64;
+ b += 64;
+ }
+
+ while (a <= pEnd - 16) {
+ uint8x16_t v1 = vld1q_u8(a);
+ uint8x16_t v2 = vld1q_u8(b);
+ acc1 = vaddq_u32(acc1, vpaddlq_u16(vpaddlq_u8(vcntq_u8(veorq_u8(v1, v2)))));
+ a += 16;
+ b += 16;
+ }
+
+ acc1 = vaddq_u32(acc1, acc2);
+ acc3 = vaddq_u32(acc3, acc4);
+ acc1 = vaddq_u32(acc1, acc3);
+ u32 sum = vaddvq_u32(acc1);
+
+ while (a < pEnd) {
+ sum += hamdist_table[*a ^ *b];
+ a++;
+ b++;
+ }
+
+ return (f32)sum;
+}
+#endif
+
static f32 distance_hamming_u8(u8 *a, u8 *b, size_t n) {
int same = 0;
for (unsigned long i = 0; i < n; i++) {
@@ -555,11 +788,18 @@ static f32 distance_hamming_u64(u64 *a, u64 *b, size_t n) {
*/
static f32 distance_hamming(const void *a, const void *b, const void *d) {
size_t dimensions = *((size_t *)d);
+ size_t n_bytes = dimensions / CHAR_BIT;
+
+#ifdef SQLITE_VEC_ENABLE_NEON
+ if (dimensions >= 128) {
+ return distance_hamming_neon((const u8 *)a, (const u8 *)b, n_bytes);
+ }
+#endif
if ((dimensions % 64) == 0) {
- return distance_hamming_u64((u64 *)a, (u64 *)b, dimensions / 8 / CHAR_BIT);
+ return distance_hamming_u64((u64 *)a, (u64 *)b, n_bytes / sizeof(u64));
}
- return distance_hamming_u8((u8 *)a, (u8 *)b, dimensions / CHAR_BIT);
+ return distance_hamming_u8((u8 *)a, (u8 *)b, n_bytes);
}
#ifdef SQLITE_VEC_TEST
@@ -1065,33 +1305,6 @@ int ensure_vector_match(sqlite3_value *aValue, sqlite3_value *bValue, void **a,
int _cmp(const void *a, const void *b) { return (*(i64 *)a - *(i64 *)b); }
-struct VecNpyFile {
- char *path;
- size_t pathLength;
-};
-#define SQLITE_VEC_NPY_FILE_NAME "vec0-npy-file"
-
-#ifndef SQLITE_VEC_OMIT_FS
-static void vec_npy_file(sqlite3_context *context, int argc,
- sqlite3_value **argv) {
- assert(argc == 1);
- char *path = (char *)sqlite3_value_text(argv[0]);
- size_t pathLength = sqlite3_value_bytes(argv[0]);
- struct VecNpyFile *f;
-
- f = sqlite3_malloc(sizeof(*f));
- if (!f) {
- sqlite3_result_error_nomem(context);
- return;
- }
- memset(f, 0, sizeof(*f));
-
- f->path = path;
- f->pathLength = pathLength;
- sqlite3_result_pointer(context, f, SQLITE_VEC_NPY_FILE_NAME, sqlite3_free);
-}
-#endif
-
#pragma region scalar functions
static void vec_f32(sqlite3_context *context, int argc, sqlite3_value **argv) {
assert(argc == 1);
@@ -2281,12 +2494,53 @@ enum Vec0DistanceMetrics {
VEC0_DISTANCE_METRIC_L1 = 3,
};
+/**
+ * Compute distance between two full-precision vectors using the appropriate
+ * distance function for the given element type and metric.
+ * Shared utility used by ANN index implementations.
+ */
+static f32 vec0_distance_full(
+ const void *a, const void *b, size_t dimensions,
+ enum VectorElementType elementType,
+ enum Vec0DistanceMetrics metric) {
+ switch (elementType) {
+ case SQLITE_VEC_ELEMENT_TYPE_FLOAT32:
+ switch (metric) {
+ case VEC0_DISTANCE_METRIC_L2:
+ return distance_l2_sqr_float(a, b, &dimensions);
+ case VEC0_DISTANCE_METRIC_COSINE:
+ return distance_cosine_float(a, b, &dimensions);
+ case VEC0_DISTANCE_METRIC_L1:
+ return (f32)distance_l1_f32(a, b, &dimensions);
+ }
+ break;
+ case SQLITE_VEC_ELEMENT_TYPE_INT8:
+ switch (metric) {
+ case VEC0_DISTANCE_METRIC_L2:
+ return distance_l2_sqr_int8(a, b, &dimensions);
+ case VEC0_DISTANCE_METRIC_COSINE:
+ return distance_cosine_int8(a, b, &dimensions);
+ case VEC0_DISTANCE_METRIC_L1:
+ return (f32)distance_l1_int8(a, b, &dimensions);
+ }
+ break;
+ case SQLITE_VEC_ELEMENT_TYPE_BIT:
+ return distance_hamming(a, b, &dimensions);
+ }
+ return 0.0f;
+}
+
+enum Vec0IndexType {
+ VEC0_INDEX_TYPE_FLAT = 1,
+};
+
struct VectorColumnDefinition {
char *name;
int name_length;
size_t dimensions;
enum VectorElementType element_type;
enum Vec0DistanceMetrics distance_metric;
+ enum Vec0IndexType index_type;
};
struct Vec0PartitionColumnDefinition {
@@ -2346,6 +2600,7 @@ int vec0_parse_vector_column(const char *source, int source_length,
int nameLength;
enum VectorElementType elementType;
enum Vec0DistanceMetrics distanceMetric = VEC0_DISTANCE_METRIC_L2;
+ enum Vec0IndexType indexType = VEC0_INDEX_TYPE_FLAT;
int dimensions;
vec0_scanner_init(&scanner, source, source_length);
@@ -2449,6 +2704,40 @@ int vec0_parse_vector_column(const char *source, int source_length,
return SQLITE_ERROR;
}
}
+ else if (sqlite3_strnicmp(key, "indexed", keyLength) == 0) {
+ // expect "by"
+ rc = vec0_scanner_next(&scanner, &token);
+ if (rc != VEC0_TOKEN_RESULT_SOME ||
+ token.token_type != TOKEN_TYPE_IDENTIFIER ||
+ sqlite3_strnicmp(token.start, "by", token.end - token.start) != 0) {
+ return SQLITE_ERROR;
+ }
+ // expect index type name
+ rc = vec0_scanner_next(&scanner, &token);
+ if (rc != VEC0_TOKEN_RESULT_SOME ||
+ token.token_type != TOKEN_TYPE_IDENTIFIER) {
+ return SQLITE_ERROR;
+ }
+ int indexNameLen = token.end - token.start;
+ if (sqlite3_strnicmp(token.start, "flat", indexNameLen) == 0) {
+ indexType = VEC0_INDEX_TYPE_FLAT;
+ // expect '('
+ rc = vec0_scanner_next(&scanner, &token);
+ if (rc != VEC0_TOKEN_RESULT_SOME ||
+ token.token_type != TOKEN_TYPE_LPAREN) {
+ return SQLITE_ERROR;
+ }
+ // expect ')'
+ rc = vec0_scanner_next(&scanner, &token);
+ if (rc != VEC0_TOKEN_RESULT_SOME ||
+ token.token_type != TOKEN_TYPE_RPAREN) {
+ return SQLITE_ERROR;
+ }
+ } else {
+ // unknown index type
+ return SQLITE_ERROR;
+ }
+ }
// unknown key
else {
return SQLITE_ERROR;
@@ -2463,6 +2752,7 @@ int vec0_parse_vector_column(const char *source, int source_length,
outColumn->distance_metric = distanceMetric;
outColumn->element_type = elementType;
outColumn->dimensions = dimensions;
+ outColumn->index_type = indexType;
return SQLITE_OK;
}
@@ -2660,758 +2950,6 @@ static sqlite3_module vec_eachModule = {
#pragma endregion
-#pragma region vec_npy_each table function
-
-enum NpyTokenType {
- NPY_TOKEN_TYPE_IDENTIFIER,
- NPY_TOKEN_TYPE_NUMBER,
- NPY_TOKEN_TYPE_LPAREN,
- NPY_TOKEN_TYPE_RPAREN,
- NPY_TOKEN_TYPE_LBRACE,
- NPY_TOKEN_TYPE_RBRACE,
- NPY_TOKEN_TYPE_COLON,
- NPY_TOKEN_TYPE_COMMA,
- NPY_TOKEN_TYPE_STRING,
- NPY_TOKEN_TYPE_FALSE,
-};
-
-struct NpyToken {
- enum NpyTokenType token_type;
- unsigned char *start;
- unsigned char *end;
-};
-
-int npy_token_next(unsigned char *start, unsigned char *end,
- struct NpyToken *out) {
- unsigned char *ptr = start;
- while (ptr < end) {
- unsigned char curr = *ptr;
- if (is_whitespace(curr)) {
- ptr++;
- continue;
- } else if (curr == '(') {
- out->start = ptr++;
- out->end = ptr;
- out->token_type = NPY_TOKEN_TYPE_LPAREN;
- return VEC0_TOKEN_RESULT_SOME;
- } else if (curr == ')') {
- out->start = ptr++;
- out->end = ptr;
- out->token_type = NPY_TOKEN_TYPE_RPAREN;
- return VEC0_TOKEN_RESULT_SOME;
- } else if (curr == '{') {
- out->start = ptr++;
- out->end = ptr;
- out->token_type = NPY_TOKEN_TYPE_LBRACE;
- return VEC0_TOKEN_RESULT_SOME;
- } else if (curr == '}') {
- out->start = ptr++;
- out->end = ptr;
- out->token_type = NPY_TOKEN_TYPE_RBRACE;
- return VEC0_TOKEN_RESULT_SOME;
- } else if (curr == ':') {
- out->start = ptr++;
- out->end = ptr;
- out->token_type = NPY_TOKEN_TYPE_COLON;
- return VEC0_TOKEN_RESULT_SOME;
- } else if (curr == ',') {
- out->start = ptr++;
- out->end = ptr;
- out->token_type = NPY_TOKEN_TYPE_COMMA;
- return VEC0_TOKEN_RESULT_SOME;
- } else if (curr == '\'') {
- unsigned char *start = ptr;
- ptr++;
- while (ptr < end) {
- if ((*ptr) == '\'') {
- break;
- }
- ptr++;
- }
- if (ptr >= end || (*ptr) != '\'') {
- return VEC0_TOKEN_RESULT_ERROR;
- }
- out->start = start;
- out->end = ++ptr;
- out->token_type = NPY_TOKEN_TYPE_STRING;
- return VEC0_TOKEN_RESULT_SOME;
- } else if (curr == 'F' &&
- strncmp((char *)ptr, "False", strlen("False")) == 0) {
- out->start = ptr;
- out->end = (ptr + (int)strlen("False"));
- ptr = out->end;
- out->token_type = NPY_TOKEN_TYPE_FALSE;
- return VEC0_TOKEN_RESULT_SOME;
- } else if (is_digit(curr)) {
- unsigned char *start = ptr;
- while (ptr < end && (is_digit(*ptr))) {
- ptr++;
- }
- out->start = start;
- out->end = ptr;
- out->token_type = NPY_TOKEN_TYPE_NUMBER;
- return VEC0_TOKEN_RESULT_SOME;
- } else {
- return VEC0_TOKEN_RESULT_ERROR;
- }
- }
- return VEC0_TOKEN_RESULT_ERROR;
-}
-
-struct NpyScanner {
- unsigned char *start;
- unsigned char *end;
- unsigned char *ptr;
-};
-
-void npy_scanner_init(struct NpyScanner *scanner, const unsigned char *source,
- int source_length) {
- scanner->start = (unsigned char *)source;
- scanner->end = (unsigned char *)source + source_length;
- scanner->ptr = (unsigned char *)source;
-}
-
-int npy_scanner_next(struct NpyScanner *scanner, struct NpyToken *out) {
- int rc = npy_token_next(scanner->start, scanner->end, out);
- if (rc == VEC0_TOKEN_RESULT_SOME) {
- scanner->start = out->end;
- }
- return rc;
-}
-
-#define NPY_PARSE_ERROR "Error parsing numpy array: "
-int parse_npy_header(sqlite3_vtab *pVTab, const unsigned char *header,
- size_t headerLength,
- enum VectorElementType *out_element_type,
- int *fortran_order, size_t *numElements,
- size_t *numDimensions) {
-
- struct NpyScanner scanner;
- struct NpyToken token;
- int rc;
- npy_scanner_init(&scanner, header, headerLength);
-
- if (npy_scanner_next(&scanner, &token) != VEC0_TOKEN_RESULT_SOME &&
- token.token_type != NPY_TOKEN_TYPE_LBRACE) {
- vtab_set_error(pVTab,
- NPY_PARSE_ERROR "numpy header did not start with '{'");
- return SQLITE_ERROR;
- }
- while (1) {
- rc = npy_scanner_next(&scanner, &token);
- if (rc != VEC0_TOKEN_RESULT_SOME) {
- vtab_set_error(pVTab, NPY_PARSE_ERROR "expected key in numpy header");
- return SQLITE_ERROR;
- }
-
- if (token.token_type == NPY_TOKEN_TYPE_RBRACE) {
- break;
- }
- if (token.token_type != NPY_TOKEN_TYPE_STRING) {
- vtab_set_error(pVTab, NPY_PARSE_ERROR
- "expected a string as key in numpy header");
- return SQLITE_ERROR;
- }
- unsigned char *key = token.start;
-
- rc = npy_scanner_next(&scanner, &token);
- if ((rc != VEC0_TOKEN_RESULT_SOME) ||
- (token.token_type != NPY_TOKEN_TYPE_COLON)) {
- vtab_set_error(pVTab, NPY_PARSE_ERROR
- "expected a ':' after key in numpy header");
- return SQLITE_ERROR;
- }
-
- if (strncmp((char *)key, "'descr'", strlen("'descr'")) == 0) {
- rc = npy_scanner_next(&scanner, &token);
- if ((rc != VEC0_TOKEN_RESULT_SOME) ||
- (token.token_type != NPY_TOKEN_TYPE_STRING)) {
- vtab_set_error(pVTab, NPY_PARSE_ERROR
- "expected a string value after 'descr' key");
- return SQLITE_ERROR;
- }
- if (strncmp((char *)token.start, "'maxChunks = 1024;
- pCur->chunksBufferSize =
- (vector_byte_size(element_type, numDimensions)) * pCur->maxChunks;
- pCur->chunksBuffer = sqlite3_malloc(pCur->chunksBufferSize);
- if (pCur->chunksBufferSize && !pCur->chunksBuffer) {
- return SQLITE_NOMEM;
- }
-
- pCur->currentChunkSize =
- fread(pCur->chunksBuffer, vector_byte_size(element_type, numDimensions),
- pCur->maxChunks, file);
-
- pCur->currentChunkIndex = 0;
- pCur->elementType = element_type;
- pCur->nElements = numElements;
- pCur->nDimensions = numDimensions;
- pCur->input_type = VEC_NPY_EACH_INPUT_FILE;
-
- pCur->eof = pCur->currentChunkSize == 0;
- pCur->file = file;
- return SQLITE_OK;
-}
-#endif
-
-int parse_npy_buffer(sqlite3_vtab *pVTab, const unsigned char *buffer,
- int bufferLength, void **data, size_t *numElements,
- size_t *numDimensions,
- enum VectorElementType *element_type) {
-
- if (bufferLength < 10) {
- // IMP: V03312_20150
- vtab_set_error(pVTab, "numpy array too short");
- return SQLITE_ERROR;
- }
- if (memcmp(NPY_MAGIC, buffer, sizeof(NPY_MAGIC)) != 0) {
- // V11954_28792
- vtab_set_error(pVTab, "numpy array does not contain the 'magic' header");
- return SQLITE_ERROR;
- }
-
- u8 major = buffer[6];
- u8 minor = buffer[7];
- uint16_t headerLength = 0;
- memcpy(&headerLength, &buffer[8], sizeof(uint16_t));
-
- i32 totalHeaderLength = sizeof(NPY_MAGIC) + sizeof(major) + sizeof(minor) +
- sizeof(headerLength) + headerLength;
- i32 dataSize = bufferLength - totalHeaderLength;
-
- if (dataSize < 0) {
- vtab_set_error(pVTab, "numpy array header length is invalid");
- return SQLITE_ERROR;
- }
-
- const unsigned char *header = &buffer[10];
- int fortran_order;
-
- int rc = parse_npy_header(pVTab, header, headerLength, element_type,
- &fortran_order, numElements, numDimensions);
- if (rc != SQLITE_OK) {
- return rc;
- }
-
- i32 expectedDataSize =
- (*numElements * vector_byte_size(*element_type, *numDimensions));
- if (expectedDataSize != dataSize) {
- vtab_set_error(pVTab,
- "numpy array error: Expected a data size of %d, found %d",
- expectedDataSize, dataSize);
- return SQLITE_ERROR;
- }
-
- *data = (void *)&buffer[totalHeaderLength];
- return SQLITE_OK;
-}
-
-static int vec_npy_eachConnect(sqlite3 *db, void *pAux, int argc,
- const char *const *argv, sqlite3_vtab **ppVtab,
- char **pzErr) {
- UNUSED_PARAMETER(pAux);
- UNUSED_PARAMETER(argc);
- UNUSED_PARAMETER(argv);
- UNUSED_PARAMETER(pzErr);
- vec_npy_each_vtab *pNew;
- int rc;
-
- rc = sqlite3_declare_vtab(db, "CREATE TABLE x(vector, input hidden)");
-#define VEC_NPY_EACH_COLUMN_VECTOR 0
-#define VEC_NPY_EACH_COLUMN_INPUT 1
- if (rc == SQLITE_OK) {
- pNew = sqlite3_malloc(sizeof(*pNew));
- *ppVtab = (sqlite3_vtab *)pNew;
- if (pNew == 0)
- return SQLITE_NOMEM;
- memset(pNew, 0, sizeof(*pNew));
- }
- return rc;
-}
-
-static int vec_npy_eachDisconnect(sqlite3_vtab *pVtab) {
- vec_npy_each_vtab *p = (vec_npy_each_vtab *)pVtab;
- sqlite3_free(p);
- return SQLITE_OK;
-}
-
-static int vec_npy_eachOpen(sqlite3_vtab *p, sqlite3_vtab_cursor **ppCursor) {
- UNUSED_PARAMETER(p);
- vec_npy_each_cursor *pCur;
- pCur = sqlite3_malloc(sizeof(*pCur));
- if (pCur == 0)
- return SQLITE_NOMEM;
- memset(pCur, 0, sizeof(*pCur));
- *ppCursor = &pCur->base;
- return SQLITE_OK;
-}
-
-static int vec_npy_eachClose(sqlite3_vtab_cursor *cur) {
- vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)cur;
-#ifndef SQLITE_VEC_OMIT_FS
- if (pCur->file) {
- fclose(pCur->file);
- pCur->file = NULL;
- }
-#endif
- if (pCur->chunksBuffer) {
- sqlite3_free(pCur->chunksBuffer);
- pCur->chunksBuffer = NULL;
- }
- if (pCur->vector) {
- pCur->vector = NULL;
- }
- sqlite3_free(pCur);
- return SQLITE_OK;
-}
-
-static int vec_npy_eachBestIndex(sqlite3_vtab *pVTab,
- sqlite3_index_info *pIdxInfo) {
- int hasInput;
- for (int i = 0; i < pIdxInfo->nConstraint; i++) {
- const struct sqlite3_index_constraint *pCons = &pIdxInfo->aConstraint[i];
- // printf("i=%d iColumn=%d, op=%d, usable=%d\n", i, pCons->iColumn,
- // pCons->op, pCons->usable);
- switch (pCons->iColumn) {
- case VEC_NPY_EACH_COLUMN_INPUT: {
- if (pCons->op == SQLITE_INDEX_CONSTRAINT_EQ && pCons->usable) {
- hasInput = 1;
- pIdxInfo->aConstraintUsage[i].argvIndex = 1;
- pIdxInfo->aConstraintUsage[i].omit = 1;
- }
- break;
- }
- }
- }
- if (!hasInput) {
- pVTab->zErrMsg = sqlite3_mprintf("input argument is required");
- return SQLITE_ERROR;
- }
-
- pIdxInfo->estimatedCost = (double)100000;
- pIdxInfo->estimatedRows = 100000;
-
- return SQLITE_OK;
-}
-
-static int vec_npy_eachFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum,
- const char *idxStr, int argc,
- sqlite3_value **argv) {
- UNUSED_PARAMETER(idxNum);
- UNUSED_PARAMETER(idxStr);
- assert(argc == 1);
- int rc;
-
- vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)pVtabCursor;
-
-#ifndef SQLITE_VEC_OMIT_FS
- if (pCur->file) {
- fclose(pCur->file);
- pCur->file = NULL;
- }
-#endif
- if (pCur->chunksBuffer) {
- sqlite3_free(pCur->chunksBuffer);
- pCur->chunksBuffer = NULL;
- }
- if (pCur->vector) {
- pCur->vector = NULL;
- }
-
-#ifndef SQLITE_VEC_OMIT_FS
- struct VecNpyFile *f = NULL;
- if ((f = sqlite3_value_pointer(argv[0], SQLITE_VEC_NPY_FILE_NAME))) {
- FILE *file = fopen(f->path, "r");
- if (!file) {
- vtab_set_error(pVtabCursor->pVtab, "Could not open numpy file");
- return SQLITE_ERROR;
- }
-
- rc = parse_npy_file(pVtabCursor->pVtab, file, pCur);
- if (rc != SQLITE_OK) {
-#ifndef SQLITE_VEC_OMIT_FS
- fclose(file);
-#endif
- return rc;
- }
-
- } else
-#endif
- {
-
- const unsigned char *input = sqlite3_value_blob(argv[0]);
- int inputLength = sqlite3_value_bytes(argv[0]);
- void *data;
- size_t numElements;
- size_t numDimensions;
- enum VectorElementType element_type;
-
- rc = parse_npy_buffer(pVtabCursor->pVtab, input, inputLength, &data,
- &numElements, &numDimensions, &element_type);
- if (rc != SQLITE_OK) {
- return rc;
- }
-
- pCur->vector = data;
- pCur->elementType = element_type;
- pCur->nElements = numElements;
- pCur->nDimensions = numDimensions;
- pCur->input_type = VEC_NPY_EACH_INPUT_BUFFER;
- }
-
- pCur->iRowid = 0;
- return SQLITE_OK;
-}
-
-static int vec_npy_eachRowid(sqlite3_vtab_cursor *cur, sqlite_int64 *pRowid) {
- vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)cur;
- *pRowid = pCur->iRowid;
- return SQLITE_OK;
-}
-
-static int vec_npy_eachEof(sqlite3_vtab_cursor *cur) {
- vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)cur;
- if (pCur->input_type == VEC_NPY_EACH_INPUT_BUFFER) {
- return (!pCur->nElements) || (size_t)pCur->iRowid >= pCur->nElements;
- }
- return pCur->eof;
-}
-
-static int vec_npy_eachNext(sqlite3_vtab_cursor *cur) {
- vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)cur;
- pCur->iRowid++;
- if (pCur->input_type == VEC_NPY_EACH_INPUT_BUFFER) {
- return SQLITE_OK;
- }
-
-#ifndef SQLITE_VEC_OMIT_FS
- // else: input is a file
- pCur->currentChunkIndex++;
- if (pCur->currentChunkIndex >= pCur->currentChunkSize) {
- pCur->currentChunkSize =
- fread(pCur->chunksBuffer,
- vector_byte_size(pCur->elementType, pCur->nDimensions),
- pCur->maxChunks, pCur->file);
- if (!pCur->currentChunkSize) {
- pCur->eof = 1;
- }
- pCur->currentChunkIndex = 0;
- }
-#endif
- return SQLITE_OK;
-}
-
-static int vec_npy_eachColumnBuffer(vec_npy_each_cursor *pCur,
- sqlite3_context *context, int i) {
- switch (i) {
- case VEC_NPY_EACH_COLUMN_VECTOR: {
- sqlite3_result_subtype(context, pCur->elementType);
- switch (pCur->elementType) {
- case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: {
- sqlite3_result_blob(
- context,
- &((unsigned char *)
- pCur->vector)[pCur->iRowid * pCur->nDimensions * sizeof(f32)],
- pCur->nDimensions * sizeof(f32), SQLITE_TRANSIENT);
-
- break;
- }
- case SQLITE_VEC_ELEMENT_TYPE_INT8:
- case SQLITE_VEC_ELEMENT_TYPE_BIT: {
- // https://github.com/asg017/sqlite-vec/issues/42
- sqlite3_result_error(context,
- "vec_npy_each only supports float32 vectors", -1);
- break;
- }
- }
-
- break;
- }
- }
- return SQLITE_OK;
-}
-static int vec_npy_eachColumnFile(vec_npy_each_cursor *pCur,
- sqlite3_context *context, int i) {
- switch (i) {
- case VEC_NPY_EACH_COLUMN_VECTOR: {
- switch (pCur->elementType) {
- case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: {
- sqlite3_result_blob(
- context,
- &((unsigned char *)
- pCur->chunksBuffer)[pCur->currentChunkIndex *
- pCur->nDimensions * sizeof(f32)],
- pCur->nDimensions * sizeof(f32), SQLITE_TRANSIENT);
- break;
- }
- case SQLITE_VEC_ELEMENT_TYPE_INT8:
- case SQLITE_VEC_ELEMENT_TYPE_BIT: {
- // https://github.com/asg017/sqlite-vec/issues/42
- sqlite3_result_error(context,
- "vec_npy_each only supports float32 vectors", -1);
- break;
- }
- }
- break;
- }
- }
- return SQLITE_OK;
-}
-static int vec_npy_eachColumn(sqlite3_vtab_cursor *cur,
- sqlite3_context *context, int i) {
- vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)cur;
- switch (pCur->input_type) {
- case VEC_NPY_EACH_INPUT_BUFFER:
- return vec_npy_eachColumnBuffer(pCur, context, i);
- case VEC_NPY_EACH_INPUT_FILE:
- return vec_npy_eachColumnFile(pCur, context, i);
- }
- return SQLITE_ERROR;
-}
-
-static sqlite3_module vec_npy_eachModule = {
- /* iVersion */ 0,
- /* xCreate */ 0,
- /* xConnect */ vec_npy_eachConnect,
- /* xBestIndex */ vec_npy_eachBestIndex,
- /* xDisconnect */ vec_npy_eachDisconnect,
- /* xDestroy */ 0,
- /* xOpen */ vec_npy_eachOpen,
- /* xClose */ vec_npy_eachClose,
- /* xFilter */ vec_npy_eachFilter,
- /* xNext */ vec_npy_eachNext,
- /* xEof */ vec_npy_eachEof,
- /* xColumn */ vec_npy_eachColumn,
- /* xRowid */ vec_npy_eachRowid,
- /* xUpdate */ 0,
- /* xBegin */ 0,
- /* xSync */ 0,
- /* xCommit */ 0,
- /* xRollback */ 0,
- /* xFindMethod */ 0,
- /* xRename */ 0,
- /* xSavepoint */ 0,
- /* xRelease */ 0,
- /* xRollbackTo */ 0,
- /* xShadowName */ 0,
-#if SQLITE_VERSION_NUMBER >= 3044000
- /* xIntegrity */ 0,
-#endif
-};
-
-#pragma endregion
#pragma region vec0 virtual table
@@ -5959,6 +5497,65 @@ int min_idx(const f32 *distances, i32 n, u8 *candidates, i32 *out, i32 k,
assert(k > 0);
assert(k <= n);
+#ifdef SQLITE_VEC_EXPERIMENTAL_MIN_IDX
+ // Max-heap variant: O(n log k) single-pass.
+ // out[0..heap_size-1] stores indices; heap ordered by distances descending
+ // so out[0] is always the index of the LARGEST distance in the top-k.
+ (void)bTaken;
+ int heap_size = 0;
+
+ #define HEAP_SIFT_UP(pos) do { \
+ int _c = (pos); \
+ while (_c > 0) { \
+ int _p = (_c - 1) / 2; \
+ if (distances[out[_p]] < distances[out[_c]]) { \
+ i32 _tmp = out[_p]; out[_p] = out[_c]; out[_c] = _tmp; \
+ _c = _p; \
+ } else break; \
+ } \
+ } while(0)
+
+ #define HEAP_SIFT_DOWN(pos, sz) do { \
+ int _p = (pos); \
+ for (;;) { \
+ int _l = 2*_p + 1, _r = 2*_p + 2, _largest = _p; \
+ if (_l < (sz) && distances[out[_l]] > distances[out[_largest]]) \
+ _largest = _l; \
+ if (_r < (sz) && distances[out[_r]] > distances[out[_largest]]) \
+ _largest = _r; \
+ if (_largest == _p) break; \
+ i32 _tmp = out[_p]; out[_p] = out[_largest]; out[_largest] = _tmp; \
+ _p = _largest; \
+ } \
+ } while(0)
+
+ for (int i = 0; i < n; i++) {
+ if (!bitmap_get(candidates, i))
+ continue;
+ if (heap_size < k) {
+ out[heap_size] = i;
+ heap_size++;
+ HEAP_SIFT_UP(heap_size - 1);
+ } else if (distances[i] < distances[out[0]]) {
+ out[0] = i;
+ HEAP_SIFT_DOWN(0, heap_size);
+ }
+ }
+
+ // Heapsort to produce ascending order.
+ for (int i = heap_size - 1; i > 0; i--) {
+ i32 tmp = out[0]; out[0] = out[i]; out[i] = tmp;
+ HEAP_SIFT_DOWN(0, i);
+ }
+
+ #undef HEAP_SIFT_UP
+ #undef HEAP_SIFT_DOWN
+
+ *k_used = heap_size;
+ return SQLITE_OK;
+
+#else
+ // Original: O(n*k) repeated linear scan with bitmap.
bitmap_clear(bTaken, n);
for (int ik = 0; ik < k; ik++) {
@@ -5984,6 +5581,7 @@ int min_idx(const f32 *distances, i32 n, u8 *candidates, i32 *out, i32 k,
}
*k_used = k;
return SQLITE_OK;
+#endif
}
int vec0_get_metadata_text_long_value(
@@ -9388,652 +8986,6 @@ static sqlite3_module vec0Module = {
};
#pragma endregion
-static char *POINTER_NAME_STATIC_BLOB_DEF = "vec0-static_blob_def";
-struct static_blob_definition {
- void *p;
- size_t dimensions;
- size_t nvectors;
- enum VectorElementType element_type;
-};
-static void vec_static_blob_from_raw(sqlite3_context *context, int argc,
- sqlite3_value **argv) {
-
- assert(argc == 4);
- struct static_blob_definition *p;
- p = sqlite3_malloc(sizeof(*p));
- if (!p) {
- sqlite3_result_error_nomem(context);
- return;
- }
- memset(p, 0, sizeof(*p));
- p->p = (void *)sqlite3_value_int64(argv[0]);
- p->element_type = SQLITE_VEC_ELEMENT_TYPE_FLOAT32;
- p->dimensions = sqlite3_value_int64(argv[2]);
- p->nvectors = sqlite3_value_int64(argv[3]);
- sqlite3_result_pointer(context, p, POINTER_NAME_STATIC_BLOB_DEF,
- sqlite3_free);
-}
-#pragma region vec_static_blobs() table function
-
-#define MAX_STATIC_BLOBS 16
-
-typedef struct static_blob static_blob;
-struct static_blob {
- char *name;
- void *p;
- size_t dimensions;
- size_t nvectors;
- enum VectorElementType element_type;
-};
-
-typedef struct vec_static_blob_data vec_static_blob_data;
-struct vec_static_blob_data {
- static_blob static_blobs[MAX_STATIC_BLOBS];
-};
-
-typedef struct vec_static_blobs_vtab vec_static_blobs_vtab;
-struct vec_static_blobs_vtab {
- sqlite3_vtab base;
- vec_static_blob_data *data;
-};
-
-typedef struct vec_static_blobs_cursor vec_static_blobs_cursor;
-struct vec_static_blobs_cursor {
- sqlite3_vtab_cursor base;
- sqlite3_int64 iRowid;
-};
-
-static int vec_static_blobsConnect(sqlite3 *db, void *pAux, int argc,
- const char *const *argv,
- sqlite3_vtab **ppVtab, char **pzErr) {
- UNUSED_PARAMETER(argc);
- UNUSED_PARAMETER(argv);
- UNUSED_PARAMETER(pzErr);
-
- vec_static_blobs_vtab *pNew;
-#define VEC_STATIC_BLOBS_NAME 0
-#define VEC_STATIC_BLOBS_DATA 1
-#define VEC_STATIC_BLOBS_DIMENSIONS 2
-#define VEC_STATIC_BLOBS_COUNT 3
- int rc = sqlite3_declare_vtab(
- db, "CREATE TABLE x(name, data, dimensions hidden, count hidden)");
- if (rc == SQLITE_OK) {
- pNew = sqlite3_malloc(sizeof(*pNew));
- *ppVtab = (sqlite3_vtab *)pNew;
- if (pNew == 0)
- return SQLITE_NOMEM;
- memset(pNew, 0, sizeof(*pNew));
- pNew->data = pAux;
- }
- return rc;
-}
-
-static int vec_static_blobsDisconnect(sqlite3_vtab *pVtab) {
- vec_static_blobs_vtab *p = (vec_static_blobs_vtab *)pVtab;
- sqlite3_free(p);
- return SQLITE_OK;
-}
-
-static int vec_static_blobsUpdate(sqlite3_vtab *pVTab, int argc,
- sqlite3_value **argv, sqlite_int64 *pRowid) {
- UNUSED_PARAMETER(pRowid);
- vec_static_blobs_vtab *p = (vec_static_blobs_vtab *)pVTab;
- // DELETE operation
- if (argc == 1 && sqlite3_value_type(argv[0]) != SQLITE_NULL) {
- return SQLITE_ERROR;
- }
- // INSERT operation
- else if (argc > 1 && sqlite3_value_type(argv[0]) == SQLITE_NULL) {
- const char *key =
- (const char *)sqlite3_value_text(argv[2 + VEC_STATIC_BLOBS_NAME]);
- int idx = -1;
- for (int i = 0; i < MAX_STATIC_BLOBS; i++) {
- if (!p->data->static_blobs[i].name) {
- p->data->static_blobs[i].name = sqlite3_mprintf("%s", key);
- idx = i;
- break;
- }
- }
- if (idx < 0)
- abort();
- struct static_blob_definition *def = sqlite3_value_pointer(
- argv[2 + VEC_STATIC_BLOBS_DATA], POINTER_NAME_STATIC_BLOB_DEF);
- p->data->static_blobs[idx].p = def->p;
- p->data->static_blobs[idx].dimensions = def->dimensions;
- p->data->static_blobs[idx].nvectors = def->nvectors;
- p->data->static_blobs[idx].element_type = def->element_type;
-
- return SQLITE_OK;
- }
- // UPDATE operation
- else if (argc > 1 && sqlite3_value_type(argv[0]) != SQLITE_NULL) {
- return SQLITE_ERROR;
- }
- return SQLITE_ERROR;
-}
-
-static int vec_static_blobsOpen(sqlite3_vtab *p,
- sqlite3_vtab_cursor **ppCursor) {
- UNUSED_PARAMETER(p);
- vec_static_blobs_cursor *pCur;
- pCur = sqlite3_malloc(sizeof(*pCur));
- if (pCur == 0)
- return SQLITE_NOMEM;
- memset(pCur, 0, sizeof(*pCur));
- *ppCursor = &pCur->base;
- return SQLITE_OK;
-}
-
-static int vec_static_blobsClose(sqlite3_vtab_cursor *cur) {
- vec_static_blobs_cursor *pCur = (vec_static_blobs_cursor *)cur;
- sqlite3_free(pCur);
- return SQLITE_OK;
-}
-
-static int vec_static_blobsBestIndex(sqlite3_vtab *pVTab,
- sqlite3_index_info *pIdxInfo) {
- UNUSED_PARAMETER(pVTab);
- pIdxInfo->idxNum = 1;
- pIdxInfo->estimatedCost = (double)10;
- pIdxInfo->estimatedRows = 10;
- return SQLITE_OK;
-}
-
-static int vec_static_blobsNext(sqlite3_vtab_cursor *cur);
-static int vec_static_blobsFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum,
- const char *idxStr, int argc,
- sqlite3_value **argv) {
- UNUSED_PARAMETER(idxNum);
- UNUSED_PARAMETER(idxStr);
- UNUSED_PARAMETER(argc);
- UNUSED_PARAMETER(argv);
- vec_static_blobs_cursor *pCur = (vec_static_blobs_cursor *)pVtabCursor;
- pCur->iRowid = -1;
- vec_static_blobsNext(pVtabCursor);
- return SQLITE_OK;
-}
-
-static int vec_static_blobsRowid(sqlite3_vtab_cursor *cur,
- sqlite_int64 *pRowid) {
- vec_static_blobs_cursor *pCur = (vec_static_blobs_cursor *)cur;
- *pRowid = pCur->iRowid;
- return SQLITE_OK;
-}
-
-static int vec_static_blobsNext(sqlite3_vtab_cursor *cur) {
- vec_static_blobs_cursor *pCur = (vec_static_blobs_cursor *)cur;
- vec_static_blobs_vtab *p = (vec_static_blobs_vtab *)pCur->base.pVtab;
- pCur->iRowid++;
- while (pCur->iRowid < MAX_STATIC_BLOBS) {
- if (p->data->static_blobs[pCur->iRowid].name) {
- return SQLITE_OK;
- }
- pCur->iRowid++;
- }
- return SQLITE_OK;
-}
-
-static int vec_static_blobsEof(sqlite3_vtab_cursor *cur) {
- vec_static_blobs_cursor *pCur = (vec_static_blobs_cursor *)cur;
- return pCur->iRowid >= MAX_STATIC_BLOBS;
-}
-
-static int vec_static_blobsColumn(sqlite3_vtab_cursor *cur,
- sqlite3_context *context, int i) {
- vec_static_blobs_cursor *pCur = (vec_static_blobs_cursor *)cur;
- vec_static_blobs_vtab *p = (vec_static_blobs_vtab *)cur->pVtab;
- switch (i) {
- case VEC_STATIC_BLOBS_NAME:
- sqlite3_result_text(context, p->data->static_blobs[pCur->iRowid].name, -1,
- SQLITE_TRANSIENT);
- break;
- case VEC_STATIC_BLOBS_DATA:
- sqlite3_result_null(context);
- break;
- case VEC_STATIC_BLOBS_DIMENSIONS:
- sqlite3_result_int64(context,
- p->data->static_blobs[pCur->iRowid].dimensions);
- break;
- case VEC_STATIC_BLOBS_COUNT:
- sqlite3_result_int64(context, p->data->static_blobs[pCur->iRowid].nvectors);
- break;
- }
- return SQLITE_OK;
-}
-
-static sqlite3_module vec_static_blobsModule = {
- /* iVersion */ 3,
- /* xCreate */ 0,
- /* xConnect */ vec_static_blobsConnect,
- /* xBestIndex */ vec_static_blobsBestIndex,
- /* xDisconnect */ vec_static_blobsDisconnect,
- /* xDestroy */ 0,
- /* xOpen */ vec_static_blobsOpen,
- /* xClose */ vec_static_blobsClose,
- /* xFilter */ vec_static_blobsFilter,
- /* xNext */ vec_static_blobsNext,
- /* xEof */ vec_static_blobsEof,
- /* xColumn */ vec_static_blobsColumn,
- /* xRowid */ vec_static_blobsRowid,
- /* xUpdate */ vec_static_blobsUpdate,
- /* xBegin */ 0,
- /* xSync */ 0,
- /* xCommit */ 0,
- /* xRollback */ 0,
- /* xFindMethod */ 0,
- /* xRename */ 0,
- /* xSavepoint */ 0,
- /* xRelease */ 0,
- /* xRollbackTo */ 0,
- /* xShadowName */ 0,
-#if SQLITE_VERSION_NUMBER >= 3044000
- /* xIntegrity */ 0
-#endif
-};
-#pragma endregion
-
-#pragma region vec_static_blob_entries() table function
-
-typedef struct vec_static_blob_entries_vtab vec_static_blob_entries_vtab;
-struct vec_static_blob_entries_vtab {
- sqlite3_vtab base;
- static_blob *blob;
-};
-typedef enum {
- VEC_SBE__QUERYPLAN_FULLSCAN = 1,
- VEC_SBE__QUERYPLAN_KNN = 2
-} vec_sbe_query_plan;
-
-struct sbe_query_knn_data {
- i64 k;
- i64 k_used;
- // Array of rowids of size k. Must be freed with sqlite3_free().
- i32 *rowids;
- // Array of distances of size k. Must be freed with sqlite3_free().
- f32 *distances;
- i64 current_idx;
-};
-void sbe_query_knn_data_clear(struct sbe_query_knn_data *knn_data) {
- if (!knn_data)
- return;
-
- if (knn_data->rowids) {
- sqlite3_free(knn_data->rowids);
- knn_data->rowids = NULL;
- }
- if (knn_data->distances) {
- sqlite3_free(knn_data->distances);
- knn_data->distances = NULL;
- }
-}
-
-typedef struct vec_static_blob_entries_cursor vec_static_blob_entries_cursor;
-struct vec_static_blob_entries_cursor {
- sqlite3_vtab_cursor base;
- sqlite3_int64 iRowid;
- vec_sbe_query_plan query_plan;
- struct sbe_query_knn_data *knn_data;
-};
-
-static int vec_static_blob_entriesConnect(sqlite3 *db, void *pAux, int argc,
- const char *const *argv,
- sqlite3_vtab **ppVtab, char **pzErr) {
- UNUSED_PARAMETER(argc);
- UNUSED_PARAMETER(argv);
- UNUSED_PARAMETER(pzErr);
- vec_static_blob_data *blob_data = pAux;
- int idx = -1;
- for (int i = 0; i < MAX_STATIC_BLOBS; i++) {
- if (!blob_data->static_blobs[i].name)
- continue;
- if (strncmp(blob_data->static_blobs[i].name, argv[3],
- strlen(blob_data->static_blobs[i].name)) == 0) {
- idx = i;
- break;
- }
- }
- if (idx < 0)
- abort();
- vec_static_blob_entries_vtab *pNew;
-#define VEC_STATIC_BLOB_ENTRIES_VECTOR 0
-#define VEC_STATIC_BLOB_ENTRIES_DISTANCE 1
-#define VEC_STATIC_BLOB_ENTRIES_K 2
- int rc = sqlite3_declare_vtab(
- db, "CREATE TABLE x(vector, distance hidden, k hidden)");
- if (rc == SQLITE_OK) {
- pNew = sqlite3_malloc(sizeof(*pNew));
- *ppVtab = (sqlite3_vtab *)pNew;
- if (pNew == 0)
- return SQLITE_NOMEM;
- memset(pNew, 0, sizeof(*pNew));
- pNew->blob = &blob_data->static_blobs[idx];
- }
- return rc;
-}
-
-static int vec_static_blob_entriesCreate(sqlite3 *db, void *pAux, int argc,
- const char *const *argv,
- sqlite3_vtab **ppVtab, char **pzErr) {
- return vec_static_blob_entriesConnect(db, pAux, argc, argv, ppVtab, pzErr);
-}
-
-static int vec_static_blob_entriesDisconnect(sqlite3_vtab *pVtab) {
- vec_static_blob_entries_vtab *p = (vec_static_blob_entries_vtab *)pVtab;
- sqlite3_free(p);
- return SQLITE_OK;
-}
-
-static int vec_static_blob_entriesOpen(sqlite3_vtab *p,
- sqlite3_vtab_cursor **ppCursor) {
- UNUSED_PARAMETER(p);
- vec_static_blob_entries_cursor *pCur;
- pCur = sqlite3_malloc(sizeof(*pCur));
- if (pCur == 0)
- return SQLITE_NOMEM;
- memset(pCur, 0, sizeof(*pCur));
- *ppCursor = &pCur->base;
- return SQLITE_OK;
-}
-
-static int vec_static_blob_entriesClose(sqlite3_vtab_cursor *cur) {
- vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)cur;
- sqlite3_free(pCur->knn_data);
- sqlite3_free(pCur);
- return SQLITE_OK;
-}
-
-static int vec_static_blob_entriesBestIndex(sqlite3_vtab *pVTab,
- sqlite3_index_info *pIdxInfo) {
- vec_static_blob_entries_vtab *p = (vec_static_blob_entries_vtab *)pVTab;
- int iMatchTerm = -1;
- int iLimitTerm = -1;
- // int iRowidTerm = -1; // https://github.com/asg017/sqlite-vec/issues/47
- int iKTerm = -1;
-
- for (int i = 0; i < pIdxInfo->nConstraint; i++) {
- if (!pIdxInfo->aConstraint[i].usable)
- continue;
-
- int iColumn = pIdxInfo->aConstraint[i].iColumn;
- int op = pIdxInfo->aConstraint[i].op;
- if (op == SQLITE_INDEX_CONSTRAINT_MATCH &&
- iColumn == VEC_STATIC_BLOB_ENTRIES_VECTOR) {
- if (iMatchTerm > -1) {
- // https://github.com/asg017/sqlite-vec/issues/51
- return SQLITE_ERROR;
- }
- iMatchTerm = i;
- }
- if (op == SQLITE_INDEX_CONSTRAINT_LIMIT) {
- iLimitTerm = i;
- }
- if (op == SQLITE_INDEX_CONSTRAINT_EQ &&
- iColumn == VEC_STATIC_BLOB_ENTRIES_K) {
- iKTerm = i;
- }
- }
- if (iMatchTerm >= 0) {
- if (iLimitTerm < 0 && iKTerm < 0) {
- // https://github.com/asg017/sqlite-vec/issues/51
- return SQLITE_ERROR;
- }
- if (iLimitTerm >= 0 && iKTerm >= 0) {
- return SQLITE_ERROR; // limit or k, not both
- }
- if (pIdxInfo->nOrderBy < 1) {
- vtab_set_error(pVTab, "ORDER BY distance required");
- return SQLITE_CONSTRAINT;
- }
- if (pIdxInfo->nOrderBy > 1) {
- // https://github.com/asg017/sqlite-vec/issues/51
- vtab_set_error(pVTab, "more than 1 ORDER BY clause provided");
- return SQLITE_CONSTRAINT;
- }
- if (pIdxInfo->aOrderBy[0].iColumn != VEC_STATIC_BLOB_ENTRIES_DISTANCE) {
- vtab_set_error(pVTab, "ORDER BY must be on the distance column");
- return SQLITE_CONSTRAINT;
- }
- if (pIdxInfo->aOrderBy[0].desc) {
- vtab_set_error(pVTab,
- "Only ascending in ORDER BY distance clause is supported, "
- "DESC is not supported yet.");
- return SQLITE_CONSTRAINT;
- }
-
- pIdxInfo->idxNum = VEC_SBE__QUERYPLAN_KNN;
- pIdxInfo->estimatedCost = (double)10;
- pIdxInfo->estimatedRows = 10;
-
- pIdxInfo->orderByConsumed = 1;
- pIdxInfo->aConstraintUsage[iMatchTerm].argvIndex = 1;
- pIdxInfo->aConstraintUsage[iMatchTerm].omit = 1;
- if (iLimitTerm >= 0) {
- pIdxInfo->aConstraintUsage[iLimitTerm].argvIndex = 2;
- pIdxInfo->aConstraintUsage[iLimitTerm].omit = 1;
- } else {
- pIdxInfo->aConstraintUsage[iKTerm].argvIndex = 2;
- pIdxInfo->aConstraintUsage[iKTerm].omit = 1;
- }
-
- } else {
- pIdxInfo->idxNum = VEC_SBE__QUERYPLAN_FULLSCAN;
- pIdxInfo->estimatedCost = (double)p->blob->nvectors;
- pIdxInfo->estimatedRows = p->blob->nvectors;
- }
- return SQLITE_OK;
-}
-
-static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor,
- int idxNum, const char *idxStr,
- int argc, sqlite3_value **argv) {
- UNUSED_PARAMETER(idxStr);
- assert(argc >= 0 && argc <= 3);
- vec_static_blob_entries_cursor *pCur =
- (vec_static_blob_entries_cursor *)pVtabCursor;
- vec_static_blob_entries_vtab *p =
- (vec_static_blob_entries_vtab *)pCur->base.pVtab;
-
- if (idxNum == VEC_SBE__QUERYPLAN_KNN) {
- assert(argc == 2);
- pCur->query_plan = VEC_SBE__QUERYPLAN_KNN;
- struct sbe_query_knn_data *knn_data;
- knn_data = sqlite3_malloc(sizeof(*knn_data));
- if (!knn_data) {
- return SQLITE_NOMEM;
- }
- memset(knn_data, 0, sizeof(*knn_data));
-
- void *queryVector;
- size_t dimensions;
- enum VectorElementType elementType;
- vector_cleanup cleanup;
- char *err;
- int rc = vector_from_value(argv[0], &queryVector, &dimensions, &elementType,
- &cleanup, &err);
- if (rc != SQLITE_OK) {
- return SQLITE_ERROR;
- }
- if (elementType != p->blob->element_type) {
- return SQLITE_ERROR;
- }
- if (dimensions != p->blob->dimensions) {
- return SQLITE_ERROR;
- }
-
- i64 k = min(sqlite3_value_int64(argv[1]), (i64)p->blob->nvectors);
- if (k < 0) {
- // HANDLE https://github.com/asg017/sqlite-vec/issues/55
- return SQLITE_ERROR;
- }
- if (k == 0) {
- knn_data->k = 0;
- pCur->knn_data = knn_data;
- return SQLITE_OK;
- }
-
- size_t bsize = (p->blob->nvectors + 7) & ~7;
-
- i32 *topk_rowids = sqlite3_malloc(k * sizeof(i32));
- if (!topk_rowids) {
- // HANDLE https://github.com/asg017/sqlite-vec/issues/55
- return SQLITE_ERROR;
- }
- f32 *distances = sqlite3_malloc(bsize * sizeof(f32));
- if (!distances) {
- // HANDLE https://github.com/asg017/sqlite-vec/issues/55
- return SQLITE_ERROR;
- }
-
- for (size_t i = 0; i < p->blob->nvectors; i++) {
- // https://github.com/asg017/sqlite-vec/issues/52
- float *v = ((float *)p->blob->p) + (i * p->blob->dimensions);
- distances[i] =
- distance_l2_sqr_float(v, (float *)queryVector, &p->blob->dimensions);
- }
- u8 *candidates = bitmap_new(bsize);
- assert(candidates);
-
- u8 *taken = bitmap_new(bsize);
- assert(taken);
-
- bitmap_fill(candidates, bsize);
- for (size_t i = bsize; i >= p->blob->nvectors; i--) {
- bitmap_set(candidates, i, 0);
- }
- i32 k_used = 0;
- min_idx(distances, bsize, candidates, topk_rowids, k, taken, &k_used);
- knn_data->current_idx = 0;
- knn_data->distances = distances;
- knn_data->k = k;
- knn_data->rowids = topk_rowids;
-
- pCur->knn_data = knn_data;
- } else {
- pCur->query_plan = VEC_SBE__QUERYPLAN_FULLSCAN;
- pCur->iRowid = 0;
- }
-
- return SQLITE_OK;
-}
-
-static int vec_static_blob_entriesRowid(sqlite3_vtab_cursor *cur,
- sqlite_int64 *pRowid) {
- vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)cur;
- switch (pCur->query_plan) {
- case VEC_SBE__QUERYPLAN_FULLSCAN: {
- *pRowid = pCur->iRowid;
- return SQLITE_OK;
- }
- case VEC_SBE__QUERYPLAN_KNN: {
- i32 rowid = ((i32 *)pCur->knn_data->rowids)[pCur->knn_data->current_idx];
- *pRowid = (sqlite3_int64)rowid;
- return SQLITE_OK;
- }
- }
- return SQLITE_ERROR;
-}
-
-static int vec_static_blob_entriesNext(sqlite3_vtab_cursor *cur) {
- vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)cur;
- switch (pCur->query_plan) {
- case VEC_SBE__QUERYPLAN_FULLSCAN: {
- pCur->iRowid++;
- return SQLITE_OK;
- }
- case VEC_SBE__QUERYPLAN_KNN: {
- pCur->knn_data->current_idx++;
- return SQLITE_OK;
- }
- }
- return SQLITE_ERROR;
-}
-
-static int vec_static_blob_entriesEof(sqlite3_vtab_cursor *cur) {
- vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)cur;
- vec_static_blob_entries_vtab *p =
- (vec_static_blob_entries_vtab *)pCur->base.pVtab;
- switch (pCur->query_plan) {
- case VEC_SBE__QUERYPLAN_FULLSCAN: {
- return (size_t)pCur->iRowid >= p->blob->nvectors;
- }
- case VEC_SBE__QUERYPLAN_KNN: {
- return pCur->knn_data->current_idx >= pCur->knn_data->k;
- }
- }
- return SQLITE_ERROR;
-}
-
-static int vec_static_blob_entriesColumn(sqlite3_vtab_cursor *cur,
- sqlite3_context *context, int i) {
- vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)cur;
- vec_static_blob_entries_vtab *p = (vec_static_blob_entries_vtab *)cur->pVtab;
-
- switch (pCur->query_plan) {
- case VEC_SBE__QUERYPLAN_FULLSCAN: {
- switch (i) {
- case VEC_STATIC_BLOB_ENTRIES_VECTOR:
-
- sqlite3_result_blob(
- context,
- ((unsigned char *)p->blob->p) +
- (pCur->iRowid * p->blob->dimensions * sizeof(float)),
- p->blob->dimensions * sizeof(float), SQLITE_TRANSIENT);
- sqlite3_result_subtype(context, p->blob->element_type);
- break;
- }
- return SQLITE_OK;
- }
- case VEC_SBE__QUERYPLAN_KNN: {
- switch (i) {
- case VEC_STATIC_BLOB_ENTRIES_VECTOR: {
- i32 rowid = ((i32 *)pCur->knn_data->rowids)[pCur->knn_data->current_idx];
- sqlite3_result_blob(context,
- ((unsigned char *)p->blob->p) +
- (rowid * p->blob->dimensions * sizeof(float)),
- p->blob->dimensions * sizeof(float),
- SQLITE_TRANSIENT);
- sqlite3_result_subtype(context, p->blob->element_type);
- break;
- }
- }
- return SQLITE_OK;
- }
- }
- return SQLITE_ERROR;
-}
-
-static sqlite3_module vec_static_blob_entriesModule = {
- /* iVersion */ 3,
- /* xCreate */
- vec_static_blob_entriesCreate, // handle rm?
- // https://github.com/asg017/sqlite-vec/issues/55
- /* xConnect */ vec_static_blob_entriesConnect,
- /* xBestIndex */ vec_static_blob_entriesBestIndex,
- /* xDisconnect */ vec_static_blob_entriesDisconnect,
- /* xDestroy */ vec_static_blob_entriesDisconnect,
- /* xOpen */ vec_static_blob_entriesOpen,
- /* xClose */ vec_static_blob_entriesClose,
- /* xFilter */ vec_static_blob_entriesFilter,
- /* xNext */ vec_static_blob_entriesNext,
- /* xEof */ vec_static_blob_entriesEof,
- /* xColumn */ vec_static_blob_entriesColumn,
- /* xRowid */ vec_static_blob_entriesRowid,
- /* xUpdate */ 0,
- /* xBegin */ 0,
- /* xSync */ 0,
- /* xCommit */ 0,
- /* xRollback */ 0,
- /* xFindMethod */ 0,
- /* xRename */ 0,
- /* xSavepoint */ 0,
- /* xRelease */ 0,
- /* xRollbackTo */ 0,
- /* xShadowName */ 0,
-#if SQLITE_VERSION_NUMBER >= 3044000
- /* xIntegrity */ 0
-#endif
-};
-#pragma endregion
#ifdef SQLITE_VEC_ENABLE_AVX
#define SQLITE_VEC_DEBUG_BUILD_AVX "avx"
@@ -10139,55 +9091,4 @@ SQLITE_VEC_API int sqlite3_vec_init(sqlite3 *db, char **pzErrMsg,
return SQLITE_OK;
}
-#ifndef SQLITE_VEC_OMIT_FS
-SQLITE_VEC_API int sqlite3_vec_numpy_init(sqlite3 *db, char **pzErrMsg,
- const sqlite3_api_routines *pApi) {
- UNUSED_PARAMETER(pzErrMsg);
-#ifndef SQLITE_CORE
- SQLITE_EXTENSION_INIT2(pApi);
-#endif
- int rc = SQLITE_OK;
- rc = sqlite3_create_function_v2(db, "vec_npy_file", 1, SQLITE_RESULT_SUBTYPE,
- NULL, vec_npy_file, NULL, NULL, NULL);
- if(rc != SQLITE_OK) {
- return rc;
- }
- rc = sqlite3_create_module_v2(db, "vec_npy_each", &vec_npy_eachModule, NULL, NULL);
- return rc;
-}
-#endif
-SQLITE_VEC_API int
-sqlite3_vec_static_blobs_init(sqlite3 *db, char **pzErrMsg,
- const sqlite3_api_routines *pApi) {
- UNUSED_PARAMETER(pzErrMsg);
-#ifndef SQLITE_CORE
- SQLITE_EXTENSION_INIT2(pApi);
-#endif
-
- int rc = SQLITE_OK;
- vec_static_blob_data *static_blob_data;
- static_blob_data = sqlite3_malloc(sizeof(*static_blob_data));
- if (!static_blob_data) {
- return SQLITE_NOMEM;
- }
- memset(static_blob_data, 0, sizeof(*static_blob_data));
-
- rc = sqlite3_create_function_v2(
- db, "vec_static_blob_from_raw", 4,
- DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL,
- vec_static_blob_from_raw, NULL, NULL, NULL);
- if (rc != SQLITE_OK)
- return rc;
-
- rc = sqlite3_create_module_v2(db, "vec_static_blobs", &vec_static_blobsModule,
- static_blob_data, sqlite3_free);
- if (rc != SQLITE_OK)
- return rc;
- rc = sqlite3_create_module_v2(db, "vec_static_blob_entries",
- &vec_static_blob_entriesModule,
- static_blob_data, NULL);
- if (rc != SQLITE_OK)
- return rc;
- return rc;
-}
diff --git a/tests/correctness/test-correctness.py b/tests/correctness/test-correctness.py
index cb01f8f..9ed0319 100644
--- a/tests/correctness/test-correctness.py
+++ b/tests/correctness/test-correctness.py
@@ -48,7 +48,6 @@ import json
db = sqlite3.connect(":memory:")
db.enable_load_extension(True)
db.load_extension("../../dist/vec0")
-db.execute("select load_extension('../../dist/vec0', 'sqlite3_vec_fs_read_init')")
db.enable_load_extension(False)
results = db.execute(
@@ -75,17 +74,21 @@ print(b)
db.execute('PRAGMA page_size=16384')
-print("Loading into sqlite-vec vec0 table...")
-t0 = time.time()
-db.execute("create virtual table v using vec0(a float[3072], chunk_size=16)")
-db.execute('insert into v select rowid, vector from vec_npy_each(vec_npy_file("dbpedia_openai_3_large_00.npy"))')
-print(time.time() - t0)
-
print("loading numpy array...")
t0 = time.time()
base = np.load('dbpedia_openai_3_large_00.npy')
print(time.time() - t0)
+print("Loading into sqlite-vec vec0 table...")
+t0 = time.time()
+db.execute("create virtual table v using vec0(a float[3072], chunk_size=16)")
+with db:
+ db.executemany(
+ "insert into v(rowid, a) values (?, ?)",
+ [(i, row.tobytes()) for i, row in enumerate(base)],
+ )
+print(time.time() - t0)
+
np.random.seed(1)
queries = base[np.random.choice(base.shape[0], 20, replace=False), :]
diff --git a/tests/fuzz/numpy.c b/tests/fuzz/numpy.c
deleted file mode 100644
index 9e2900b..0000000
--- a/tests/fuzz/numpy.c
+++ /dev/null
@@ -1,37 +0,0 @@
-#include
-#include
-
-#include
-#include
-#include
-#include "sqlite-vec.h"
-#include "sqlite3.h"
-#include
-
-extern int sqlite3_vec_numpy_init(sqlite3 *db, char **pzErrMsg,
- const sqlite3_api_routines *pApi);
-
-int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
- int rc = SQLITE_OK;
- sqlite3 *db;
- sqlite3_stmt *stmt;
-
- rc = sqlite3_open(":memory:", &db);
- assert(rc == SQLITE_OK);
- rc = sqlite3_vec_init(db, NULL, NULL);
- assert(rc == SQLITE_OK);
- rc = sqlite3_vec_numpy_init(db, NULL, NULL);
- assert(rc == SQLITE_OK);
-
- rc = sqlite3_prepare_v2(db, "select * from vec_npy_each(?)", -1, &stmt, NULL);
- assert(rc == SQLITE_OK);
- sqlite3_bind_blob(stmt, 1, data, size, SQLITE_STATIC);
- rc = sqlite3_step(stmt);
- while (rc == SQLITE_ROW) {
- rc = sqlite3_step(stmt);
- }
-
- sqlite3_finalize(stmt);
- sqlite3_close(db);
- return 0;
-}
diff --git a/tests/sqlite-vec-internal.h b/tests/sqlite-vec-internal.h
index a540849..a02c72a 100644
--- a/tests/sqlite-vec-internal.h
+++ b/tests/sqlite-vec-internal.h
@@ -3,6 +3,7 @@
#include
#include
+#include
int min_idx(
const float *distances,
@@ -62,12 +63,17 @@ enum Vec0DistanceMetrics {
VEC0_DISTANCE_METRIC_L1 = 3,
};
+enum Vec0IndexType {
+ VEC0_INDEX_TYPE_FLAT = 1,
+};
+
struct VectorColumnDefinition {
char *name;
int name_length;
size_t dimensions;
enum VectorElementType element_type;
enum Vec0DistanceMetrics distance_metric;
+ enum Vec0IndexType index_type;
};
int vec0_parse_vector_column(const char *source, int source_length,
diff --git a/tests/test-loadable.py b/tests/test-loadable.py
index bc4eed1..40c6a5e 100644
--- a/tests/test-loadable.py
+++ b/tests/test-loadable.py
@@ -119,151 +119,9 @@ FUNCTIONS = [
MODULES = [
"vec0",
"vec_each",
- # "vec_static_blob_entries",
- # "vec_static_blobs",
]
-def register_numpy(db, name: str, array):
- ptr = array.__array_interface__["data"][0]
- nvectors, dimensions = array.__array_interface__["shape"]
- element_type = array.__array_interface__["typestr"]
-
- assert element_type == "\x9a\x99\x99>",
- },
- {
- "vector": b"fff?\xcd\xccL?",
- },
- ]
- assert execute_all(db, "select rowid, (vector) from z") == [
- {
- "rowid": 0,
- "vector": b"\xcd\xcc\xcc=\xcd\xcc\xcc=\xcd\xcc\xcc=\xcd\xcc\xcc=",
- },
- {
- "rowid": 1,
- "vector": b"\xcd\xccL>\xcd\xccL>\xcd\xccL>\xcd\xccL>",
- },
- {
- "rowid": 2,
- "vector": b"\x9a\x99\x99>\x9a\x99\x99>\x9a\x99\x99>\x9a\x99\x99>",
- },
- {
- "rowid": 3,
- "vector": b"\xcd\xcc\xcc>\xcd\xcc\xcc>\xcd\xcc\xcc>\xcd\xcc\xcc>",
- },
- {
- "rowid": 4,
- "vector": b"\x00\x00\x00?\x00\x00\x00?\x00\x00\x00?\x00\x00\x00?",
- },
- ]
- assert execute_all(
- db,
- "select rowid, vec_to_json(vector) as v from z where vector match ? and k = 3 order by distance;",
- [np.array([0.3, 0.3, 0.3, 0.3], dtype=np.float32)],
- ) == [
- {
- "rowid": 2,
- "v": "[0.300000,0.300000,0.300000,0.300000]",
- },
- {
- "rowid": 3,
- "v": "[0.400000,0.400000,0.400000,0.400000]",
- },
- {
- "rowid": 1,
- "v": "[0.200000,0.200000,0.200000,0.200000]",
- },
- ]
- assert execute_all(
- db,
- "select rowid, vec_to_json(vector) as v from z where vector match ? and k = 3 order by distance;",
- [np.array([0.6, 0.6, 0.6, 0.6], dtype=np.float32)],
- ) == [
- {
- "rowid": 4,
- "v": "[0.500000,0.500000,0.500000,0.500000]",
- },
- {
- "rowid": 3,
- "v": "[0.400000,0.400000,0.400000,0.400000]",
- },
- {
- "rowid": 2,
- "v": "[0.300000,0.300000,0.300000,0.300000]",
- },
- ]
-
-
def test_limits():
db = connect(EXT_PATH)
with _raises(
@@ -1618,231 +1476,6 @@ def test_vec_each():
vec_each_f32(None)
-import io
-
-
-def to_npy(arr):
- buf = io.BytesIO()
- np.save(buf, arr)
- buf.seek(0)
- return buf.read()
-
-
-def test_vec_npy_each():
- db = connect(EXT_PATH, extra_entrypoint="sqlite3_vec_numpy_init")
- vec_npy_each = lambda *args: execute_all(
- db, "select rowid, * from vec_npy_each(?)", args
- )
- assert vec_npy_each(to_npy(np.array([1.1, 2.2, 3.3], dtype=np.float32))) == [
- {
- "rowid": 0,
- "vector": _f32([1.1, 2.2, 3.3]),
- },
- ]
- assert vec_npy_each(to_npy(np.array([[1.1, 2.2, 3.3]], dtype=np.float32))) == [
- {
- "rowid": 0,
- "vector": _f32([1.1, 2.2, 3.3]),
- },
- ]
- assert vec_npy_each(
- to_npy(np.array([[1.1, 2.2, 3.3], [9.9, 8.8, 7.7]], dtype=np.float32))
- ) == [
- {
- "rowid": 0,
- "vector": _f32([1.1, 2.2, 3.3]),
- },
- {
- "rowid": 1,
- "vector": _f32([9.9, 8.8, 7.7]),
- },
- ]
-
- assert vec_npy_each(to_npy(np.array([], dtype=np.float32))) == []
-
-
-def test_vec_npy_each_errors():
- db = connect(EXT_PATH, extra_entrypoint="sqlite3_vec_numpy_init")
- vec_npy_each = lambda *args: execute_all(
- db, "select rowid, * from vec_npy_each(?)", args
- )
-
- full = b"\x93NUMPY\x01\x00v\x00{'descr': ' 8 bits per byte * 64 bytes = 512
+ for (int i = 0; i < 128; i += 2) {
+ a[i] = 0xFF;
+ }
+ d = _test_distance_hamming(a, b, 1024);
+ assert(d == 512.0f);
+ }
+
printf(" All distance_hamming tests passed.\n");
}
diff --git a/tmp-static.py b/tmp-static.py
deleted file mode 100644
index a3b5f37..0000000
--- a/tmp-static.py
+++ /dev/null
@@ -1,56 +0,0 @@
-import sqlite3
-import numpy as np
-
-db = sqlite3.connect(":memory:")
-
-db.enable_load_extension(True)
-db.load_extension("./dist/vec0")
-db.execute("select load_extension('./dist/vec0', 'sqlite3_vec_raw_init')")
-db.enable_load_extension(False)
-
-x = np.array([[0.1, 0.2, 0.3, 0.4], [0.9, 0.8, 0.7, 0.6]], dtype=np.float32)
-y = np.array([[0.2, 0.3], [0.9, 0.8], [0.6, 0.5]], dtype=np.float32)
-z = np.array(
- [
- [0.1, 0.1, 0.1, 0.1],
- [0.2, 0.2, 0.2, 0.2],
- [0.3, 0.3, 0.3, 0.3],
- [0.4, 0.4, 0.4, 0.4],
- [0.5, 0.5, 0.5, 0.5],
- ],
- dtype=np.float32,
-)
-
-
-def register_np(array, name):
- ptr = array.__array_interface__["data"][0]
- nvectors, dimensions = array.__array_interface__["shape"]
- element_type = array.__array_interface__["typestr"]
-
- assert element_type == "