commit 4c8ad629e03eb15833f257344e7c6116f2cf9b9b Author: Alex Garcia Date: Sat Apr 20 13:38:58 2024 -0700 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..81db11e --- /dev/null +++ b/.gitignore @@ -0,0 +1,24 @@ +/target +.vscode +sift/ +*.tar.gz +*.db +*.bin +*.out +venv/ + +vendor/ +dist/ + +*.pyc +*.db-journal +*.svg + +alexandria/ +openai/ +examples/supabase-dbpedia +examples/ann-filtering +examples/dbpedia-openai +examples/imdb + +sqlite-vec.h diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..4b8d4a4 --- /dev/null +++ b/Makefile @@ -0,0 +1,296 @@ + +COMMIT=$(shell git rev-parse HEAD) +VERSION=$(shell cat VERSION) +DATE=$(shell date +'%FT%TZ%z') + + +ifeq ($(shell uname -s),Darwin) +CONFIG_DARWIN=y +else ifeq ($(OS),Windows_NT) +CONFIG_WINDOWS=y +else +CONFIG_LINUX=y +endif + +ifdef CONFIG_DARWIN +LOADABLE_EXTENSION=dylib +endif + +ifdef CONFIG_LINUX +LOADABLE_EXTENSION=so +endif + +ifdef CONFIG_WINDOWS +LOADABLE_EXTENSION=dll +endif + + +ifdef python +PYTHON=$(python) +else +PYTHON=python3 +endif + +ifndef OMIT_SIMD + ifeq ($(shell uname -sm),Darwin x86_64) + CFLAGS += -mavx -DSQLITE_VEC_ENABLE_AVX + endif + ifeq ($(shell uname -sm),Darwin arm64) + CFLAGS += -mcpu=apple-m1 -DSQLITE_VEC_ENABLE_NEON + endif +endif + +ifdef IS_MACOS_ARM +RENAME_WHEELS_ARGS=--is-macos-arm +else +RENAME_WHEELS_ARGS= +endif + +prefix=dist +$(prefix): + mkdir -p $(prefix) + +TARGET_LOADABLE=$(prefix)/vec0.$(LOADABLE_EXTENSION) +TARGET_STATIC=$(prefix)/libsqlite_vec0.a +TARGET_STATIC_H=$(prefix)/sqlite-vec.h +TARGET_CLI=$(prefix)/sqlite3 + +loadable: $(TARGET_LOADABLE) +static: $(TARGET_STATIC) +cli: $(TARGET_CLI) + +all: loadable static cli + +$(TARGET_LOADABLE): sqlite-vec.c sqlite-vec.h $(prefix) + gcc \ + -fPIC -shared \ + -Wall -Wextra \ + -Ivendor/ \ + -O3 \ + $(CFLAGS) \ + $< -o $@ + +$(TARGET_STATIC): sqlite-vec.c sqlite-vec.h $(prefix) + gcc -Ivendor/sqlite -Ivendor/vec $(CFLAGS) -DSQLITE_CORE \ + -O3 -c $< -o $(prefix)/.objs/vec.o + ar rcs $@ $(prefix)/.objs/vec.o + +$(TARGET_STATIC_H): sqlite-vec.h $(prefix) + cp $< $@ + + +OBJS_DIR=$(prefix)/.objs +LIBS_DIR=$(prefix)/.libs +BUILD_DIR=$(prefix)/.build + +$(OBJS_DIR): $(prefix) + mkdir -p $@ + +$(LIBS_DIR): $(prefix) + mkdir -p $@ + +$(BUILD_DIR): $(prefix) + mkdir -p $@ + +$(OBJS_DIR)/sqlite3.o: vendor/sqlite3.c $(OBJS_DIR) + gcc -c -g3 -O3 -DSQLITE_EXTRA_INIT=core_init -DSQLITE_CORE -DSQLITE_ENABLE_STMT_SCANSTATUS -DSQLITE_ENABLE_BYTECODE_VTAB -DSQLITE_ENABLE_EXPLAIN_COMMENTS -I./vendor $< -o $@ + +$(LIBS_DIR)/sqlite3.a: $(OBJS_DIR)/sqlite3.o $(LIBS_DIR) + ar rcs $@ $< + +$(BUILD_DIR)/shell-new.c: vendor/shell.c $(BUILD_DIR) + sed 's/\/\*extra-version-info\*\//EXTRA_TODO/g' $< > $@ + +$(OBJS_DIR)/shell.o: $(BUILD_DIR)/shell-new.c $(OBJS_DIR) + gcc -c -g3 -O3 \ + -DHAVE_EDITLINE=1 -I./vendor \ + -DSQLITE_ENABLE_STMT_SCANSTATUS -DSQLITE_ENABLE_BYTECODE_VTAB -DSQLITE_ENABLE_EXPLAIN_COMMENTS \ + -DEXTRA_TODO="\"CUSTOM BUILD: sqlite-vec\n\"" \ + $< -o $@ + +$(LIBS_DIR)/shell.a: $(OBJS_DIR)/shell.o $(LIBS_DIR) + ar rcs $@ $< + +$(OBJS_DIR)/sqlite-vec.o: sqlite-vec.c $(OBJS_DIR) + gcc -c -g3 -I./vendor $(CFLAGS) $< -o $@ + +$(LIBS_DIR)/sqlite-vec.a: $(OBJS_DIR)/sqlite-vec.o $(LIBS_DIR) + ar rcs $@ $< + +$(TARGET_CLI): $(LIBS_DIR)/sqlite-vec.a $(LIBS_DIR)/shell.a $(LIBS_DIR)/sqlite3.a examples/sqlite3-cli/core_init.c $(prefix) + gcc -g3 \ + -Ivendor/sqlite -I./ \ + -DSQLITE_CORE \ + -DSQLITE_THREADSAFE=0 -DSQLITE_ENABLE_FTS4 \ + -DSQLITE_ENABLE_STMT_SCANSTATUS -DSQLITE_ENABLE_BYTECODE_VTAB -DSQLITE_ENABLE_EXPLAIN_COMMENTS \ + -DSQLITE_EXTRA_INIT=core_init \ + $(CFLAGS) \ + -lreadline -DHAVE_EDITLINE=1 \ + -ldl -lm -lreadline \ + $(LIBS_DIR)/shell.a $(LIBS_DIR)/sqlite3.a $(LIBS_DIR)/sqlite-vec.a examples/sqlite3-cli/core_init.c -o $@ + + +sqlite-vec.h: sqlite-vec.h.tmpl VERSION + VERSION=$(shell cat VERSION) \ + DATE=$(shell date -r VERSION +'%FT%TZ%z') \ + SOURCE=$(shell git log -n 1 --pretty=format:%H -- VERSION) \ + envsubst < $< > $@ + +clean: + rm -rf dist + + +FORMAT_FILES=sqlite-vec.h sqlite-vec.c +format: $(FORMAT_FILES) + clang-format -i $(FORMAT_FILES) + black tests/test-loadable.py + +lint: SHELL:=/bin/bash +lint: + diff -u <(cat $(FORMAT_FILES)) <(clang-format $(FORMAT_FILES)) + +test: + sqlite3 :memory: '.read test.sql' + +.PHONY: version loadable static test clean gh-release \ + ruby + +publish-release: + ./scripts/publish_release.sh + +TARGET_WHEELS=$(prefix)/wheels +INTERMEDIATE_PYPACKAGE_EXTENSION=bindings/python/sqlite_vec/ + +$(TARGET_WHEELS): $(prefix) + mkdir -p $(TARGET_WHEELS) + +bindings/ruby/lib/version.rb: bindings/ruby/lib/version.rb.tmpl VERSION + VERSION=$(VERSION) envsubst < $< > $@ + +bindings/python/sqlite_vec/version.py: bindings/python/sqlite_vec/version.py.tmpl VERSION + VERSION=$(VERSION) envsubst < $< > $@ + echo "✅ generated $@" + +bindings/datasette/datasette_sqlite_vec/version.py: bindings/datasette/datasette_sqlite_vec/version.py.tmpl VERSION + VERSION=$(VERSION) envsubst < $< > $@ + echo "✅ generated $@" + +python: $(TARGET_WHEELS) $(TARGET_LOADABLE) bindings/python/setup.py bindings/python/sqlite_vec/__init__.py scripts/rename-wheels.py + cp $(TARGET_LOADABLE) $(INTERMEDIATE_PYPACKAGE_EXTENSION) + rm $(TARGET_WHEELS)/*.wheel || true + pip3 wheel bindings/python/ -w $(TARGET_WHEELS) + python3 scripts/rename-wheels.py $(TARGET_WHEELS) $(RENAME_WHEELS_ARGS) + echo "✅ generated python wheel" + +datasette: $(TARGET_WHEELS) bindings/datasette/setup.py bindings/datasette/datasette_sqlite_vec/__init__.py + rm $(TARGET_WHEELS)/datasette* || true + pip3 wheel bindings/datasette/ --no-deps -w $(TARGET_WHEELS) + +bindings/sqlite-utils/pyproject.toml: bindings/sqlite-utils/pyproject.toml.tmpl VERSION + VERSION=$(VERSION) envsubst < $< > $@ + echo "✅ generated $@" + +bindings/sqlite-utils/sqlite_utils_sqlite_vec/version.py: bindings/sqlite-utils/sqlite_utils_sqlite_vec/version.py.tmpl VERSION + VERSION=$(VERSION) envsubst < $< > $@ + echo "✅ generated $@" + +sqlite-utils: $(TARGET_WHEELS) bindings/sqlite-utils/pyproject.toml bindings/sqlite-utils/sqlite_utils_sqlite_vec/version.py + python3 -m build bindings/sqlite-utils -w -o $(TARGET_WHEELS) + +node: VERSION bindings/node/platform-package.README.md.tmpl bindings/node/platform-package.package.json.tmpl bindings/node/sqlite-vec/package.json.tmpl scripts/node_generate_platform_packages.sh + scripts/node_generate_platform_packages.sh + +deno: VERSION bindings/deno/deno.json.tmpl + scripts/deno_generate_package.sh + + +version: + make bindings/ruby/lib/version.rb + make bindings/python/sqlite_vec/version.py + make bindings/datasette/datasette_sqlite_vec/version.py + make bindings/datasette/datasette_sqlite_vec/version.py + make bindings/sqlite-utils/pyproject.toml bindings/sqlite-utils/sqlite_utils_sqlite_vec/version.py + make node + make deno + +test-loadable: loadable + $(PYTHON) -m pytest -vv -s tests/test-loadable.py + +test-loadable-snapshot-update: loadable + $(PYTHON) -m pytest -vv tests/test-loadable.py --snapshot-update + +test-loadable-watch: + watchexec -w sqlite-vec.c -w tests/test-loadable.py -w Makefile --clear -- make test-loadable + + + +# ███████████████████████████████ WASM SECTION ███████████████████████████████ + +WASM_DIR=$(prefix)/.wasm + +$(WASM_DIR): $(prefix) + mkdir -p $@ + +SQLITE_WASM_VERSION=3450300 +SQLITE_WASM_YEAR=2024 +SQLITE_WASM_SRCZIP=$(BUILD_DIR)/sqlite-src.zip +SQLITE_WASM_COMPILED_SQLITE3C=$(BUILD_DIR)/sqlite-src-$(SQLITE_WASM_VERSION)/sqlite3.c +SQLITE_WASM_COMPILED_MJS=$(BUILD_DIR)/sqlite-src-$(SQLITE_WASM_VERSION)/ext/wasm/jswasm/sqlite3.mjs +SQLITE_WASM_COMPILED_WASM=$(BUILD_DIR)/sqlite-src-$(SQLITE_WASM_VERSION)/ext/wasm/jswasm/sqlite3.wasm + +TARGET_WASM_LIB=$(WASM_DIR)/libsqlite_vec.wasm.a +TARGET_WASM_MJS=$(WASM_DIR)/sqlite3.mjs +TARGET_WASM_WASM=$(WASM_DIR)/sqlite3.wasm +TARGET_WASM=$(TARGET_WASM_MJS) $(TARGET_WASM_WASM) + +$(SQLITE_WASM_SRCZIP): $(BUILD_DIR) + curl -o $@ https://www.sqlite.org/$(SQLITE_WASM_YEAR)/sqlite-src-$(SQLITE_WASM_VERSION).zip + +$(SQLITE_WASM_COMPILED_SQLITE3C): $(SQLITE_WASM_SRCZIP) $(BUILD_DIR) + unzip -q -o $< -d $(BUILD_DIR) + (cd $(BUILD_DIR)/sqlite-src-$(SQLITE_WASM_VERSION)/ && ./configure --enable-all && make sqlite3.c) + +$(TARGET_WASM_LIB): examples/wasm/wasm.c sqlite-vec.c $(BUILD_DIR) $(WASM_DIR) + emcc -O3 -I./ -Ivendor -DSQLITE_CORE -c examples/wasm/wasm.c -o $(BUILD_DIR)/wasm.wasm.o + emcc -O3 -I./ -Ivendor -DSQLITE_CORE -c sqlite-vec.c -o $(BUILD_DIR)/sqlite-vec.wasm.o + emar rcs $@ $(BUILD_DIR)/wasm.wasm.o $(BUILD_DIR)/sqlite-vec.wasm.o + +$(SQLITE_WASM_COMPILED_MJS) $(SQLITE_WASM_COMPILED_WASM): $(SQLITE_WASM_COMPILED_SQLITE3C) $(TARGET_WASM_LIB) + (cd $(BUILD_DIR)/sqlite-src-$(SQLITE_WASM_VERSION)/ext/wasm && \ + make sqlite3_wasm_extra_init.c=../../../../.wasm/libsqlite_vec.wasm.a "emcc.flags=-s EXTRA_EXPORTED_RUNTIME_METHODS=['ENV'] -s FETCH") + +$(TARGET_WASM_MJS): $(SQLITE_WASM_COMPILED_MJS) + cp $< $@ + +$(TARGET_WASM_WASM): $(SQLITE_WASM_COMPILED_WASM) + cp $< $@ + +wasm: $(TARGET_WASM) + +# ███████████████████████████████ END WASM ███████████████████████████████ + + +# ███████████████████████████████ SITE SECTION ███████████████████████████████ + +WASM_TOOLKIT_NPM_TARGZ=$(BUILD_DIR)/sqlite-wasm-toolkit-npm.tar.gz + +SITE_DIR=$(prefix)/.site +TARGET_SITE=$(prefix)/.site/index.html + +$(WASM_TOOLKIT_NPM_TARGZ): + curl -o $@ -q https://registry.npmjs.org/@alex.garcia/sqlite-wasm-toolkit/-/sqlite-wasm-toolkit-0.0.1-alpha.7.tgz + +$(SITE_DIR)/slim.js $(SITE_DIR)/slim.css: $(WASM_TOOLKIT_NPM_TARGZ) + tar -xvzf $< -C $(SITE_DIR) --strip-components=2 package/dist/slim.js package/dist/slim.css + + +$(SITE_DIR): + mkdir -p $(SITE_DIR) + +# $(TARGET_WASM_MJS) $(TARGET_WASM_WASM) +$(TARGET_SITE): site/index.html $(SITE_DIR)/slim.js $(SITE_DIR)/slim.css + cp $(TARGET_WASM_MJS) $(SITE_DIR) + cp $(TARGET_WASM_WASM) $(SITE_DIR) + cp $< $@ +site: $(TARGET_SITE) +# ███████████████████████████████ END SITE ███████████████████████████████ diff --git a/README.md b/README.md new file mode 100644 index 0000000..93f5162 --- /dev/null +++ b/README.md @@ -0,0 +1,73 @@ +# `sqlite-vec` + +An extremely small, "fast enough" vector search SQLite extension that runs +anywhere! A successor to [sqlite-vss](https://github.com/asg017/sqlite-vss) + +> [!IMPORTANT] +> *`sqlite-vec` is a work-in-progress and not ready for general usage! I plan to launch a "beta" version in the next month or so. Watch this repo for updates.* + +- Store and query float, int8, and binary vectors in `vec0` virtual tables +- Pre-filter vectors with `rowid IN (...)` subqueries +- Written in pure C, no dependencies, + runs anywhere SQLite runs (Linux/MacOS/Windows, in the browser with WASM, + Raspberry Pis, etc.) + +## Sample usage + +```sql +.load ./vec0 + +create virtual table vec_examples using vec0( + sample_embedding float[8] +); + +-- vectors can be provided as JSON or in a compact binary format +insert into vec_examples(rowid, sample_embedding) + values + (1, '[-0.200, 0.250, 0.341, -0.211, 0.645, 0.935, -0.316, -0.924]'), + (2, '[0.443, -0.501, 0.355, -0.771, 0.707, -0.708, -0.185, 0.362]'), + (3, '[0.716, -0.927, 0.134, 0.052, -0.669, 0.793, -0.634, -0.162]'), + (4, '[-0.710, 0.330, 0.656, 0.041, -0.990, 0.726, 0.385, -0.958]'); + + +-- KNN style query goes brrrr +select + rowid, + distance +from vec_examples +where sample_embedding match '[0.890, 0.544, 0.825, 0.961, 0.358, 0.0196, 0.521, 0.175]' +order by distance +limit 2; +/* +┌───────┬──────────────────┐ +│ rowid │ distance │ +├───────┼──────────────────┤ +│ 2 │ 2.38687372207642 │ +│ 1 │ 2.38978505134583 │ +└───────┴──────────────────┘ +*/ +``` + +## Roadmap + +Not currently implemented, but planned in the future (after initial beta version): + +- Approximate nearest neighbors search (IVF and HNSW) +- Metadata filtering + custom internal partitioning +- More vector types (float16, int16, sparse, etc.) and distance functions + +Additionally, there will be pre-compiled and pre-packaged packages of `sqlite-vec` for the following platforms: + +- `pip` for Python +- `npm` for Node.js / Deno / Bun +- `gem` for Ruby +- `cargo` for Rust +- A single `.c` and `.h` amalgammation for C/C++ +- Go module for Golang (requires CGO) +- Datasette and sqlite-utils plugins +- Pre-compiled loadable extensions on Github releases + + +## Support + +Is your company interested in sponsoring `sqlite-vec` development? Send me an email to get more info: https://alexgarcia.xyz diff --git a/VERSION b/VERSION new file mode 100644 index 0000000..68c2f72 --- /dev/null +++ b/VERSION @@ -0,0 +1 @@ +0.0.1-alpha.0 \ No newline at end of file diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 0000000..e69de29 diff --git a/benchmarks/exhaustive-memory/README.md b/benchmarks/exhaustive-memory/README.md new file mode 100644 index 0000000..5336d30 --- /dev/null +++ b/benchmarks/exhaustive-memory/README.md @@ -0,0 +1,17 @@ +``` +python3 bench/bench.py \ + -n "sift1m" \ + -i sift/sift_base.fvecs \ + -q sift/sift_query.fvecs \ + --sample 10000 --qsample 100 \ + -k 10 +``` + +``` +python3 bench/bench.py \ + -n "sift1m" \ + -i sift/sift_base.fvecs \ + -q sift/sift_query.fvecs \ + --sample 10000 --qsample 100 \ + -k 10 +``` diff --git a/benchmarks/exhaustive-memory/bench.py b/benchmarks/exhaustive-memory/bench.py new file mode 100644 index 0000000..ffa3443 --- /dev/null +++ b/benchmarks/exhaustive-memory/bench.py @@ -0,0 +1,403 @@ +import numpy as np +import numpy.typing as npt +import time +import hnswlib +import sqlite3 +import faiss +import lancedb +import pandas as pd + +# import chromadb +from usearch.index import Index, search, MetricKind + +from dataclasses import dataclass + +from typing import List + + +@dataclass +class BenchResult: + tool: str + build_time_ms: float + query_times_ms: List[float] + + +def duration(seconds: float): + ms = seconds * 1000 + return f"{int(ms)}ms" + + +def cosine_similarity( + vec: npt.NDArray[np.float32], mat: npt.NDArray[np.float32], do_norm: bool = True +) -> npt.NDArray[np.float32]: + sim = vec @ mat.T + if do_norm: + sim /= np.linalg.norm(vec) * np.linalg.norm(mat, axis=1) + return sim + + +def topk( + vec: npt.NDArray[np.float32], + mat: npt.NDArray[np.float32], + k: int = 5, + do_norm: bool = True, +) -> tuple[npt.NDArray[np.int32], npt.NDArray[np.float32]]: + sim = cosine_similarity(vec, mat, do_norm=do_norm) + # Rather than sorting all similarities and taking the top K, it's faster to + # argpartition and then just sort the top K. + # The difference is O(N logN) vs O(N + k logk) + indices = np.argpartition(-sim, kth=k)[:k] + top_indices = np.argsort(-sim[indices]) + return indices[top_indices], sim[top_indices] + + +def ivecs_read(fname): + a = np.fromfile(fname, dtype="int32") + d = a[0] + return a.reshape(-1, d + 1)[:, 1:].copy() + + +def fvecs_read(fname): + return ivecs_read(fname).view("float32") + + +def bench_hnsw(base, query): + t0 = time.time() + p = hnswlib.Index(space="ip", dim=128) # possible options are l2, cosine or ip + + # NOTE: Use default settings from the README. + print("buildings hnsw") + p.init_index(max_elements=base.shape[0], ef_construction=200, M=16) + ids = np.arange(base.shape[0]) + p.add_items(base, ids) + p.set_ef(50) + + print("build time", time.time() - t0) + + results = [] + times = [] + t = time.time() + for idx, q in enumerate(query): + t0 = time.time() + result = p.knn_query(q, k=5) + if idx < 5: + print(result[0]) + results.append(result) + times.append(time.time() - t0) + print(time.time() - t) + print("hnsw avg", np.mean(times)) + return results + + +def bench_hnsw_bf(base, query, k) -> BenchResult: + print("hnswlib-bf") + dimensions = base.shape[1] + t0 = time.time() + p = hnswlib.BFIndex(space="l2", dim=dimensions) + + p.init_index(max_elements=base.shape[0]) + ids = np.arange(base.shape[0]) + p.add_items(base, ids) + + build_time = time.time() - t0 + + results = [] + times = [] + t = time.time() + for idx, q in enumerate(query): + t0 = time.time() + result = p.knn_query(q, k=k) + results.append(result) + times.append(time.time() - t0) + return BenchResult("hnswlib-bf", build_time, times) + + +def bench_numpy(base, query, k) -> BenchResult: + print("numpy") + times = [] + results = [] + for idx, q in enumerate(query): + t0 = time.time() + result = topk(q, base, k=k) + results.append(result) + times.append(time.time() - t0) + return BenchResult("numpy", 0, times) + + +def bench_sqlite_vec(base, query, page_size, chunk_size, k) -> BenchResult: + dimensions = base.shape[1] + print(f"sqlite-vec {page_size} {chunk_size}") + + db = sqlite3.connect(":memory:") + db.execute(f"PRAGMA page_size = {page_size}") + db.enable_load_extension(True) + db.load_extension("./dist/vec0") + db.execute( + f""" + create virtual table vec_sift1m using vec0( + chunk_size={chunk_size}, + vector float[{dimensions}] + ) + """ + ) + + t = time.time() + with db: + db.executemany( + "insert into vec_sift1m(vector) values (?)", + list(map(lambda x: [x.tobytes()], base)), + ) + build_time = time.time() - t + times = [] + results = [] + for ( + idx, + q, + ) in enumerate(query): + t0 = time.time() + result = db.execute( + """ + select + rowid, + distance + from vec_sift1m + where vector match ? + and k = ? + order by distance + """, + [q.tobytes(), k], + ).fetchall() + times.append(time.time() - t0) + return BenchResult(f"sqlite-vec vec0 ({page_size}|{chunk_size})", build_time, times) + + +def bench_sqlite_normal(base, query, page_size, k) -> BenchResult: + print(f"sqlite-normal") + + db = sqlite3.connect(":memory:") + db.enable_load_extension(True) + db.load_extension("./dist/vec0") + db.execute(f"PRAGMA page_size={page_size}") + db.execute(f"create table sift1m(vector);") + + t = time.time() + with db: + db.executemany( + "insert into sift1m(vector) values (?)", + list(map(lambda x: [x.tobytes()], base)), + ) + build_time = time.time() - t + times = [] + results = [] + t = time.time() + for ( + idx, + q, + ) in enumerate(query): + t0 = time.time() + result = db.execute( + """ + select + rowid, + vec_distance_l2(?, vector) as distance + from sift1m + order by distance + limit ? + """, + [q.tobytes(), k], + ).fetchall() + times.append(time.time() - t0) + return BenchResult(f"sqlite-vec normal ({page_size})", build_time, times) + + +def bench_faiss(base, query, k) -> BenchResult: + dimensions = base.shape[1] + print("faiss") + t = time.time() + index = faiss.IndexFlatL2(dimensions) + index.add(base) + build_time = time.time() - t + times = [] + results = [] + t = time.time() + for idx, q in enumerate(query): + t0 = time.time() + distances, rowids = index.search(x=np.array([q]), k=k) + results.append(rowids) + times.append(time.time() - t0) + print("faiss avg", duration(np.mean(times))) + return BenchResult("faiss", build_time, times) + + +def bench_lancedb(base, query, k) -> BenchResult: + dimensions = base.shape[1] + db = lancedb.connect("a") + data = [{"vector": row.reshape(1, -1)[0]} for row in base] + # Create a DataFrame where each row is a 1D array + df = pd.DataFrame(data=data, columns=["vector"]) + t = time.time() + db.create_table("t", data=df) + build_time = time.time() - t + tbl = db.open_table("t") + times = [] + for q in query: + t0 = time.time() + result = tbl.search(q).limit(k).to_arrow() + times.append(time.time() - t0) + return BenchResult("lancedb", build_time, times) + + +# def bench_chroma(base, query, k): +# chroma_client = chromadb.Client() +# collection = chroma_client.create_collection(name="my_collection") +# +# t = time.time() +# # chroma doesn't allow for more than 41666 vectors to be inserted at once (???) +# i = 0 +# collection.add(embeddings=base, ids=[str(x) for x in range(len(base))]) +# print("chroma build time: ", duration(time.time() - t)) +# times = [] +# for q in query: +# t0 = time.time() +# result = collection.query( +# query_embeddings=[q.tolist()], +# n_results=k, +# ) +# print(result) +# times.append(time.time() - t0) +# print("chroma avg", duration(np.mean(times))) + + +def bench_usearch_npy(base, query, k) -> BenchResult: + times = [] + for q in query: + t0 = time.time() + # result = index.search(q, exact=True) + result = search(base, q, k, MetricKind.L2sq, exact=True) + times.append(time.time() - t0) + return BenchResult("usearch numpy exact=True", 0, times) + + +def bench_usearch_special(base, query, k) -> BenchResult: + dimensions = base.shape[1] + index = Index(ndim=dimensions) + t = time.time() + index.add(np.arange(len(base)), base) + build_time = time.time() - t + + times = [] + for q in query: + t0 = time.time() + result = index.search(q, exact=True) + times.append(time.time() - t0) + return BenchResult("usuearch index exact=True", build_time, times) + + +from rich.console import Console +from rich.table import Table + + +def suite(name, base, query, k): + print(f"Starting benchmark suite: {name} {base.shape}, k={k}") + results = [] + # n = bench_chroma(base[:40000], query, k=k) + # n = bench_usearch_npy(base, query, k=k) + # n = bench_usearch_special(base, query, k=k) + results.append(bench_faiss(base, query, k=k)) + results.append(bench_hnsw_bf(base, query, k=k)) + # n = bench_sqlite_vec(base, query, 4096, 1024, k=k) + # n = bench_sqlite_vec(base, query, 32768, 1024, k=k) + results.append(bench_sqlite_vec(base, query, 32768, 256, k=k)) + # n = bench_sqlite_vec(base, query, 16384, 64, k=k) + # n = bench_sqlite_vec(base, query, 16384, 32, k=k) + results.append(bench_sqlite_normal(base, query, 8192, k=k)) + results.append(bench_lancedb(base, query, k=k)) + results.append(bench_numpy(base, query, k=k)) + # h = bench_hnsw(base, query) + + table = Table( + title=f"{name}: {base.shape[0]:,} {base.shape[1]}-dimension vectors, k={k}" + ) + + table.add_column("Tool") + table.add_column("Build Time (ms)", justify="right") + table.add_column("Query time (ms)", justify="right") + for res in results: + table.add_row( + res.tool, duration(res.build_time_ms), duration(np.mean(res.query_times_ms)) + ) + + console = Console() + console.print(table) + + +import argparse + + +def parse_args(): + parser = argparse.ArgumentParser(description="Benchmark processing script.") + # Required arguments + parser.add_argument("-n", "--name", required=True, help="Name of the benchmark.") + parser.add_argument( + "-i", "--input", required=True, help="Path to input file (.npy)." + ) + parser.add_argument( + "-k", type=int, required=True, help="Parameter k to use in benchmark." + ) + + # Optional arguments + parser.add_argument( + "-q", "--query", required=False, help="Path to query file (.npy)." + ) + parser.add_argument( + "--sample", + type=int, + required=False, + help="Number of entries in base to use. Defaults all", + ) + parser.add_argument( + "--qsample", + type=int, + required=False, + help="Number of queries to use. Defaults all", + ) + + args = parser.parse_args() + return args + + +from pathlib import Path + + +def cli_read_input(input): + input_path = Path(input) + if input_path.suffix == ".fvecs": + return fvecs_read(input_path) + if input_path.suffx == ".npy": + return np.fromfile(input_path, dtype="float32") + raise Exception("unknown filetype", input) + + +def cli_read_query(query, base): + if query is None: + return base[np.random.choice(base.shape[0], 100, replace=False), :] + return cli_read_input(query) + + +def main(): + args = parse_args() + base = cli_read_input(args.input)[: args.sample] + queries = cli_read_query(args.query, base)[: args.qsample] + suite(args.name, base, queries, args.k) + + from sys import argv + + # base = fvecs_read("sift/sift_base.fvecs") # [:100000] + # query = fvecs_read("sift/sift_query.fvecs")[:100] + # print(base.shape) + # k = int(argv[1]) if len(argv) > 1 else 5 + # suite("sift1m", base, query, k) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/profiling/build-from-npy.sql b/benchmarks/profiling/build-from-npy.sql new file mode 100644 index 0000000..134df70 --- /dev/null +++ b/benchmarks/profiling/build-from-npy.sql @@ -0,0 +1,17 @@ +.timer on +pragma page_size = 32768; +--pragma page_size = 16384; +--pragma page_size = 16384; +--pragma page_size = 4096; + +create virtual table vec_items using vec0( + embedding float[1536] +); + +-- 65s (limit 1e5), ~615MB on disk +insert into vec_items + select + rowid, + vector + from vec_npy_each(vec_npy_file('examples/dbpedia-openai/data/vectors.npy')) + limit 1e5; diff --git a/benchmarks/profiling/query-k.sql b/benchmarks/profiling/query-k.sql new file mode 100644 index 0000000..55a53e1 --- /dev/null +++ b/benchmarks/profiling/query-k.sql @@ -0,0 +1,31 @@ +.timer on + +select rowid, distance +from vec_items +where embedding match (select embedding from vec_items where rowid = 100) + and k = :k +order by distance; + +select rowid, distance +from vec_items +where embedding match (select embedding from vec_items where rowid = 100) + and k = :k +order by distance; + +select rowid, distance +from vec_items +where embedding match (select embedding from vec_items where rowid = 100) + and k = :k +order by distance; + +select rowid, distance +from vec_items +where embedding match (select embedding from vec_items where rowid = 100) + and k = :k +order by distance; + +select rowid, distance +from vec_items +where embedding match (select embedding from vec_items where rowid = 100) + and k = :k +order by distance; diff --git a/benchmarks/self-params/build.py b/benchmarks/self-params/build.py new file mode 100644 index 0000000..bc6e388 --- /dev/null +++ b/benchmarks/self-params/build.py @@ -0,0 +1,85 @@ +import sqlite3 +import time + + +def connect(path): + db = sqlite3.connect(path) + db.enable_load_extension(True) + db.load_extension("../dist/vec0") + db.execute("select load_extension('../dist/vec0', 'sqlite3_vec_fs_read_init')") + db.enable_load_extension(False) + return db + + +page_sizes = [ # 4096, 8192, + 16384, + 32768, +] +chunk_sizes = [128, 256, 1024, 2048] +types = ["f32", "int8", "bit"] + +SRC = "../examples/dbpedia-openai/data/vectors.npy" + +for page_size in page_sizes: + for chunk_size in chunk_sizes: + for t in types: + print(f"{t} page_size={page_size}, chunk_size={chunk_size}") + + t0 = time.time() + db = connect(f"dbs/test.{page_size}.{chunk_size}.{t}.db") + db.execute(f"pragma page_size = {page_size}") + with db: + db.execute( + f""" + create virtual table vec_items using vec0( + embedding {t}[1536], + chunk_size={chunk_size} + ) + """ + ) + func = "vector" + if t == "int8": + func = "vec_quantize_i8(vector, 'unit')" + if t == "bit": + func = "vec_quantize_binary(vector)" + db.execute( + f""" + insert into vec_items + select rowid, {func} + from vec_npy_each(vec_npy_file(?)) + limit 100000 + """, + [SRC], + ) + elapsed = time.time() - t0 + print(elapsed) + +""" + +# for 100_000 + +page_size=4096, chunk_size=256 +3.5894200801849365 +page_size=4096, chunk_size=1024 +60.70046401023865 +page_size=4096, chunk_size=2048 +201.04426288604736 +page_size=8192, chunk_size=256 +7.034514904022217 +page_size=8192, chunk_size=1024 +9.983598947525024 +page_size=8192, chunk_size=2048 +12.318921089172363 +page_size=16384, chunk_size=256 +4.97080397605896 +page_size=16384, chunk_size=1024 +6.051195859909058 +page_size=16384, chunk_size=2048 +8.492683172225952 +page_size=32768, chunk_size=256 +5.906642198562622 +page_size=32768, chunk_size=1024 +5.876632213592529 +page_size=32768, chunk_size=2048 +5.420510292053223 +""" diff --git a/benchmarks/self-params/knn.py b/benchmarks/self-params/knn.py new file mode 100644 index 0000000..a0f5737 --- /dev/null +++ b/benchmarks/self-params/knn.py @@ -0,0 +1,83 @@ +import sqlite3 +import time +from random import randrange +from statistics import mean + + +def connect(path): + print(path) + db = sqlite3.connect(path) + db.enable_load_extension(True) + db.load_extension("../dist/vec0") + db.execute("select load_extension('../dist/vec0', 'sqlite3_vec_fs_read_init')") + db.enable_load_extension(False) + return db + + +page_sizes = [ # 4096, 8192, + 16384, + 32768, +] +chunk_sizes = [128, 256, 1024, 2048] +types = ["f32", "int8", "bit"] + +types.reverse() + +for t in types: + for page_size in page_sizes: + for chunk_size in chunk_sizes: + print(f"page_size={page_size}, chunk_size={chunk_size}") + + func = "embedding" + if t == "int8": + func = "vec_quantize_i8(embedding, 'unit')" + if t == "bit": + func = "vec_quantize_binary(embedding)" + + times = [] + trials = 20 + db = connect(f"dbs/test.{page_size}.{chunk_size}.{t}.db") + + for trial in range(trials): + t0 = time.time() + results = db.execute( + f""" + select rowid + from vec_items + where embedding match (select {func} from vec_items where rowid = ?) + and k = 10 + order by distance + """, + [randrange(100000)], + ).fetchall() + + times.append(time.time() - t0) + print(mean(times)) + +""" + +page_size=4096, chunk_size=256 +0.2635102152824402 +page_size=4096, chunk_size=1024 +0.2609449863433838 +page_size=4096, chunk_size=2048 +0.275589919090271 +page_size=8192, chunk_size=256 +0.18621582984924318 +page_size=8192, chunk_size=1024 +0.20939643383026124 +page_size=8192, chunk_size=2048 +0.22376316785812378 +page_size=16384, chunk_size=256 +0.16012665033340454 +page_size=16384, chunk_size=1024 +0.18346318006515502 +page_size=16384, chunk_size=2048 +0.18224761486053467 +page_size=32768, chunk_size=256 +0.14202518463134767 +page_size=32768, chunk_size=1024 +0.15340715646743774 +page_size=32768, chunk_size=2048 +0.18018823862075806 +""" diff --git a/benchmarks/self-params/test.py b/benchmarks/self-params/test.py new file mode 100644 index 0000000..9d3f7ab --- /dev/null +++ b/benchmarks/self-params/test.py @@ -0,0 +1,24 @@ +import sqlite3 +import time + + +def connect(path): + db = sqlite3.connect(path) + db.enable_load_extension(True) + db.load_extension("../dist/vec0") + db.execute("select load_extension('../dist/vec0', 'sqlite3_vec_fs_read_init')") + db.enable_load_extension(False) + return db + + +page_sizes = [4096, 8192, 16384, 32768] +chunk_sizes = [256, 1024, 2048] + +for page_size in page_sizes: + for chunk_size in chunk_sizes: + print(f"page_size={page_size}, chunk_size={chunk_size}") + + t0 = time.time() + db = connect(f"dbs/test.{page_size}.{chunk_size}.db") + print(db.execute("pragma page_size").fetchone()[0]) + print(db.execute("select count(*) from vec_items_rowids").fetchone()[0]) diff --git a/examples/sqlite3-cli/README.md b/examples/sqlite3-cli/README.md new file mode 100644 index 0000000..de7c92a --- /dev/null +++ b/examples/sqlite3-cli/README.md @@ -0,0 +1,5 @@ +# `sqlite-vec` statically compiled in the SQLite CLI + +You can compile your own version of the `sqlite3` CLI with `sqlite-vec` builtin. The process is not well documented, but the special `SQLITE_EXTRA_INIT` compile option can be used to "inject" code at initialization time. See the `Makefile` at the root of this project for some more info. + +The `core_init.c` file here demonstrates auto-loading the `sqlite-vec` entrypoints at startup. diff --git a/examples/sqlite3-cli/core_init.c b/examples/sqlite3-cli/core_init.c new file mode 100644 index 0000000..4a5bcfd --- /dev/null +++ b/examples/sqlite3-cli/core_init.c @@ -0,0 +1,8 @@ +#include "sqlite3.h" +#include "sqlite-vec.h" +#include +int core_init(const char *dummy) { + int rc = sqlite3_auto_extension((void *)sqlite3_vec_init); + if(rc != SQLITE_OK) return rc; + return sqlite3_auto_extension((void *)sqlite3_vec_fs_read_init); +} diff --git a/examples/wasm/README.md b/examples/wasm/README.md new file mode 100644 index 0000000..5b3bcfe --- /dev/null +++ b/examples/wasm/README.md @@ -0,0 +1,5 @@ +# `sqlite-vec` statically compiled into WASM builds + +You can compile your own version of SQLite's WASM build with `sqlite-vec` builtin. Dynamically loading SQLite extensions is not supported in the official WASM build yet, but you can statically compile extensions in. It's not well documented, but the `sqlite3_wasm_extra_init` option in the SQLite `ext/wasm` Makefile allows you to inject your own code at initialization time. See the `Makefile` at the room of the project for more info. + +The `wasm.c` file here demonstrates auto-loading the `sqlite-vec` entrypoints at startup. diff --git a/examples/wasm/wasm.c b/examples/wasm/wasm.c new file mode 100644 index 0000000..0f95eda --- /dev/null +++ b/examples/wasm/wasm.c @@ -0,0 +1,6 @@ +#include "sqlite3.h" +#include "sqlite-vec.h" + +int sqlite3_wasm_extra_init(const char *) { + sqlite3_auto_extension((void (*)(void)) sqlite3_vec_init); +} diff --git a/site/index.html b/site/index.html new file mode 100644 index 0000000..ff06031 --- /dev/null +++ b/site/index.html @@ -0,0 +1,23 @@ + + +

sqlite-vec

+
+ +
+ + + + diff --git a/sqlite-vec.c b/sqlite-vec.c new file mode 100644 index 0000000..5438da8 --- /dev/null +++ b/sqlite-vec.c @@ -0,0 +1,4607 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "sqlite-vec.h" + +#include "sqlite3ext.h" +SQLITE_EXTENSION_INIT1 + +#ifndef UNUSED_PARAMETER +#define UNUSED_PARAMETER(X) (void)(X) +#endif + +#ifndef todo_assert +#define todo_assert(X) assert(X) +#endif + +#define countof(x) (sizeof(x) / sizeof((x)[0])) + +#define todo(msg) \ + do { \ + fprintf(stderr, "TODO: %s\n", msg); \ + exit(1); \ + } while (0) + +enum VectorElementType { + SQLITE_VEC_ELEMENT_TYPE_FLOAT32 = 223 + 0, + SQLITE_VEC_ELEMENT_TYPE_BIT = 223 + 1, + SQLITE_VEC_ELEMENT_TYPE_INT8 = 223 + 2, +}; + +#ifdef SQLITE_VEC_ENABLE_AVX +#include +#define PORTABLE_ALIGN32 __attribute__((aligned(32))) +#define PORTABLE_ALIGN64 __attribute__((aligned(64))) + +static float l2_sqr_float_avx(const void *pVect1v, const void *pVect2v, + const void *qty_ptr) { + float *pVect1 = (float *)pVect1v; + float *pVect2 = (float *)pVect2v; + size_t qty = *((size_t *)qty_ptr); + float PORTABLE_ALIGN32 TmpRes[8]; + size_t qty16 = qty >> 4; + + const float *pEnd1 = pVect1 + (qty16 << 4); + + __m256 diff, v1, v2; + __m256 sum = _mm256_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + diff = _mm256_sub_ps(v1, v2); + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); + + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + diff = _mm256_sub_ps(v1, v2); + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); + } + + _mm256_store_ps(TmpRes, sum); + return sqrt(TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + + TmpRes[5] + TmpRes[6] + TmpRes[7]); +} +#endif + +#ifdef SQLITE_VEC_ENABLE_NEON +#include + +#define PORTABLE_ALIGN32 __attribute__((aligned(32))) + +// thx https://github.com/nmslib/hnswlib/pull/299/files +static float l2_sqr_float_neon(const void *pVect1v, const void *pVect2v, + const void *qty_ptr) { + float *pVect1 = (float *)pVect1v; + float *pVect2 = (float *)pVect2v; + size_t qty = *((size_t *)qty_ptr); + size_t qty16 = qty >> 4; + + const float *pEnd1 = pVect1 + (qty16 << 4); + + float32x4_t diff, v1, v2; + float32x4_t sum0 = vdupq_n_f32(0); + float32x4_t sum1 = vdupq_n_f32(0); + float32x4_t sum2 = vdupq_n_f32(0); + float32x4_t sum3 = vdupq_n_f32(0); + + while (pVect1 < pEnd1) { + v1 = vld1q_f32(pVect1); + pVect1 += 4; + v2 = vld1q_f32(pVect2); + pVect2 += 4; + diff = vsubq_f32(v1, v2); + sum0 = vfmaq_f32(sum0, diff, diff); + + v1 = vld1q_f32(pVect1); + pVect1 += 4; + v2 = vld1q_f32(pVect2); + pVect2 += 4; + diff = vsubq_f32(v1, v2); + sum1 = vfmaq_f32(sum1, diff, diff); + + v1 = vld1q_f32(pVect1); + pVect1 += 4; + v2 = vld1q_f32(pVect2); + pVect2 += 4; + diff = vsubq_f32(v1, v2); + sum2 = vfmaq_f32(sum2, diff, diff); + + v1 = vld1q_f32(pVect1); + pVect1 += 4; + v2 = vld1q_f32(pVect2); + pVect2 += 4; + diff = vsubq_f32(v1, v2); + sum3 = vfmaq_f32(sum3, diff, diff); + } + + return sqrt( + vaddvq_f32(vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3)))); +} +#endif + +static float l2_sqr_float(const void *pVect1v, const void *pVect2v, + const void *qty_ptr) { + float *pVect1 = (float *)pVect1v; + float *pVect2 = (float *)pVect2v; + size_t qty = *((size_t *)qty_ptr); + + float res = 0; + for (size_t i = 0; i < qty; i++) { + float t = *pVect1 - *pVect2; + pVect1++; + pVect2++; + res += t * t; + } + return sqrt(res); +} + +static float l2_sqr_int8(const void *pA, const void *pB, const void *pD) { + int8_t *a = (int8_t *)pA; + int8_t *b = (int8_t *)pB; + size_t d = *((size_t *)pD); + + float res = 0; + for (size_t i = 0; i < d; i++) { + float t = *a - *b; + a++; + b++; + res += t * t; + } + return sqrt(res); +} + +static float distance_l2_sqr_float(const void *a, const void *b, + const void *d) { +#ifdef SQLITE_VEC_ENABLE_NEON + if (((*(const size_t *)d) % 16 == 0)) { + return l2_sqr_float_neon(a, b, d); + } +#endif +#ifdef SQLITE_VEC_ENABLE_AVX + if (((*(const size_t *)d) % 16 == 0)) { + return l2_sqr_float_avx(a, b, d); + } +#endif + return l2_sqr_float(a, b, d); +} + +static float distance_l2_sqr_int8(const void *a, const void *b, const void *d) { + return l2_sqr_int8(a, b, d); +} + +static float distance_cosine_float(const void *pVect1v, const void *pVect2v, + const void *qty_ptr) { + float *pVect1 = (float *)pVect1v; + float *pVect2 = (float *)pVect2v; + size_t qty = *((size_t *)qty_ptr); + + float dot = 0; + float aMag = 0; + float bMag = 0; + for (size_t i = 0; i < qty; i++) { + dot += *pVect1 * *pVect2; + aMag += *pVect1 * *pVect1; + bMag += *pVect2 * *pVect2; + pVect1++; + pVect2++; + } + return 1 - (dot / (sqrt(aMag) * sqrt(bMag))); +} +static float distance_cosine_int8(const void *pA, const void *pB, + const void *pD) { + int8_t *a = (int8_t *)pA; + int8_t *b = (int8_t *)pB; + size_t d = *((size_t *)pD); + + float dot = 0; + float aMag = 0; + float bMag = 0; + for (size_t i = 0; i < d; i++) { + dot += *a * *b; + aMag += *a * *a; + bMag += *b * *b; + a++; + b++; + } + return 1 - (dot / (sqrt(aMag) * sqrt(bMag))); +} + +// https://github.com/facebookresearch/faiss/blob/77e2e79cd0a680adc343b9840dd865da724c579e/faiss/utils/hamming_distance/common.h#L34 +static uint8_t hamdist_table[256] = { + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 2, 3, 2, 3, 3, 4, + 2, 3, 3, 4, 3, 4, 4, 5, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 1, 2, 2, 3, 2, 3, 3, 4, + 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, + 4, 5, 5, 6, 5, 6, 6, 7, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5, + 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, + 4, 5, 5, 6, 5, 6, 6, 7, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8}; + +static float distance_hamming_u8(uint8_t *a, uint8_t *b, size_t n) { + int same = 0; + for (unsigned long i = 0; i < n; i++) { + same += hamdist_table[a[i] ^ b[i]]; + } + return (float)same; +} +static float distance_hamming_u64(uint64_t *a, uint64_t *b, size_t n) { + int same = 0; + for (unsigned long i = 0; i < n; i++) { + same += __builtin_popcountl(a[i] ^ b[i]); + } + return (float)same; +} + +static float distance_hamming(const void *a, const void *b, const void *d) { + size_t dimensions = *((size_t *)d); + todo_assert((dimensions % CHAR_BIT) == 0); + + if ((dimensions % 64) == 0) { + return distance_hamming_u64((uint64_t *)a, (uint64_t *)b, + dimensions / 8 / CHAR_BIT); + } + return distance_hamming_u8((uint8_t *)a, (uint8_t *)b, dimensions / CHAR_BIT); +} + +// from SQLite source: +// https://github.com/sqlite/sqlite/blob/a509a90958ddb234d1785ed7801880ccb18b497e/src/json.c#L153 +static const char jsonIsSpace[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +}; +#define jsonIsspace(x) (jsonIsSpace[(unsigned char)x]) + +typedef void (*vector_cleanup)(void *p); + +void vector_cleanup_noop(void *_) { UNUSED_PARAMETER(_); } + +#define JSON_SUBTYPE 74 + +struct Array { + size_t element_size; + size_t length; + size_t capacity; + void *z; +}; + +int array_init(struct Array *array, size_t element_size, size_t init_capacity) { + void *z = sqlite3_malloc(element_size * init_capacity); + if (!z) { + return SQLITE_NOMEM; + } + array->element_size = element_size; + array->length = 0; + array->capacity = init_capacity; + array->z = z; + return SQLITE_OK; +} + +int array_append(struct Array *array, const void *element) { + if (array->length == array->capacity) { + size_t new_capacity = array->capacity * 2 + 100; + void *z = sqlite3_realloc64(array->z, array->element_size * new_capacity); + if (z) { + array->capacity = new_capacity; + array->z = z; + } else { + return SQLITE_NOMEM; + } + } + memcpy(&array->z[array->length * array->element_size], element, + array->element_size); + array->length++; + return SQLITE_OK; +} + +void array_cleanup(struct Array *array) { + array->element_size = 0; + array->length = 0; + array->capacity = 0; + sqlite3_free(array->z); + array->z = NULL; +} + +typedef void (*fvec_cleanup)(float *vector); + +void fvec_cleanup_noop(float *_) { UNUSED_PARAMETER(_); } + +static int fvec_from_value(sqlite3_value *value, float **vector, + size_t *dimensions, fvec_cleanup *cleanup, + char **pzErr) { + int value_type = sqlite3_value_type(value); + if (value_type == SQLITE_BLOB) { + const void *blob = sqlite3_value_blob(value); + int bytes = sqlite3_value_bytes(value); + if (bytes == 0) { + *pzErr = sqlite3_mprintf("zero-length vectors are not supported."); + return SQLITE_ERROR; + } + if ((bytes % sizeof(float)) != 0) { + *pzErr = sqlite3_mprintf("invalid float32 vector BLOB length. Must be " + "divisible by %d, found %d", + sizeof(float), bytes); + return SQLITE_ERROR; + } + *vector = (float *)blob; + *dimensions = bytes / sizeof(float); + *cleanup = fvec_cleanup_noop; + return SQLITE_OK; + } + + if (value_type == SQLITE_TEXT) { + const char *source = (const char *)sqlite3_value_text(value); + int source_len = sqlite3_value_bytes(value); + int i = 0; + + struct Array x; + int rc = array_init(&x, sizeof(float), ceil(source_len / 2.0)); + todo_assert(rc == SQLITE_OK); + + // advance leading whitespace to first '[' + while (i < source_len) { + if (jsonIsspace(source[i])) { + i++; + continue; + } + if (source[i] == '[') { + break; + } + + *pzErr = sqlite3_mprintf( + "JSON array parsing error: Input does not start with '['"); + array_cleanup(&x); + return SQLITE_ERROR; + } + if (source[i] != '[') { + *pzErr = sqlite3_mprintf( + "JSON array parsing error: Input does not start with '['"); + array_cleanup(&x); + return SQLITE_ERROR; + } + int offset = i + 1; + + while (offset < source_len) { + char *ptr = (char *)&source[offset]; + char *endptr; + + errno = 0; + double result = strtod(ptr, &endptr); + if ((errno != 0 && result == 0) // some interval error? + || (errno == ERANGE && + (result == HUGE_VAL || result == -HUGE_VAL)) // too big / smalls + ) { + sqlite3_free(x.z); + *pzErr = sqlite3_mprintf("JSON parsing error"); + return SQLITE_ERROR; + } + + if (endptr == ptr) { + if (*ptr != ']') { + sqlite3_free(x.z); + *pzErr = sqlite3_mprintf("JSON parsing error"); + return SQLITE_ERROR; + } + goto done; + } + + float res = (float)result; + array_append(&x, (const void *)&res); + + offset += (endptr - ptr); + while (offset < source_len) { + if (jsonIsspace(source[offset])) { + offset++; + continue; + } + if (source[offset] == ',') { + offset++; + continue; + } // TODO multiple commas in a row without digits? + if (source[offset] == ']') + goto done; + break; + } + } + + done: + + if (x.length > 0) { + *vector = (float *)x.z; + *dimensions = x.length; + *cleanup = (fvec_cleanup)sqlite3_free; + return SQLITE_OK; + } + sqlite3_free(x.z); + *pzErr = sqlite3_mprintf("zero-length vectors are not supported."); + return SQLITE_ERROR; + } + + *pzErr = sqlite3_mprintf( + "Input must have type BLOB (compact format) or TEXT (JSON)"); + return SQLITE_ERROR; +} + +static int bitvec_from_value(sqlite3_value *value, uint8_t **vector, + size_t *dimensions, vector_cleanup *cleanup, + char **pzErr) { + int value_type = sqlite3_value_type(value); + if (value_type == SQLITE_BLOB) { + const void *blob = sqlite3_value_blob(value); + int bytes = sqlite3_value_bytes(value); + if (bytes == 0) { + *pzErr = sqlite3_mprintf("zero-length vectors are not supported."); + return SQLITE_ERROR; + } + *vector = (uint8_t *)blob; + *dimensions = bytes * CHAR_BIT; + *cleanup = vector_cleanup_noop; + return SQLITE_OK; + } + *pzErr = sqlite3_mprintf("Unknown type for bitvector."); + return SQLITE_ERROR; +} + +static int int8_vec_from_value(sqlite3_value *value, int8_t **vector, + size_t *dimensions, vector_cleanup *cleanup, + char **pzErr) { + int value_type = sqlite3_value_type(value); + if (value_type == SQLITE_BLOB) { + const void *blob = sqlite3_value_blob(value); + int bytes = sqlite3_value_bytes(value); + if (bytes == 0) { + *pzErr = sqlite3_mprintf("zero-length vectors are not supported."); + return SQLITE_ERROR; + } + *vector = (int8_t *)blob; + *dimensions = bytes; + *cleanup = vector_cleanup_noop; + return SQLITE_OK; + } + *pzErr = sqlite3_mprintf("Unknown type for int8 vector."); + return SQLITE_ERROR; +} + +/** + * @brief Extract a vector from a sqlite3_value. Can be a float32, int8, or bit + * vector. + * + * @param value: the sqlite3_value to read from. + * @param vector: Output pointer to vector data. + * @param dimensions: Output number of dimensions + * @param dimensions: Output vector element type + * @param cleanup + * @param pzErrorMessage + * @return int SQLITE_OK on success, error code otherwise + */ +int vector_from_value(sqlite3_value *value, void **vector, size_t *dimensions, + enum VectorElementType *element_type, + vector_cleanup *cleanup, char **pzErrorMessage) { + int subtype = sqlite3_value_subtype(value); + if (!subtype || (subtype == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) || + (subtype == JSON_SUBTYPE)) { + int rc = fvec_from_value(value, (float **)vector, dimensions, + (fvec_cleanup *)cleanup, pzErrorMessage); + if (rc == SQLITE_OK) { + *element_type = SQLITE_VEC_ELEMENT_TYPE_FLOAT32; + } + return rc; + } + + if (subtype == SQLITE_VEC_ELEMENT_TYPE_BIT) { + int rc = bitvec_from_value(value, (uint8_t **)vector, dimensions, cleanup, + pzErrorMessage); + if (rc == SQLITE_OK) { + *element_type = SQLITE_VEC_ELEMENT_TYPE_BIT; + } + return rc; + } + if (subtype == SQLITE_VEC_ELEMENT_TYPE_INT8) { + int rc = int8_vec_from_value(value, (int8_t **)vector, dimensions, cleanup, + pzErrorMessage); + if (rc == SQLITE_OK) { + *element_type = SQLITE_VEC_ELEMENT_TYPE_INT8; + } + return rc; + } + *pzErrorMessage = sqlite3_mprintf("Unknown subtype: %d", subtype); + return SQLITE_ERROR; +} + +char *vector_subtype_name(int subtype) { + switch (subtype) { + case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: + return "float32"; + case SQLITE_VEC_ELEMENT_TYPE_INT8: + return "int8"; + case SQLITE_VEC_ELEMENT_TYPE_BIT: + return "bit"; + } + return ""; +} +int ensure_vector_match(sqlite3_value *aValue, sqlite3_value *bValue, void **a, + void **b, enum VectorElementType *element_type, + size_t *dimensions, vector_cleanup *outACleanup, + vector_cleanup *outBCleanup, char **outError) { + int rc; + enum VectorElementType aType, bType; + size_t aDims, bDims; + char *error; + vector_cleanup aCleanup, bCleanup; + + rc = vector_from_value(aValue, a, &aDims, &aType, &aCleanup, &error); + if (rc != SQLITE_OK) { + *outError = sqlite3_mprintf("Error reading 1st vector: %s", error); + sqlite3_free(error); + return SQLITE_ERROR; + } + + rc = vector_from_value(bValue, b, &bDims, &bType, &bCleanup, &error); + if (rc != SQLITE_OK) { + *outError = sqlite3_mprintf("Error reading 2nd vector: %s", error); + sqlite3_free(error); + aCleanup(a); + return SQLITE_ERROR; + } + + if (aType != bType) { + *outError = + sqlite3_mprintf("Vector type mistmatch. First vector has type %s, " + "while the second has type %s.", + vector_subtype_name(aType), vector_subtype_name(bType)); + aCleanup(a); + bCleanup(b); + return SQLITE_ERROR; + } + if (aDims != bDims) { + *outError = sqlite3_mprintf( + "Vector dimension mistmatch. First vector has %ld dimensions, " + "while the second has %ld dimensions.", + aDims, bDims); + aCleanup(a); + bCleanup(b); + return SQLITE_ERROR; + } + *element_type = aType; + *dimensions = aDims; + *outACleanup = aCleanup; + *outBCleanup = bCleanup; + return SQLITE_OK; +} + +int _cmp(const void *a, const void *b) { + return (*(sqlite3_int64 *)a - *(sqlite3_int64 *)b); +} + +struct VecNpyFile { + char *path; + size_t pathLength; +}; +#define SQLITE_VEC_NPY_FILE_NAME "vec0-npy-file" + +static void vec_npy_file(sqlite3_context *context, int argc, + sqlite3_value **argv) { + todo_assert(argc == 1); + char *path = (char *)sqlite3_value_text(argv[0]); + size_t pathLength = sqlite3_value_bytes(argv[0]); + struct VecNpyFile *f = sqlite3_malloc(sizeof(struct VecNpyFile)); + f->path = path; + f->pathLength = pathLength; + sqlite3_result_pointer(context, f, SQLITE_VEC_NPY_FILE_NAME, sqlite3_free); +} + +static void vec_f32(sqlite3_context *context, int argc, sqlite3_value **argv) { + todo_assert(argc == 1); + int rc; + float *vector; + size_t dimensions; + fvec_cleanup cleanup; + char *errmsg; + rc = fvec_from_value(argv[0], &vector, &dimensions, &cleanup, &errmsg); + if (rc != SQLITE_OK) { + sqlite3_result_error(context, errmsg, -1); + sqlite3_free(errmsg); + return; + } + sqlite3_result_blob(context, vector, dimensions * sizeof(float), + SQLITE_TRANSIENT); + sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_FLOAT32); + cleanup(vector); +} +static void vec_bit(sqlite3_context *context, int argc, sqlite3_value **argv) { + todo_assert(argc == 1); + int rc; + uint8_t *vector; + size_t dimensions; + vector_cleanup cleanup; + char *errmsg; + rc = bitvec_from_value(argv[0], &vector, &dimensions, &cleanup, &errmsg); + if (rc != SQLITE_OK) { + sqlite3_result_error(context, errmsg, -1); + sqlite3_free(errmsg); + return; + } + sqlite3_result_blob(context, vector, dimensions / CHAR_BIT, SQLITE_TRANSIENT); + sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_BIT); + cleanup(vector); +} +static void vec_int8(sqlite3_context *context, int argc, sqlite3_value **argv) { + todo_assert(argc == 1); + int rc; + int8_t *vector; + size_t dimensions; + vector_cleanup cleanup; + char *errmsg; + rc = int8_vec_from_value(argv[0], &vector, &dimensions, &cleanup, &errmsg); + if (rc != SQLITE_OK) { + sqlite3_result_error(context, errmsg, -1); + sqlite3_free(errmsg); + return; + } + sqlite3_result_blob(context, vector, dimensions, SQLITE_TRANSIENT); + sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_INT8); + cleanup(vector); +} + +static void vec_length(sqlite3_context *context, int argc, + sqlite3_value **argv) { + todo_assert(argc == 1); + int rc; + void *vector; + size_t dimensions; + vector_cleanup cleanup; + char *errmsg; + enum VectorElementType elementType; + rc = vector_from_value(argv[0], &vector, &dimensions, &elementType, &cleanup, + &errmsg); + if (rc != SQLITE_OK) { + sqlite3_result_error(context, errmsg, -1); + sqlite3_free(errmsg); + return; + } + sqlite3_result_int64(context, dimensions); + cleanup(vector); +} + +static void vec_distance_cosine(sqlite3_context *context, int argc, + sqlite3_value **argv) { + todo_assert(argc == 2); + int rc; + void *a, *b; + size_t dimensions; + vector_cleanup aCleanup, bCleanup; + char *error; + enum VectorElementType elementType; + rc = ensure_vector_match(argv[0], argv[1], &a, &b, &elementType, &dimensions, + &aCleanup, &bCleanup, &error); + if (rc != SQLITE_OK) { + sqlite3_result_error(context, error, -1); + sqlite3_free(error); + return; + } + + switch (elementType) { + case SQLITE_VEC_ELEMENT_TYPE_BIT: { + sqlite3_result_error( + context, "Cannot calculate cosine distance between two bitvectors.", + -1); + goto finish; + } + case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { + float result = distance_cosine_float(a, b, &dimensions); + sqlite3_result_double(context, result); + goto finish; + } + case SQLITE_VEC_ELEMENT_TYPE_INT8: { + float result = distance_cosine_int8(a, b, &dimensions); + sqlite3_result_double(context, result); + goto finish; + } + } + +finish: + aCleanup(a); + bCleanup(b); + return; +} + +static void vec_distance_l2(sqlite3_context *context, int argc, + sqlite3_value **argv) { + todo_assert(argc == 2); + int rc; + void *a, *b; + size_t dimensions; + vector_cleanup aCleanup, bCleanup; + char *error; + enum VectorElementType elementType; + rc = ensure_vector_match(argv[0], argv[1], &a, &b, &elementType, &dimensions, + &aCleanup, &bCleanup, &error); + if (rc != SQLITE_OK) { + sqlite3_result_error(context, error, -1); + sqlite3_free(error); + return; + } + + switch (elementType) { + case SQLITE_VEC_ELEMENT_TYPE_BIT: { + sqlite3_result_error( + context, "Cannot calculate L2 distance between two bitvectors.", -1); + goto finish; + } + case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { + float result = distance_l2_sqr_float(a, b, &dimensions); + sqlite3_result_double(context, result); + goto finish; + } + case SQLITE_VEC_ELEMENT_TYPE_INT8: { + float result = distance_l2_sqr_int8(a, b, &dimensions); + sqlite3_result_double(context, result); + goto finish; + } + } + +finish: + aCleanup(a); + bCleanup(b); + return; +} +static void vec_distance_hamming(sqlite3_context *context, int argc, + sqlite3_value **argv) { + todo_assert(argc == 2); + int rc; + void *a, *b; + size_t dimensions; + vector_cleanup aCleanup, bCleanup; + char *error; + enum VectorElementType elementType; + rc = ensure_vector_match(argv[0], argv[1], &a, &b, &elementType, &dimensions, + &aCleanup, &bCleanup, &error); + if (rc != SQLITE_OK) { + sqlite3_result_error(context, error, -1); + sqlite3_free(error); + return; + } + + switch (elementType) { + case SQLITE_VEC_ELEMENT_TYPE_BIT: { + sqlite3_result_double(context, distance_hamming(a, b, &dimensions)); + goto finish; + } + case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { + sqlite3_result_error( + context, + "Cannot calculate hamming distance between two float32 vectors.", -1); + goto finish; + } + case SQLITE_VEC_ELEMENT_TYPE_INT8: { + sqlite3_result_error( + context, "Cannot calculate hamming distance between two int8 vectors.", + -1); + goto finish; + } + } + +finish: + aCleanup(a); + bCleanup(b); + return; +} + +static void vec_quantize_i8(sqlite3_context *context, int argc, + sqlite3_value **argv) { + float *srcVector; + size_t dimensions; + fvec_cleanup cleanup; + char *err; + int rc = fvec_from_value(argv[0], &srcVector, &dimensions, &cleanup, &err); + assert(rc == SQLITE_OK); + int8_t *out = sqlite3_malloc(dimensions * sizeof(int8_t)); + assert(out); + + if (argc == 2) { + if ((sqlite3_value_type(argv[1]) != SQLITE_TEXT) || + (sqlite3_value_bytes(argv[1]) != strlen("unit")) || + (sqlite3_stricmp((const char *)sqlite3_value_text(argv[1]), "unit") != + 0)) { + sqlite3_result_error(context, + "2nd argument to vec_quantize_i8() must be 'unit', " + "or ranges must be provided.", + -1); + cleanup(srcVector); + sqlite3_free(out); + return; + } + float step = (1.0 - (-1.0)) / 255; + for (size_t i = 0; i < dimensions; i++) { + out[i] = ((srcVector[i] - (-1.0)) / step) - 128; + } + } else if (argc == 3) { + // float * minVector, maxVector; + // size_t d; + // fvec_cleanup minCleanup, maxCleanup; + // int rc = fvec_from_value(argv[1], ) + todo("ranges"); + } + + cleanup(srcVector); + sqlite3_result_blob(context, out, dimensions * sizeof(int8_t), sqlite3_free); + sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_INT8); + return; +} + +static void vec_quantize_binary(sqlite3_context *context, int argc, + sqlite3_value **argv) { + todo_assert(argc == 1); + void *vector; + size_t dimensions; + vector_cleanup cleanup; + char *pzError; + enum VectorElementType elementType; + int rc = vector_from_value(argv[0], &vector, &dimensions, &elementType, + &cleanup, &pzError); + if (rc != SQLITE_OK) { + sqlite3_result_error(context, pzError, -1); + sqlite3_free(pzError); + return; + } + + if (elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) { + uint8_t *out = sqlite3_malloc(dimensions / CHAR_BIT); + todo_assert(out); + for (size_t i = 0; i < dimensions; i++) { + int res = ((float *)vector)[i] > 0.0; + out[i / 8] |= (res << (i % 8)); + } + sqlite3_result_blob(context, out, dimensions / CHAR_BIT, sqlite3_free); + sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_BIT); + } else if (elementType == SQLITE_VEC_ELEMENT_TYPE_INT8) { + uint8_t *out = sqlite3_malloc(dimensions / CHAR_BIT); + todo_assert(out); + for (size_t i = 0; i < dimensions; i++) { + int res = ((int8_t *)vector)[i] > 0; + out[i / 8] |= (res << (i % 8)); + } + sqlite3_result_blob(context, out, dimensions / CHAR_BIT, sqlite3_free); + sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_BIT); + } else { + todo("wut"); + } +} + +static void vec_add(sqlite3_context *context, int argc, sqlite3_value **argv) { + todo_assert(argc == 2); + int rc; + void *a, *b; + size_t dimensions; + vector_cleanup aCleanup, bCleanup; + char *error; + enum VectorElementType elementType; + rc = ensure_vector_match(argv[0], argv[1], &a, &b, &elementType, &dimensions, + &aCleanup, &bCleanup, &error); + if (rc != SQLITE_OK) { + sqlite3_result_error(context, error, -1); + sqlite3_free(error); + return; + } + + switch (elementType) { + case SQLITE_VEC_ELEMENT_TYPE_BIT: { + sqlite3_result_error(context, "Cannot add two bitvectors together.", -1); + goto finish; + } + case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { + size_t outSize = dimensions * sizeof(float); + float *out = sqlite3_malloc(outSize); + if (!out) { + sqlite3_result_error_nomem(context); + goto finish; + } + for (size_t i = 0; i < dimensions; i++) { + out[i] = ((float *)a)[i] + ((float *)b)[i]; + } + sqlite3_result_blob(context, out, outSize, sqlite3_free); + sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_FLOAT32); + goto finish; + } + case SQLITE_VEC_ELEMENT_TYPE_INT8: { + size_t outSize = dimensions * sizeof(int8_t); + int8_t *out = sqlite3_malloc(outSize); + if (!out) { + sqlite3_result_error_nomem(context); + goto finish; + } + for (size_t i = 0; i < dimensions; i++) { + out[i] = ((int8_t *)a)[i] + ((int8_t *)b)[i]; + } + sqlite3_result_blob(context, out, outSize, sqlite3_free); + sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_INT8); + goto finish; + } + } +finish: + aCleanup(a); + bCleanup(b); + return; +} +static void vec_sub(sqlite3_context *context, int argc, sqlite3_value **argv) { + todo_assert(argc == 2); + int rc; + void *a, *b; + size_t dimensions; + vector_cleanup aCleanup, bCleanup; + char *error; + enum VectorElementType elementType; + rc = ensure_vector_match(argv[0], argv[1], &a, &b, &elementType, &dimensions, + &aCleanup, &bCleanup, &error); + if (rc != SQLITE_OK) { + sqlite3_result_error(context, error, -1); + sqlite3_free(error); + return; + } + + switch (elementType) { + case SQLITE_VEC_ELEMENT_TYPE_BIT: { + sqlite3_result_error(context, "Cannot subtract two bitvectors together.", + -1); + goto finish; + } + case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { + size_t outSize = dimensions * sizeof(float); + float *out = sqlite3_malloc(outSize); + if (!out) { + sqlite3_result_error_nomem(context); + goto finish; + } + for (size_t i = 0; i < dimensions; i++) { + out[i] = ((float *)a)[i] - ((float *)b)[i]; + } + sqlite3_result_blob(context, out, outSize, sqlite3_free); + sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_FLOAT32); + goto finish; + } + case SQLITE_VEC_ELEMENT_TYPE_INT8: { + size_t outSize = dimensions * sizeof(int8_t); + int8_t *out = sqlite3_malloc(outSize); + if (!out) { + sqlite3_result_error_nomem(context); + goto finish; + } + for (size_t i = 0; i < dimensions; i++) { + out[i] = ((int8_t *)a)[i] - ((int8_t *)b)[i]; + } + sqlite3_result_blob(context, out, outSize, sqlite3_free); + sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_INT8); + goto finish; + } + } +finish: + aCleanup(a); + bCleanup(b); + return; +} +static void vec_slice(sqlite3_context *context, int argc, + sqlite3_value **argv) { + todo_assert(argc == 3); + + void *vector; + size_t dimensions; + vector_cleanup cleanup; + char *err; + enum VectorElementType elementType; + + int rc = vector_from_value(argv[0], &vector, &dimensions, &elementType, + &cleanup, &err); + if (rc != SQLITE_OK) { + sqlite3_result_error(context, err, -1); + sqlite3_free(err); + return; + } + + int start = sqlite3_value_int(argv[1]); + int end = sqlite3_value_int(argv[2]); + if (start < 0) { + sqlite3_result_error(context, + "slice 'start' index must be a postive number.", -1); + goto done; + } + if (end < 0) { + sqlite3_result_error(context, "slice 'end' index must be a postive number.", + -1); + goto done; + } + if (((size_t)start) > dimensions) { + sqlite3_result_error( + context, "slice 'start' index is greater than the number of dimensions", + -1); + goto done; + } + if (((size_t)end) > dimensions) { + sqlite3_result_error( + context, "slice 'end' index is greater than the number of dimensions", + -1); + goto done; + } + if (start > end) { + sqlite3_result_error(context, + "slice 'start' index is greater than 'end' index", -1); + goto done; + } + // TODO check start == end + size_t n = end - start; + + switch (elementType) { + case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { + float *out = sqlite3_malloc(n * sizeof(float)); + if (!out) { + sqlite3_result_error_nomem(context); + return; + } + for (size_t i = 0; i < n; i++) { + out[i] = ((float *)vector)[start + i]; + } + sqlite3_result_blob(context, out, n * sizeof(float), sqlite3_free); + sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_FLOAT32); + goto done; + } + case SQLITE_VEC_ELEMENT_TYPE_INT8: { + int8_t *out = sqlite3_malloc(n * sizeof(int8_t)); + if (!out) { + sqlite3_result_error_nomem(context); + return; + } + for (size_t i = 0; i < n; i++) { + out[i] = ((int8_t *)vector)[start + i]; + } + sqlite3_result_blob(context, out, n * sizeof(int8_t), sqlite3_free); + sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_INT8); + goto done; + } + case SQLITE_VEC_ELEMENT_TYPE_BIT: { + if ((start % CHAR_BIT) != 0) { + sqlite3_result_error(context, "start index must be divisible by 8.", -1); + goto done; + } + if ((end % CHAR_BIT) != 0) { + sqlite3_result_error(context, "end index must be divisible by 8.", -1); + goto done; + } + + uint8_t *out = sqlite3_malloc(n / CHAR_BIT); + if (!out) { + sqlite3_result_error_nomem(context); + return; + } + for (size_t i = 0; i < n / CHAR_BIT; i++) { + out[i] = ((uint8_t *)vector)[(start / CHAR_BIT) + i]; + } + sqlite3_result_blob(context, out, n / CHAR_BIT, sqlite3_free); + sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_BIT); + goto done; + } + } +done: + cleanup(vector); +} + +static void vec_to_json(sqlite3_context *context, int argc, + sqlite3_value **argv) { + todo_assert(argc == 1); + void *vector; + size_t dimensions; + vector_cleanup cleanup; + char *err; + enum VectorElementType elementType; + + int rc = vector_from_value(argv[0], &vector, &dimensions, &elementType, + &cleanup, &err); + if (rc != SQLITE_OK) { + sqlite3_result_error(context, err, -1); + sqlite3_free(err); + return; + } + + sqlite3_str *str = sqlite3_str_new(sqlite3_context_db_handle(context)); + sqlite3_str_appendall(str, "["); + for (size_t i = 0; i < dimensions; i++) { + if (i != 0) { + sqlite3_str_appendall(str, ","); + } + if (elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) { + sqlite3_str_appendf(str, "%f", ((float *)vector)[i]); + } else if (elementType == SQLITE_VEC_ELEMENT_TYPE_INT8) { + sqlite3_str_appendf(str, "%d", ((int8_t *)vector)[i]); + } else if (elementType == SQLITE_VEC_ELEMENT_TYPE_BIT) { + uint8_t b = (((uint8_t *)vector)[i / 8] >> (i % CHAR_BIT)) & 1; + sqlite3_str_appendf(str, "%d", b); + } + } + sqlite3_str_appendall(str, "]"); + int len = sqlite3_str_length(str); + char *s = sqlite3_str_finish(str); + if (s) { + sqlite3_result_text(context, s, len, sqlite3_free); + } else { + sqlite3_result_error_nomem(context); + } + cleanup(vector); +} + +static void vec_normalize(sqlite3_context *context, int argc, + sqlite3_value **argv) { + todo_assert(argc == 1); + void *vector; + size_t dimensions; + vector_cleanup cleanup; + char *err; + enum VectorElementType elementType; + + int rc = vector_from_value(argv[0], &vector, &dimensions, &elementType, + &cleanup, &err); + if (rc != SQLITE_OK) { + sqlite3_result_error(context, err, -1); + sqlite3_free(err); + return; + } + + if (elementType != SQLITE_VEC_ELEMENT_TYPE_FLOAT32) { + sqlite3_result_error( + context, "only float32 vectors are supported when normalizing", -1); + cleanup(vector); + return; + } + + float *out = sqlite3_malloc(dimensions * sizeof(float)); + todo_assert(out); + float *v = (float *)vector; + + float norm = 0; + for (size_t i = 0; i < dimensions; i++) { + norm += v[i] * v[i]; + } + norm = sqrt(norm); + for (size_t i = 0; i < dimensions; i++) { + out[i] = v[i] / norm; + } + + sqlite3_result_blob(context, out, dimensions * sizeof(float), sqlite3_free); + sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_FLOAT32); +} + +static void _static_text_func(sqlite3_context *context, int argc, + sqlite3_value **argv) { + UNUSED_PARAMETER(argc); + UNUSED_PARAMETER(argv); + sqlite3_result_text(context, sqlite3_user_data(context), -1, SQLITE_STATIC); +} + +enum Vec0TokenType { + TOKEN_TYPE_IDENTIFIER, + TOKEN_TYPE_DIGIT, + TOKEN_TYPE_LBRACKET, + TOKEN_TYPE_RBRACKET, + TOKEN_TYPE_EQ, +}; +struct Vec0Token { + enum Vec0TokenType token_type; + char *start; + char *end; +}; + +int is_alpha(char x) { + return (x >= 'a' && x <= 'z') || (x >= 'A' && x <= 'Z'); +} +int is_digit(char x) { return (x >= '0' && x <= '9'); } +int is_whitespace(char x) { + return x == ' ' || x == '\t' || x == '\n' || x == '\r'; +} + +#define VEC0_TOKEN_RESULT_EOF 1 +#define VEC0_TOKEN_RESULT_SOME 2 +#define VEC0_TOKEN_RESULT_ERROR 3 + +int vec0_token_next(char *start, char *end, struct Vec0Token *out) { + char *ptr = start; + while (ptr < end) { + char curr = *ptr; + if (is_whitespace(curr)) { + ptr++; + continue; + } else if (curr == '[') { + ptr++; + out->start = ptr; + out->end = ptr; + out->token_type = TOKEN_TYPE_LBRACKET; + return VEC0_TOKEN_RESULT_SOME; + } else if (curr == ']') { + ptr++; + out->start = ptr; + out->end = ptr; + out->token_type = TOKEN_TYPE_RBRACKET; + return VEC0_TOKEN_RESULT_SOME; + } else if (curr == '=') { + ptr++; + out->start = ptr; + out->end = ptr; + out->token_type = TOKEN_TYPE_EQ; + return VEC0_TOKEN_RESULT_SOME; + } else if (is_alpha(curr)) { + char *start = ptr; + while (ptr < end && (is_alpha(*ptr) || is_digit(*ptr) || *ptr == '_')) { + ptr++; + } + out->start = start; + out->end = ptr; + out->token_type = TOKEN_TYPE_IDENTIFIER; + return VEC0_TOKEN_RESULT_SOME; + } else if (is_digit(curr)) { + char *start = ptr; + while (ptr < end && (is_digit(*ptr))) { + ptr++; + } + out->start = start; + out->end = ptr; + out->token_type = TOKEN_TYPE_DIGIT; + return VEC0_TOKEN_RESULT_SOME; + } else { + return VEC0_TOKEN_RESULT_ERROR; + } + } + return VEC0_TOKEN_RESULT_EOF; +} + +struct Vec0Scanner { + char *start; + char *end; + char *ptr; +}; + +void vec0_scanner_init(struct Vec0Scanner *scanner, const char *source, + int source_length) { + scanner->start = (char *)source; + scanner->end = (char *)source + source_length; + scanner->ptr = (char *)source; +} +int vec0_scanner_next(struct Vec0Scanner *scanner, struct Vec0Token *out) { + int rc = vec0_token_next(scanner->start, scanner->end, out); + if (rc == VEC0_TOKEN_RESULT_SOME) { + scanner->start = out->end; + } + return rc; +} + +int vec0_parse_table_option(const char *source, int source_length, + char **out_key, int *out_key_length, + char **out_value, int *out_value_length) { + int rc; + struct Vec0Scanner scanner; + struct Vec0Token token; + char *key; + char *value; + int keyLength, valueLength; + + vec0_scanner_init(&scanner, source, source_length); + + rc = vec0_scanner_next(&scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME && + token.token_type != TOKEN_TYPE_IDENTIFIER) { + return SQLITE_EMPTY; + } + key = token.start; + keyLength = token.end - token.start; + + rc = vec0_scanner_next(&scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME && token.token_type != TOKEN_TYPE_EQ) { + return SQLITE_EMPTY; + } + + rc = vec0_scanner_next(&scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME && + !((token.token_type == TOKEN_TYPE_IDENTIFIER) || + (token.token_type == TOKEN_TYPE_DIGIT))) { + return SQLITE_EMPTY; + } + value = token.start; + valueLength = token.end - token.start; + + rc = vec0_scanner_next(&scanner, &token); + if (rc == VEC0_TOKEN_RESULT_EOF) { + *out_key = key; + *out_key_length = keyLength; + *out_value = value; + *out_value_length = valueLength; + return SQLITE_OK; + } + return SQLITE_ERROR; +} +/** + * @brief Parse an argv[i] entry of a vec0 virtual table definition, and see if + * it's a PRIMARY KEY definition. + * + * @param source: argv[i] source string + * @param source_length: length of the source string + * @param out_column_name: If it is a PK, the output column name. Same lifetime + * as source, points to specific char * + * @param out_column_name_length: Length of out_column_name in bytes + * @param out_column_type: SQLITE_TEXT or SQLITE_INTEGER. + * @return int: SQLITE_EMPTY if not a PK, SQLITE_OK if it is. + */ +int parse_primary_key_definition(const char *source, int source_length, + char **out_column_name, + int *out_column_name_length, + int *out_column_type) { + struct Vec0Scanner scanner; + struct Vec0Token token; + char *column_name; + int column_name_length; + int column_type; + vec0_scanner_init(&scanner, source, source_length); + + // Check first token is identifier, will be the column name + int rc = vec0_scanner_next(&scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME && + token.token_type != TOKEN_TYPE_IDENTIFIER) { + return SQLITE_EMPTY; + } + + column_name = token.start; + column_name_length = token.end - token.start; + + // Check the next token matches "text" or "integer", as column type + rc = vec0_scanner_next(&scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME && + token.token_type != TOKEN_TYPE_IDENTIFIER) { + return SQLITE_EMPTY; + } + if (sqlite3_strnicmp(token.start, "text", token.end - token.start) == 0) { + column_type = SQLITE_TEXT; + } else if (sqlite3_strnicmp(token.start, "int", token.end - token.start) == + 0 || + sqlite3_strnicmp(token.start, "integer", + token.end - token.start) == 0) { + column_type = SQLITE_INTEGER; + } else { + return SQLITE_EMPTY; + } + + // Check the next token is identifier and matches "primary" + rc = vec0_scanner_next(&scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME && + token.token_type != TOKEN_TYPE_IDENTIFIER) { + return SQLITE_EMPTY; + } + if (sqlite3_strnicmp(token.start, "primary", token.end - token.start) != 0) { + return SQLITE_EMPTY; + } + + // Check the next token is identifier and matches "key" + rc = vec0_scanner_next(&scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME && + token.token_type != TOKEN_TYPE_IDENTIFIER) { + return SQLITE_EMPTY; + } + if (sqlite3_strnicmp(token.start, "key", token.end - token.start) != 0) { + return SQLITE_EMPTY; + } + + *out_column_name = column_name; + *out_column_name_length = column_name_length; + *out_column_type = column_type; + + return SQLITE_OK; +} + +enum Vec0DistanceMetrics { + VEC0_DISTANCE_METRIC_L2 = 1, + VEC0_DISTANCE_METRIC_COSINE = 2, +}; + +struct VectorColumnDefinition { + char *name; + int name_length; + size_t dimensions; + enum VectorElementType element_type; + enum Vec0DistanceMetrics distance_metric; +}; + +size_t vector_column_byte_size(struct VectorColumnDefinition column) { + switch (column.element_type) { + case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: + return column.dimensions * sizeof(float); + case SQLITE_VEC_ELEMENT_TYPE_INT8: + return column.dimensions * sizeof(int8_t); + case SQLITE_VEC_ELEMENT_TYPE_BIT: + return column.dimensions / CHAR_BIT; + } +} + +int parse_vector_column(const char *source, int source_length, + struct VectorColumnDefinition *column_def) { + // parses a vector column definition like so: + // "abc float[123]", "abc_123 bit[1234]", eetc. + struct Vec0Scanner scanner; + struct Vec0Token token; + + vec0_scanner_init(&scanner, source, source_length); + + int rc = vec0_scanner_next(&scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME && + token.token_type != TOKEN_TYPE_IDENTIFIER) { + return SQLITE_ERROR; + } + + column_def->name = token.start; + column_def->name_length = token.end - token.start; + column_def->distance_metric = VEC0_DISTANCE_METRIC_L2; + + rc = vec0_scanner_next(&scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME && + token.token_type != TOKEN_TYPE_IDENTIFIER) { + return SQLITE_ERROR; + } + if (sqlite3_strnicmp(token.start, "float", token.end - token.start) == 0 || + sqlite3_strnicmp(token.start, "f32", token.end - token.start) == 0) { + column_def->element_type = SQLITE_VEC_ELEMENT_TYPE_FLOAT32; + } else if (sqlite3_strnicmp(token.start, "int8", token.end - token.start) == + 0 || + sqlite3_strnicmp(token.start, "i8", token.end - token.start) == + 0) { + column_def->element_type = SQLITE_VEC_ELEMENT_TYPE_INT8; + } else if (sqlite3_strnicmp(token.start, "bit", token.end - token.start) == + 0) { + column_def->element_type = SQLITE_VEC_ELEMENT_TYPE_BIT; + } else { + return SQLITE_ERROR; + } + + rc = vec0_scanner_next(&scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME && token.token_type != TOKEN_TYPE_LBRACKET) { + return SQLITE_ERROR; + } + + rc = vec0_scanner_next(&scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME && token.token_type != TOKEN_TYPE_DIGIT) { + return SQLITE_ERROR; + } + column_def->dimensions = atoi(token.start); + + rc = vec0_scanner_next(&scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME && token.token_type != TOKEN_TYPE_RBRACKET) { + return SQLITE_ERROR; + } + + // any other tokens left should be column-level options , ex `key=value` + // TODO make sure options are defined only once. ex `distance_metric=L2 + // distance_metric=cosine` should error + while (1) { + rc = vec0_scanner_next(&scanner, &token); + if (rc == VEC0_TOKEN_RESULT_EOF) { + return SQLITE_OK; + } + + if (rc != VEC0_TOKEN_RESULT_SOME && + token.token_type != TOKEN_TYPE_IDENTIFIER) { + return SQLITE_ERROR; + } + + char *key = token.start; + int keyLength = token.end - token.start; + + if (sqlite3_strnicmp(key, "distance_metric", keyLength) == 0) { + + if (column_def->element_type == SQLITE_VEC_ELEMENT_TYPE_BIT) { + return SQLITE_ERROR; + } + + rc = vec0_scanner_next(&scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME && token.token_type != TOKEN_TYPE_EQ) { + return SQLITE_ERROR; + } + + rc = vec0_scanner_next(&scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME && + token.token_type != TOKEN_TYPE_IDENTIFIER) { + return SQLITE_ERROR; + } + + char *value = token.start; + int valueLength = token.end - token.start; + if (sqlite3_strnicmp(value, "l2", valueLength) == 0) { + column_def->distance_metric = VEC0_DISTANCE_METRIC_L2; + } else if (sqlite3_strnicmp(value, "cosine", valueLength) == 0) { + column_def->distance_metric = VEC0_DISTANCE_METRIC_COSINE; + } else { + return SQLITE_ERROR; + } + } + // unknown option key + else { + return SQLITE_ERROR; + } + } +} + +#pragma region vec_each table function + +typedef struct vec_each_vtab vec_each_vtab; +struct vec_each_vtab { + sqlite3_vtab base; +}; + +typedef struct vec_each_cursor vec_each_cursor; +struct vec_each_cursor { + sqlite3_vtab_cursor base; + sqlite3_int64 iRowid; + enum VectorElementType vector_type; + void *vector; + size_t dimensions; + vector_cleanup cleanup; +}; + +static int vec_eachConnect(sqlite3 *db, void *pAux, int argc, + const char *const *argv, sqlite3_vtab **ppVtab, + char **pzErr) { + UNUSED_PARAMETER(pAux); + UNUSED_PARAMETER(argc); + UNUSED_PARAMETER(argv); + UNUSED_PARAMETER(pzErr); // TODO use + vec_each_vtab *pNew; + int rc; + + rc = sqlite3_declare_vtab(db, "CREATE TABLE x(value, vector hidden)"); +#define VEC_EACH_COLUMN_VALUE 0 +#define VEC_EACH_COLUMN_VECTOR 1 + if (rc == SQLITE_OK) { + pNew = sqlite3_malloc(sizeof(*pNew)); + *ppVtab = (sqlite3_vtab *)pNew; + if (pNew == 0) + return SQLITE_NOMEM; + memset(pNew, 0, sizeof(*pNew)); + } + return rc; +} + +static int vec_eachDisconnect(sqlite3_vtab *pVtab) { + vec_each_vtab *p = (vec_each_vtab *)pVtab; + sqlite3_free(p); + return SQLITE_OK; +} + +static int vec_eachOpen(sqlite3_vtab *p, sqlite3_vtab_cursor **ppCursor) { + UNUSED_PARAMETER(p); + vec_each_cursor *pCur; + pCur = sqlite3_malloc(sizeof(*pCur)); + if (pCur == 0) + return SQLITE_NOMEM; + memset(pCur, 0, sizeof(*pCur)); + *ppCursor = &pCur->base; + return SQLITE_OK; +} + +static int vec_eachClose(sqlite3_vtab_cursor *cur) { + vec_each_cursor *pCur = (vec_each_cursor *)cur; + sqlite3_free(pCur); + return SQLITE_OK; +} + +static int vec_eachBestIndex(sqlite3_vtab *pVTab, + sqlite3_index_info *pIdxInfo) { + int hasVector; + for (int i = 0; i < pIdxInfo->nConstraint; i++) { + const struct sqlite3_index_constraint *pCons = &pIdxInfo->aConstraint[i]; + // printf("i=%d iColumn=%d, op=%d, usable=%d\n", i, pCons->iColumn, + // pCons->op, pCons->usable); + switch (pCons->iColumn) { + case VEC_EACH_COLUMN_VECTOR: { + if (pCons->op == SQLITE_INDEX_CONSTRAINT_EQ && pCons->usable) { + hasVector = 1; + pIdxInfo->aConstraintUsage[i].argvIndex = 1; + pIdxInfo->aConstraintUsage[i].omit = 1; + } + break; + } + } + } + if (!hasVector) { + pVTab->zErrMsg = sqlite3_mprintf("vector argument is required"); + return SQLITE_ERROR; + } + + pIdxInfo->estimatedCost = (double)100000; + pIdxInfo->estimatedRows = 100000; + + return SQLITE_OK; +} + +static int vec_eachFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum, + const char *idxStr, int argc, sqlite3_value **argv) { + UNUSED_PARAMETER(idxNum); + UNUSED_PARAMETER(idxStr); + todo_assert(argc == 1); + vec_each_cursor *pCur = (vec_each_cursor *)pVtabCursor; + + char *pzErrMsg; + int rc = vector_from_value(argv[0], &pCur->vector, &pCur->dimensions, + &pCur->vector_type, &pCur->cleanup, &pzErrMsg); + if (rc != SQLITE_OK) { + return SQLITE_ERROR; + } + pCur->iRowid = 0; + return SQLITE_OK; +} + +static int vec_eachRowid(sqlite3_vtab_cursor *cur, sqlite_int64 *pRowid) { + vec_each_cursor *pCur = (vec_each_cursor *)cur; + *pRowid = pCur->iRowid; + return SQLITE_OK; +} + +static int vec_eachEof(sqlite3_vtab_cursor *cur) { + vec_each_cursor *pCur = (vec_each_cursor *)cur; + return pCur->iRowid >= (sqlite3_int64)pCur->dimensions; +} + +static int vec_eachNext(sqlite3_vtab_cursor *cur) { + vec_each_cursor *pCur = (vec_each_cursor *)cur; + pCur->iRowid++; + return SQLITE_OK; +} + +static int vec_eachColumn(sqlite3_vtab_cursor *cur, sqlite3_context *context, + int i) { + vec_each_cursor *pCur = (vec_each_cursor *)cur; + switch (i) { + case VEC_EACH_COLUMN_VALUE: + switch (pCur->vector_type) { + case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { + sqlite3_result_double(context, ((float *)pCur->vector)[pCur->iRowid]); + break; + } + case SQLITE_VEC_ELEMENT_TYPE_BIT: { + uint8_t x = ((uint8_t *)pCur->vector)[pCur->iRowid / CHAR_BIT]; + sqlite3_result_int(context, + (x & (0b10000000 >> ((pCur->iRowid % CHAR_BIT)))) > 0); + break; + } + case SQLITE_VEC_ELEMENT_TYPE_INT8: { + sqlite3_result_int(context, ((int8_t *)pCur->vector)[pCur->iRowid]); + break; + } + } + + break; + } + return SQLITE_OK; +} + +static sqlite3_module vec_eachModule = { + /* iVersion */ 0, + /* xCreate */ 0, + /* xConnect */ vec_eachConnect, + /* xBestIndex */ vec_eachBestIndex, + /* xDisconnect */ vec_eachDisconnect, + /* xDestroy */ 0, + /* xOpen */ vec_eachOpen, + /* xClose */ vec_eachClose, + /* xFilter */ vec_eachFilter, + /* xNext */ vec_eachNext, + /* xEof */ vec_eachEof, + /* xColumn */ vec_eachColumn, + /* xRowid */ vec_eachRowid, + /* xUpdate */ 0, + /* xBegin */ 0, + /* xSync */ 0, + /* xCommit */ 0, + /* xRollback */ 0, + /* xFindMethod */ 0, + /* xRename */ 0, + /* xSavepoint */ 0, + /* xRelease */ 0, + /* xRollbackTo */ 0, + /* xShadowName */ 0, + /* xIntegrity */ 0}; + +#pragma endregion + +#pragma region vec_npy_each table function + +static unsigned char NPY_MAGIC[6] = "\x93NUMPY"; + +enum NpyTokenType { + NPY_TOKEN_TYPE_IDENTIFIER, + NPY_TOKEN_TYPE_NUMBER, + NPY_TOKEN_TYPE_LPAREN, + NPY_TOKEN_TYPE_RPAREN, + NPY_TOKEN_TYPE_LBRACE, + NPY_TOKEN_TYPE_RBRACE, + NPY_TOKEN_TYPE_COLON, + NPY_TOKEN_TYPE_COMMA, + NPY_TOKEN_TYPE_STRING, + NPY_TOKEN_TYPE_FALSE, +}; + +struct NpyToken { + enum NpyTokenType token_type; + unsigned char *start; + unsigned char *end; +}; + +int npy_token_next(unsigned char *start, unsigned char *end, + struct NpyToken *out) { + unsigned char *ptr = start; + while (ptr < end) { + unsigned char curr = *ptr; + if (is_whitespace(curr)) { + ptr++; + continue; + } else if (curr == '(') { + out->start = ptr++; + out->end = ptr; + out->token_type = NPY_TOKEN_TYPE_LPAREN; + return VEC0_TOKEN_RESULT_SOME; + } else if (curr == ')') { + out->start = ptr++; + out->end = ptr; + out->token_type = NPY_TOKEN_TYPE_RPAREN; + return VEC0_TOKEN_RESULT_SOME; + } else if (curr == '{') { + out->start = ptr++; + out->end = ptr; + out->token_type = NPY_TOKEN_TYPE_LBRACE; + return VEC0_TOKEN_RESULT_SOME; + } else if (curr == '}') { + out->start = ptr++; + out->end = ptr; + out->token_type = NPY_TOKEN_TYPE_RBRACE; + return VEC0_TOKEN_RESULT_SOME; + } else if (curr == ':') { + out->start = ptr++; + out->end = ptr; + out->token_type = NPY_TOKEN_TYPE_COLON; + return VEC0_TOKEN_RESULT_SOME; + } else if (curr == ',') { + out->start = ptr++; + out->end = ptr; + out->token_type = NPY_TOKEN_TYPE_COMMA; + return VEC0_TOKEN_RESULT_SOME; + } else if (curr == '\'') { + unsigned char *start = ptr; + ptr++; + while (ptr < end) { + if ((*ptr) == '\'') { + break; + } + ptr++; + } + if ((*ptr) != '\'') { + return VEC0_TOKEN_RESULT_ERROR; + } + out->start = start; + out->end = ++ptr; + out->token_type = NPY_TOKEN_TYPE_STRING; + return VEC0_TOKEN_RESULT_SOME; + } else if (curr == 'F' && + strncmp((char *)ptr, "False", strlen("False")) == 0) { + out->start = ptr; + out->end = (ptr + (int)strlen("False")); + ptr = out->end; + out->token_type = NPY_TOKEN_TYPE_FALSE; + return VEC0_TOKEN_RESULT_SOME; + } else if (is_digit(curr)) { + unsigned char *start = ptr; + while (ptr < end && (is_digit(*ptr))) { + ptr++; + } + out->start = start; + out->end = ptr; + out->token_type = NPY_TOKEN_TYPE_NUMBER; + return VEC0_TOKEN_RESULT_SOME; + } else { + return VEC0_TOKEN_RESULT_ERROR; + } + } + return VEC0_TOKEN_RESULT_ERROR; +} + +struct NpyScanner { + unsigned char *start; + unsigned char *end; + unsigned char *ptr; +}; + +void npy_scanner_init(struct NpyScanner *scanner, const unsigned char *source, + int source_length) { + scanner->start = (unsigned char *)source; + scanner->end = (unsigned char *)source + source_length; + scanner->ptr = (unsigned char *)source; +} + +int npy_scanner_next(struct NpyScanner *scanner, struct NpyToken *out) { + int rc = npy_token_next(scanner->start, scanner->end, out); + if (rc == VEC0_TOKEN_RESULT_SOME) { + scanner->start = out->end; + } + return rc; +} + +int parse_npy_header(const unsigned char *header, size_t headerLength, + enum VectorElementType *out_element_type, + int *fortran_order, size_t *numElements, + size_t *numDimensions) { + + struct NpyScanner scanner; + struct NpyToken token; + int rc; + npy_scanner_init(&scanner, header, headerLength); + + if (npy_scanner_next(&scanner, &token) != VEC0_TOKEN_RESULT_SOME && + token.token_type != NPY_TOKEN_TYPE_LBRACE) { + return SQLITE_ERROR; + } + while (1) { + rc = npy_scanner_next(&scanner, &token); + todo_assert(rc == VEC0_TOKEN_RESULT_SOME); + if (token.token_type == NPY_TOKEN_TYPE_RBRACE) { + break; + } + todo_assert(token.token_type == NPY_TOKEN_TYPE_STRING); + unsigned char *key = token.start; + // TODO use this in strncmp()? + // int keyLength = token.end - token.start; + + rc = npy_scanner_next(&scanner, &token); + todo_assert(rc == VEC0_TOKEN_RESULT_SOME); + todo_assert(token.token_type == NPY_TOKEN_TYPE_COLON); + + // TODO: strcmp safe? + if (strncmp((char *)key, "'descr'", strlen("'descr'")) == 0) { + rc = npy_scanner_next(&scanner, &token); + todo_assert(rc == VEC0_TOKEN_RESULT_SOME); + todo_assert(token.token_type == NPY_TOKEN_TYPE_STRING); + todo_assert(strncmp((char *)token.start, "' 10); + for (size_t i = 0; i < sizeof(NPY_MAGIC); i++) { + todo_assert(NPY_MAGIC[i] == buffer[i]); + } + uint8_t major = buffer[6]; + uint8_t minor = buffer[7]; + uint16_t headerLength = 0; + memcpy(&headerLength, &buffer[8], sizeof(uint16_t)); + + const unsigned char *header = &buffer[10]; + + // printf("npy: headerLength=%zu major=%d minor=%d headerLen=%d\n", + // bufferLength, major, minor, headerLength); + size_t totalHeaderLength = sizeof(NPY_MAGIC) + sizeof(major) + sizeof(minor) + + sizeof(headerLength) + headerLength; + size_t dataSize = bufferLength - totalHeaderLength; + todo_assert(dataSize > 0); + + int fortran_order; + + int rc = parse_npy_header(header, headerLength, element_type, &fortran_order, + numElements, numDimensions); + todo_assert(rc == SQLITE_OK); + + int element_size = 0; + // TODO bit + if (*element_type == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) { + element_size = sizeof(float); + } + todo_assert((*numElements * *numDimensions * element_size) == dataSize); + + *data = (void *)&buffer[totalHeaderLength]; + return SQLITE_OK; +} + +typedef struct vec_npy_each_vtab vec_npy_each_vtab; +struct vec_npy_each_vtab { + sqlite3_vtab base; +}; + +typedef enum { + VEC_NPY_EACH_INPUT_BUFFER, + VEC_NPY_EACH_INPUT_FILE, +} vec_npy_each_input_type; + +typedef struct vec_npy_each_cursor vec_npy_each_cursor; +struct vec_npy_each_cursor { + sqlite3_vtab_cursor base; + sqlite3_int64 iRowid; + // sqlite-vec compatible type of vector + enum VectorElementType elementType; + // number of vectors in the npy array + size_t nElements; + // number of dimensions each vector has + size_t nDimensions; + vec_npy_each_input_type input_type; + + // TODO enum this + + // when input_type == VEC_NPY_EACH_INPUT_BUFFER + + // Buffer containing the vector data, when reading from an in-memory buffer. + // Size: nElements * nDimensions * element_size + // Clean up with sqlite3_free() once complete + void *vector; + + // when input_type == VEC_NPY_EACH_INPUT_FILE + + // Opened npy file, when reading from a file. + // fclose() when complete. + FILE *file; + // an in-memory buffer containing a portion of the npy array. + // Used for faster reading, instead of calling fread() a lot. + // Will have a byte-size of fileBufferSize + void *fileBuffer; + // size of allocated fileBuffer in bytes + size_t fileBufferSize; + // Counter index of the current vector into of fileBuffer to yield. + // Starts at 0 once fileBuffer is read, and iterates to bufferLength. + // Resets to 0 once that "buffer" is yielded and a new one is read. + size_t bufferIndex; + // Maximum length of the buffer, in terms of number of vectors. + size_t bufferLength; + // Size of each element inside the vector. + // Ex: 4 for floats, ex. + int elementSize; + // 0 when there are still more elements to read/yield, 1 when complete. + int eof; +}; + +static int vec_npy_eachConnect(sqlite3 *db, void *pAux, int argc, + const char *const *argv, sqlite3_vtab **ppVtab, + char **pzErr) { + UNUSED_PARAMETER(pAux); + UNUSED_PARAMETER(argc); + UNUSED_PARAMETER(argv); + UNUSED_PARAMETER(pzErr); // TODO use + vec_npy_each_vtab *pNew; + int rc; + + rc = sqlite3_declare_vtab(db, "CREATE TABLE x(vector, input hidden)"); +#define VEC_NPY_EACH_COLUMN_VECTOR 0 +#define VEC_NPY_EACH_COLUMN_INPUT 1 + if (rc == SQLITE_OK) { + pNew = sqlite3_malloc(sizeof(*pNew)); + *ppVtab = (sqlite3_vtab *)pNew; + if (pNew == 0) + return SQLITE_NOMEM; + memset(pNew, 0, sizeof(*pNew)); + } + return rc; +} + +static int vec_npy_eachDisconnect(sqlite3_vtab *pVtab) { + vec_npy_each_vtab *p = (vec_npy_each_vtab *)pVtab; + sqlite3_free(p); + return SQLITE_OK; +} + +static int vec_npy_eachOpen(sqlite3_vtab *p, sqlite3_vtab_cursor **ppCursor) { + UNUSED_PARAMETER(p); + vec_npy_each_cursor *pCur; + pCur = sqlite3_malloc(sizeof(*pCur)); + if (pCur == 0) + return SQLITE_NOMEM; + memset(pCur, 0, sizeof(*pCur)); + *ppCursor = &pCur->base; + return SQLITE_OK; +} + +static int vec_npy_eachClose(sqlite3_vtab_cursor *cur) { + vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)cur; + if (pCur->file) { + fclose(pCur->file); + pCur->file = NULL; + } + if (pCur->fileBuffer) { + sqlite3_free(pCur->fileBuffer); + pCur->fileBuffer = NULL; + } + if (pCur->vector) { + // sqlite3_free(pCur->vector); + pCur->vector = NULL; + } + sqlite3_free(pCur); + return SQLITE_OK; +} + +static int vec_npy_eachBestIndex(sqlite3_vtab *pVTab, + sqlite3_index_info *pIdxInfo) { + int hasInput; + for (int i = 0; i < pIdxInfo->nConstraint; i++) { + const struct sqlite3_index_constraint *pCons = &pIdxInfo->aConstraint[i]; + // printf("i=%d iColumn=%d, op=%d, usable=%d\n", i, pCons->iColumn, + // pCons->op, pCons->usable); + switch (pCons->iColumn) { + case VEC_NPY_EACH_COLUMN_INPUT: { + if (pCons->op == SQLITE_INDEX_CONSTRAINT_EQ && pCons->usable) { + hasInput = 1; + pIdxInfo->aConstraintUsage[i].argvIndex = 1; + pIdxInfo->aConstraintUsage[i].omit = 1; + } + break; + } + } + } + if (!hasInput) { + pVTab->zErrMsg = sqlite3_mprintf("input argument is required"); + return SQLITE_ERROR; + } + + pIdxInfo->estimatedCost = (double)100000; + pIdxInfo->estimatedRows = 100000; + + return SQLITE_OK; +} + +static int vec_npy_eachFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum, + const char *idxStr, int argc, + sqlite3_value **argv) { + UNUSED_PARAMETER(idxNum); + UNUSED_PARAMETER(idxStr); + todo_assert(argc == 1); + vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)pVtabCursor; + + if (pCur->file) { + fclose(pCur->file); + pCur->file = NULL; + } + if (pCur->fileBuffer) { + sqlite3_free(pCur->fileBuffer); + pCur->fileBuffer = NULL; + } + if (pCur->vector) { + // sqlite3_free(pCur->vector); TODO don't need to free this?? + pCur->vector = NULL; + } + + struct VecNpyFile *f = NULL; + + if ((f = sqlite3_value_pointer(argv[0], SQLITE_VEC_NPY_FILE_NAME))) { + int n; + FILE *file = fopen(f->path, "r"); + todo_assert(file); + + fseek(file, 0, SEEK_END); + long fileSize = ftell(file); + + fseek(file, 0L, SEEK_SET); + + unsigned char header[10]; + n = fread(&header, sizeof(unsigned char), 10, file); + todo_assert(n == 10); + + for (size_t i = 0; i < countof(NPY_MAGIC); i++) { + todo_assert(NPY_MAGIC[i] == header[i]); + } + uint8_t major = header[6]; + uint8_t minor = header[7]; + + uint16_t headerLength = 0; + memcpy(&headerLength, &header[8], sizeof(uint16_t)); + + size_t totalHeaderLength = sizeof(NPY_MAGIC) + sizeof(major) + + sizeof(minor) + sizeof(headerLength) + + headerLength; + size_t dataSize = fileSize - totalHeaderLength; + todo_assert(dataSize > 0); + + unsigned char *headerX = sqlite3_malloc(headerLength); + todo_assert(headerX); + n = fread(headerX, sizeof(char), headerLength, file); + todo_assert(n == headerLength); + + int fortran_order; + enum VectorElementType element_type; + size_t numElements; + size_t numDimensions; + int rc = parse_npy_header(headerX, headerLength, &element_type, + &fortran_order, &numElements, &numDimensions); + sqlite3_free(headerX); + todo_assert(rc == SQLITE_OK); + + int element_size = 0; + if (element_type == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) { + element_size = sizeof(float); + } else { + todo("non-f32 numpy array"); + } + + todo_assert((numElements * numDimensions * element_size) == dataSize); + + pCur->bufferIndex = 0; + pCur->bufferLength = 1024; + pCur->elementSize = element_size; + pCur->elementType = element_type; + pCur->nElements = numElements; + pCur->nDimensions = numDimensions; + pCur->fileBufferSize = numDimensions * element_size * pCur->bufferLength; + pCur->fileBuffer = sqlite3_malloc(pCur->fileBufferSize); + todo_assert(pCur->fileBuffer); + pCur->input_type = VEC_NPY_EACH_INPUT_FILE; + n = fread(pCur->fileBuffer, 1, pCur->fileBufferSize, file); + todo_assert((size_t)n == pCur->fileBufferSize); // TODO may be smaller + + pCur->eof = 0; + pCur->file = file; + + } else { + + const unsigned char *input = sqlite3_value_blob(argv[0]); + size_t inputLength = sqlite3_value_bytes(argv[0]); + int rc; + void *data; + size_t numElements; + size_t numDimensions; + enum VectorElementType element_type; + + rc = parse_npy(input, inputLength, &data, &numElements, &numDimensions, + &element_type); + todo_assert(rc == SQLITE_OK); + + pCur->vector = data; + pCur->elementType = element_type; + pCur->nElements = numElements; + pCur->nDimensions = numDimensions; + pCur->input_type = VEC_NPY_EACH_INPUT_BUFFER; + } + + pCur->iRowid = 0; + return SQLITE_OK; +} + +static int vec_npy_eachRowid(sqlite3_vtab_cursor *cur, sqlite_int64 *pRowid) { + vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)cur; + *pRowid = pCur->iRowid; + return SQLITE_OK; +} + +static int vec_npy_eachEof(sqlite3_vtab_cursor *cur) { + vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)cur; + if (pCur->input_type == VEC_NPY_EACH_INPUT_BUFFER) { + return (size_t)pCur->iRowid >= pCur->nElements; + } + return pCur->eof; +} + +static int vec_npy_eachNext(sqlite3_vtab_cursor *cur) { + vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)cur; + pCur->iRowid++; + if (pCur->input_type == VEC_NPY_EACH_INPUT_FILE) { + pCur->bufferIndex++; + if (pCur->bufferIndex >= pCur->bufferLength) { + int n = fread(pCur->fileBuffer, 1, pCur->fileBufferSize, pCur->file); + if (!n) { + pCur->eof = 1; + } + pCur->bufferIndex = 0; + pCur->bufferLength = n / pCur->nDimensions / pCur->elementSize; + } + } + return SQLITE_OK; +} + +static int vec_npy_eachColumn(sqlite3_vtab_cursor *cur, + sqlite3_context *context, int i) { + vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)cur; + switch (i) { + case VEC_NPY_EACH_COLUMN_VECTOR: { + if (pCur->input_type == VEC_NPY_EACH_INPUT_BUFFER) { + switch (pCur->elementType) { + case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { + sqlite3_result_blob( + context, + &pCur->vector[pCur->iRowid * pCur->nDimensions * sizeof(float)], + pCur->nDimensions * sizeof(float), SQLITE_STATIC); + break; + } + case SQLITE_VEC_ELEMENT_TYPE_INT8: + case SQLITE_VEC_ELEMENT_TYPE_BIT: { + todo("bit array npy column"); + break; + } + } + } else if (pCur->input_type == VEC_NPY_EACH_INPUT_FILE) { + switch (pCur->elementType) { + case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { + sqlite3_result_blob( + context, + &pCur->fileBuffer[pCur->bufferIndex * pCur->nDimensions * + sizeof(float)], + pCur->nDimensions * sizeof(float), SQLITE_TRANSIENT); + break; + } + case SQLITE_VEC_ELEMENT_TYPE_INT8: + case SQLITE_VEC_ELEMENT_TYPE_BIT: { + todo("bit array npy column"); + break; + } + } + } + + break; + } + } + return SQLITE_OK; +} + +static sqlite3_module vec_npy_eachModule = { + /* iVersion */ 0, + /* xCreate */ 0, + /* xConnect */ vec_npy_eachConnect, + /* xBestIndex */ vec_npy_eachBestIndex, + /* xDisconnect */ vec_npy_eachDisconnect, + /* xDestroy */ 0, + /* xOpen */ vec_npy_eachOpen, + /* xClose */ vec_npy_eachClose, + /* xFilter */ vec_npy_eachFilter, + /* xNext */ vec_npy_eachNext, + /* xEof */ vec_npy_eachEof, + /* xColumn */ vec_npy_eachColumn, + /* xRowid */ vec_npy_eachRowid, + /* xUpdate */ 0, + /* xBegin */ 0, + /* xSync */ 0, + /* xCommit */ 0, + /* xRollback */ 0, + /* xFindMethod */ 0, + /* xRename */ 0, + /* xSavepoint */ 0, + /* xRelease */ 0, + /* xRollbackTo */ 0, + /* xShadowName */ 0, + /* xIntegrity */ 0}; + +#pragma endregion + +#pragma region vec0 virtual table + +#define VEC0_COLUMN_ID 0 +#define VEC0_COLUMN_VECTORN_START 1 +#define VEC0_COLUMN_OFFSET_DISTANCE 1 +#define VEC0_COLUMN_OFFSET_K 2 + +#define VEC0_SHADOW_CHUNKS_NAME "\"%w\".\"%w_chunks\"" +/// 1) schema, 2) original vtab table name +#define VEC0_SHADOW_CHUNKS_CREATE \ + "CREATE TABLE " VEC0_SHADOW_CHUNKS_NAME "(" \ + "chunk_id INTEGER PRIMARY KEY AUTOINCREMENT," \ + "size INTEGER NOT NULL," \ + "validity BLOB NOT NULL," \ + "rowids BLOB NOT NULL" \ + ");" + +#define VEC0_SHADOW_ROWIDS_NAME "\"%w\".\"%w_rowids\"" +/// 1) schema, 2) original vtab table name +#define VEC0_SHADOW_ROWIDS_CREATE_BASIC \ + "CREATE TABLE " VEC0_SHADOW_ROWIDS_NAME "(" \ + "rowid INTEGER PRIMARY KEY AUTOINCREMENT," \ + "id," \ + "chunk_id INTEGER," \ + "chunk_offset INTEGER" \ + ");" + +// vec0 tables with a text primary keys are still backed by int64 primary keys, +// since a fixed-length rowid is required for vec0 chunks. But we add a new 'id +// text unique' column to emulate a text primary key interface. +#define VEC0_SHADOW_ROWIDS_CREATE_PK_TEXT \ + "CREATE TABLE " VEC0_SHADOW_ROWIDS_NAME "(" \ + "rowid INTEGER PRIMARY KEY AUTOINCREMENT," \ + "id TEXT UNIQUE NOT NULL," \ + "chunk_id INTEGER," \ + "chunk_offset INTEGER" \ + ");" + +/// 1) schema, 2) original vtab table name +#define VEC0_SHADOW_VECTOR_N_NAME "\"%w\".\"%w_vector_chunks%02d\"" + +/// 1) schema, 2) original vtab table name +#define VEC0_SHADOW_VECTOR_N_CREATE \ + "CREATE TABLE " VEC0_SHADOW_VECTOR_N_NAME "(" \ + "rowid PRIMARY KEY," \ + "vectors BLOB NOT NULL" \ + ");" + +typedef struct vec0_vtab vec0_vtab; + +#define VEC0_MAX_VECTOR_COLUMNS 16 +struct vec0_vtab { + sqlite3_vtab base; + + // the SQLite connection of the host database + sqlite3 *db; + + // True if the primary key of the vec0 table has a column type TEXT. + // Will change the schema of the _rowids table, and insert/query logic. + int pkIsText; + + // Name of the schema the table exists on. + // Must be freed with sqlite3_free() + char *schemaName; + + // Name of the table the table exists on. + // Must be freed with sqlite3_free() + char *tableName; + + // Name of the _rowids shadow table. + // Must be freed with sqlite3_free() + char *shadowRowidsName; + + // Name of the _chunks shadow table. + // Must be freed with sqlite3_free() + char *shadowChunksName; + + // Name of all the vector chunk shadow tables. + // Only the first numVectorColumns entries will be available. + // The first numVectorColumns entries must be freed with sqlite3_free() + char *shadowVectorChunksNames[VEC0_MAX_VECTOR_COLUMNS]; + + struct VectorColumnDefinition vector_columns[VEC0_MAX_VECTOR_COLUMNS]; + + // number of defined numVectorColumns columns. + int numVectorColumns; + + int chunk_size; + + // select latest chunk from _chunks, getting chunk_id + sqlite3_stmt *stmtLatestChunk; + + /** + * Statement to insert a row into the _rowids table, with a rowid. + * Parameters: + * 1: int64, rowid to insert + * Result columns: none + * SQL: "INSERT INTO _rowids(rowid) VALUES (?)" + * + * Must be cleaned up with sqlite3_finalize(). + */ + sqlite3_stmt *stmtRowidsInsertRowid; + + /** + * Statement to insert a row into the _rowids table, with an id. + * The id column isn't a tradition primary key, but instead a unique + * column to handle "text primary key" vec0 tables. The true int64 rowid + * can be retrieved after inserting with sqlite3_last_rowid(). + * + * Parameters: + * 1: text or null, id to insert + * Result columns: none + * + * Must be cleaned up with sqlite3_finalize(). + */ + sqlite3_stmt *stmtRowidsInsertId; + + /** + * Statement to update the "position" columns chunk_id and chunk_offset for + * a given _rowids row. Used when the "next available" chunk position is found + * for a vector. + * + * Parameters: + * 1: int64, chunk_id value + * 2: int64, chunk_offset value + * 3: int64, rowid value + * Result columns: none + * + * Must be cleaned up with sqlite3_finalize(). + */ + sqlite3_stmt *stmtRowidsUpdatePosition; + + /** + * Statement to quickly find the chunk_id + chunk_offset of a given row. + * Parameters: + * 1: rowid of the row/vector to lookup + * Result columns: + * 0: chunk_id (sqlite3_int64) + * 1: chunk_offset (sqlite3_int64) + * SQL: "SELECT chunk_id, chunk_offset FROM _rowids WHERE rowid = ?"" + * + * Must be cleaned up with sqlite3_finalize(). + */ + sqlite3_stmt *stmtRowidsGetChunkPosition; + + /** + * Cached SQLite BLOBs for every possible vector column for the table. + * Defined for all vectors up to index numVectorColumns (always <= + * VEC0_MAX_VECTOR_COLUMNS). + * + * Defined from: + * db: p->schemaName + * table: p->shadowVectorChunksNames[i] + * column: "vectors" + * + * Opened at vec0_init() time. + * Must be cleaned up with sqlite3_blob_close() at xDisconnect. + * + */ + sqlite3_blob *vectorBlobs[VEC0_MAX_VECTOR_COLUMNS]; +}; + +int vec0_column_distance_idx(vec0_vtab *pVtab) { + return VEC0_COLUMN_VECTORN_START + (pVtab->numVectorColumns - 1) + + VEC0_COLUMN_OFFSET_DISTANCE; +} +int vec0_column_k_idx(vec0_vtab *pVtab) { + return VEC0_COLUMN_VECTORN_START + (pVtab->numVectorColumns - 1) + + VEC0_COLUMN_OFFSET_K; +} + +/** + * Returns 1 if the given column-based index is a valid vector column, + * 0 otherwise. + */ +int vec0_column_idx_is_vector(vec0_vtab *pVtab, int column_idx) { + return column_idx >= VEC0_COLUMN_VECTORN_START && + column_idx <= (VEC0_COLUMN_VECTORN_START + pVtab->numVectorColumns - + 1); // TODO is -1 necessary here? +} + +/** + * Returns the vector index of the given vector column index. + * ONLY call if validated with vec0_column_idx_is_vector before + */ +int vec0_column_idx_to_vector_idx(vec0_vtab *pVtab, int column_idx) { + UNUSED_PARAMETER(pVtab); + return column_idx - VEC0_COLUMN_VECTORN_START; +} + +/** + * @brief Return the id value from the _rowids table where _rowids.rowid = + * rowid. + * + * @param pVtab: vec0 table to query + * @param rowid: rowid of the row to query. + * @param out: A dup'ed sqlite3_value of the id column. Might be null. + * Must be cleaned up with sqlite3_value_free(). + * @returns SQLITE_OK on success, error code on failure + */ +int vec0_get_id_value_from_rowid(vec0_vtab *pVtab, sqlite3_int64 rowid, + sqlite3_value **out) { + // TODO different stmt than stmtRowidsGetChunkPosition? + // TODO return rc instead + sqlite3_reset(pVtab->stmtRowidsGetChunkPosition); + sqlite3_clear_bindings(pVtab->stmtRowidsGetChunkPosition); + sqlite3_bind_int64(pVtab->stmtRowidsGetChunkPosition, 1, rowid); + int rc = sqlite3_step(pVtab->stmtRowidsGetChunkPosition); + if (rc == SQLITE_ROW) { + return SQLITE_ERROR; + } + sqlite3_value *value = + sqlite3_column_value(pVtab->stmtRowidsGetChunkPosition, 0); + *out = sqlite3_value_dup(value); + return SQLITE_OK; +} + +// TODO make sure callees use the return value of this function +int vec0_result_id(vec0_vtab *p, sqlite3_context *context, + sqlite3_int64 rowid) { + if (!p->pkIsText) { + sqlite3_result_int64(context, rowid); + return SQLITE_OK; + } + sqlite3_value *valueId; + int rc = vec0_get_id_value_from_rowid(p, rowid, &valueId); + if (rc != SQLITE_OK) { + return rc; + } + if (!valueId) { + sqlite3_result_error_nomem(context); + } else { + sqlite3_result_value(context, valueId); + sqlite3_value_free(valueId); + } + return SQLITE_OK; +} + +/** + * @brief + * + * @param pVtab: virtual table to query + * @param rowid: row to lookup + * @param vector_column_idx: which vector column to query + * @param outVector: Output pointer to the vector buffer. + * Must be sqlite3_free()'ed. + * @param outVectorSize: Pointer to a int where the size of outVector + * will be stored. + * @return int SQLITE_OK on success. + */ +int vec0_get_vector_data(vec0_vtab *pVtab, sqlite3_int64 rowid, + int vector_column_idx, void **outVector, + int *outVectorSize) { + todo_assert((vector_column_idx >= 0) && + (vector_column_idx < pVtab->numVectorColumns)); + + sqlite3_reset(pVtab->stmtRowidsGetChunkPosition); + sqlite3_clear_bindings(pVtab->stmtRowidsGetChunkPosition); + sqlite3_bind_int64(pVtab->stmtRowidsGetChunkPosition, 1, rowid); + int rc = sqlite3_step(pVtab->stmtRowidsGetChunkPosition); + todo_assert(rc == SQLITE_ROW); + sqlite3_int64 chunk_id = + sqlite3_column_int64(pVtab->stmtRowidsGetChunkPosition, 1); + sqlite3_int64 chunk_offset = + sqlite3_column_int64(pVtab->stmtRowidsGetChunkPosition, 2); + + rc = sqlite3_blob_reopen(pVtab->vectorBlobs[vector_column_idx], chunk_id); + todo_assert(rc == SQLITE_OK); + size_t size = + vector_column_byte_size(pVtab->vector_columns[vector_column_idx]); + int blobOffset = chunk_offset * size; + + void *buf = sqlite3_malloc(size); + todo_assert(buf); + rc = sqlite3_blob_read(pVtab->vectorBlobs[vector_column_idx], buf, size, + blobOffset); + todo_assert(rc == SQLITE_OK); + + *outVector = buf; + if (outVectorSize) { + *outVectorSize = size; + } + return SQLITE_OK; +} + +/** + * @brief For the given rowid, found the chunk_id and chunk_offset for that row. + * + * @param p: vec0 table + * @param rowid: rowid of row to lookup + * @param chunk_id: Output chunk_id of the row, refs _chunks.rowid + * @param chunk_offset: Output chunk_offset of the row + * @return int: SQLITE_OK on success, error code on failure + */ +int vec0_get_chunk_position(vec0_vtab *p, sqlite3_int64 rowid, + sqlite3_int64 *chunk_id, + sqlite3_int64 *chunk_offset) { + int rc; + sqlite3_reset(p->stmtRowidsGetChunkPosition); + sqlite3_clear_bindings(p->stmtRowidsGetChunkPosition); + sqlite3_bind_int64(p->stmtRowidsGetChunkPosition, 1, rowid); + rc = sqlite3_step(p->stmtRowidsGetChunkPosition); + assert(rc == SQLITE_ROW); + *chunk_id = sqlite3_column_int64(p->stmtRowidsGetChunkPosition, 1); + *chunk_offset = sqlite3_column_int64(p->stmtRowidsGetChunkPosition, 2); + rc = sqlite3_step(p->stmtRowidsGetChunkPosition); + todo_assert(rc == SQLITE_DONE); + return SQLITE_OK; +} + +/** + * @brief Adds a new chunk for the vec0 table, and the cooresponding vector + * chunks. + * + * Inserts a new row into the _chunks table, with blank data, and uses that new + * rowid to insert new blank rows into _vector_chunksXX tables. + * + * @param p: vec0 table to add new chunk + * @param chunk_rowid: Putput pointer, if not NULL, then will be filled with the + * new chunk rowid. + * @return int SQLITE_OK on success, error code otherwise. + */ +int vec0_new_chunk(vec0_vtab *p, sqlite3_int64 *chunk_rowid) { + int rc; + char *zSql; + sqlite3_stmt *stmt; + sqlite3_int64 rowid; + + // Step 1: Insert a new row in _chunks, capture that new rowid + zSql = sqlite3_mprintf("INSERT INTO " VEC0_SHADOW_CHUNKS_NAME + "(size, validity, rowids) " + "VALUES (?, ?, ?);", + p->schemaName, p->tableName); + todo_assert(zSql); + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + todo_assert(rc == SQLITE_OK); +#ifdef SQLITE_VEC_THREADSAFE + sqlite3_mutex_enter(sqlite3_db_mutex(p->db)); +#endif + rc = sqlite3_bind_int64(stmt, 1, p->chunk_size); // size + todo_assert(rc == SQLITE_OK); + rc = sqlite3_bind_zeroblob(stmt, 2, + p->chunk_size / CHAR_BIT); // validity bitmap + todo_assert(rc == SQLITE_OK); + rc = sqlite3_bind_zeroblob(stmt, 3, + p->chunk_size * sizeof(sqlite3_int64)); // rowids + todo_assert(rc == SQLITE_OK); + rc = sqlite3_step(stmt); + todo_assert(rc == SQLITE_DONE); + rowid = sqlite3_last_insert_rowid(p->db); +#ifdef SQLITE_VEC_THREADSAFE + sqlite3_mutex_leave(sqlite3_db_mutex(p->db)); +#endif + sqlite3_finalize(stmt); + + // Step 2: Create new vector chunks for each vector column, with + // that new chunk_rowid. + + for (int i = 0; i < p->numVectorColumns; i++) { + + sqlite3_int64 vectorsSize = 0; + switch (p->vector_columns[i].element_type) { + case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: + vectorsSize = + p->chunk_size * p->vector_columns[i].dimensions * sizeof(float); + break; + case SQLITE_VEC_ELEMENT_TYPE_INT8: + vectorsSize = + p->chunk_size * p->vector_columns[i].dimensions * sizeof(int8_t); + break; + case SQLITE_VEC_ELEMENT_TYPE_BIT: + vectorsSize = + ceil(p->chunk_size * p->vector_columns[i].dimensions / CHAR_BIT); + break; + } + + zSql = sqlite3_mprintf("INSERT INTO " VEC0_SHADOW_VECTOR_N_NAME + "(rowid, vectors)" + "VALUES (?, ?)", + p->schemaName, p->tableName, i); + todo_assert(zSql); + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + todo_assert(rc == SQLITE_OK); + rc = sqlite3_bind_int64(stmt, 1, rowid); + todo_assert(rc == SQLITE_OK); + + rc = sqlite3_bind_zeroblob64(stmt, 2, vectorsSize); + todo_assert(rc == SQLITE_OK); + + rc = sqlite3_step(stmt); + todo_assert(rc == SQLITE_DONE); + sqlite3_finalize(stmt); + } + + if (chunk_rowid) { + *chunk_rowid = rowid; + } + + return SQLITE_OK; +} + +// Possible query plans for xBestIndex on vec0 tables. +typedef enum { + // Full scan, every row is queried. + SQLITE_VEC0_QUERYPLAN_FULLSCAN, + // A single row is queried by rowid/id + SQLITE_VEC0_QUERYPLAN_POINT, + // A KNN-style query is made on a specific vector column. + // Requires 1) a MATCH/compatible distance contraint on + // a single vector column, 2) ORDER BY distance, and 3) + // either a 'LIMIT ?' or 'k=?' contraint + SQLITE_VEC0_QUERYPLAN_KNN, +} vec0_query_plan; + +struct vec0_query_fullscan_data { + sqlite3_stmt *rowids_stmt; + int8_t done; +}; +int vec0_query_fullscan_data_clear( + struct vec0_query_fullscan_data *fullscan_data) { + int rc; + if (fullscan_data->rowids_stmt) { + rc = sqlite3_finalize(fullscan_data->rowids_stmt); + todo_assert(rc == SQLITE_OK); + fullscan_data->rowids_stmt = NULL; + } + return SQLITE_OK; +} + +struct vec0_query_knn_data { + sqlite3_int64 k; + // Array of rowids of size k. Must be freed with sqlite3_freee(). + sqlite3_int64 *rowids; + // Array of distances of size k. Must be freed with sqlite3_freee(). + float *distances; + sqlite3_int64 current_idx; +}; +int vec0_query_knn_data_clear(struct vec0_query_knn_data *knn_data) { + if (knn_data->rowids) { + sqlite3_free(knn_data->rowids); + knn_data->rowids = NULL; + } + if (knn_data->distances) { + sqlite3_free(knn_data->distances); + knn_data->distances = NULL; + } + return SQLITE_OK; +} + +struct vec0_query_point_data { + sqlite3_int64 rowid; + void *vectors[VEC0_MAX_VECTOR_COLUMNS]; + int done; +}; +void vec0_query_point_data_clear(struct vec0_query_point_data *point_data) { + for (int i = 0; i < VEC0_MAX_VECTOR_COLUMNS; i++) { + sqlite3_free(point_data->vectors[i]); + point_data->vectors[i] = NULL; + } +} + +typedef struct vec0_cursor vec0_cursor; +struct vec0_cursor { + sqlite3_vtab_cursor base; + + vec0_query_plan query_plan; + struct vec0_query_fullscan_data *fullscan_data; + struct vec0_query_knn_data *knn_data; + struct vec0_query_point_data *point_data; +}; + +#define SET_VTAB_ERROR(msg) \ + do { \ + sqlite3_free(pVTab->zErrMsg); \ + pVTab->zErrMsg = sqlite3_mprintf("%s", msg); \ + } while (0) +#define SET_VTAB_CURSOR_ERROR(msg) \ + do { \ + sqlite3_free(pVtabCursor->pVtab->zErrMsg); \ + pVtabCursor->pVtab->zErrMsg = sqlite3_mprintf("%s", msg); \ + } while (0) + +static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, + sqlite3_vtab **ppVtab, char **pzErr, bool isCreate) { + UNUSED_PARAMETER(pAux); + UNUSED_PARAMETER(pzErr); // TODO use! + vec0_vtab *pNew; + int rc; + const char *zSql; + + pNew = sqlite3_malloc(sizeof(*pNew)); + if (pNew == 0) + return SQLITE_NOMEM; + memset(pNew, 0, sizeof(*pNew)); + *ppVtab = (sqlite3_vtab *)pNew; + + int chunk_size = -1; + int numVectorColumns = 0; + + // track if a "primary key" column is defined + char *pkColumnName = NULL; + int pkColumnNameLength; + int pkColumnType; + + for (int i = 3; i < argc; i++) { + todo_assert(numVectorColumns <= VEC0_MAX_VECTOR_COLUMNS); + int rc = parse_vector_column(argv[i], strlen(argv[i]), + &pNew->vector_columns[numVectorColumns]); + if (rc == SQLITE_OK) { + todo_assert(rc == SQLITE_OK); + todo_assert(pNew->vector_columns[numVectorColumns].dimensions > 0); + pNew->vector_columns[numVectorColumns].name = sqlite3_mprintf( + "%.*s", pNew->vector_columns[numVectorColumns].name_length, + pNew->vector_columns[numVectorColumns].name); + assert(pNew->vector_columns[numVectorColumns].name); + numVectorColumns++; + continue; + } + + char *cName = NULL; + int cNameLength; + int cType; + rc = parse_primary_key_definition(argv[i], strlen(argv[i]), &cName, + &cNameLength, &cType); + if (rc == SQLITE_OK) { + todo_assert(!pkColumnName); + pkColumnName = cName; + pkColumnNameLength = cNameLength; + pkColumnType = cType; + continue; + } + char *key; + char *value; + int keyLength, valueLength; + rc = vec0_parse_table_option(argv[i], strlen(argv[i]), &key, &keyLength, + &value, &valueLength); + if (rc == SQLITE_OK) { + if (sqlite3_strnicmp(key, "chunk_size", keyLength) == 0) { + todo_assert(chunk_size < 0); + chunk_size = atoi(value); + if (chunk_size <= 0) { + todo("chunk_size must be positive"); + } + if ((chunk_size % 8) != 0) { + todo("chunk_size must be divisible by 8"); + } + } else { + todo("handle unknown table option"); + } + continue; + } + todo("unparseable constructor"); + } + + if (chunk_size < 0) { + chunk_size = 1024; + } + + todo_assert(numVectorColumns > 0); + todo_assert(numVectorColumns <= VEC0_MAX_VECTOR_COLUMNS); + + sqlite3_str *createStr = sqlite3_str_new(NULL); + sqlite3_str_appendall(createStr, "CREATE TABLE x("); + if (pkColumnName) { + sqlite3_str_appendf(createStr, "\"%.*w\" primary key, ", pkColumnNameLength, + pkColumnName); + } else { + sqlite3_str_appendall(createStr, "rowid, "); + } + for (int i = 0; i < numVectorColumns; i++) { + sqlite3_str_appendf(createStr, "\"%.*w\", ", + pNew->vector_columns[i].name_length, + pNew->vector_columns[i].name); + } + sqlite3_str_appendall(createStr, " distance hidden, k hidden) "); + if (pkColumnName) { + sqlite3_str_appendall(createStr, "without rowid "); + } + zSql = sqlite3_str_finish(createStr); + todo_assert(zSql); + rc = sqlite3_declare_vtab(db, zSql); + sqlite3_free((void *)zSql); + if (rc != SQLITE_OK) { + return rc; + } + + todo_assert(chunk_size > 0); + + const char *schemaName = argv[1]; + const char *tableName = argv[2]; + + pNew->db = db; + pNew->pkIsText = pkColumnType == SQLITE_TEXT; + pNew->schemaName = sqlite3_mprintf("%s", schemaName); + pNew->tableName = sqlite3_mprintf("%s", tableName); + pNew->shadowRowidsName = sqlite3_mprintf("%s_rowids", tableName); + pNew->shadowChunksName = sqlite3_mprintf("%s_chunks", tableName); + pNew->numVectorColumns = numVectorColumns; + for (int i = 0; i < pNew->numVectorColumns; i++) { + pNew->shadowVectorChunksNames[i] = + sqlite3_mprintf("%s_vector_chunks%02d", tableName, i); + } + pNew->chunk_size = chunk_size; + + // if xCreate, then create the necessary shadow tables + if (isCreate) { + sqlite3_stmt *stmt; + int rc; + char *zCreateShadowChunks; + char *zCreateShadowRowids; + + // create the _chunks shadow table + zCreateShadowChunks = sqlite3_mprintf(VEC0_SHADOW_CHUNKS_CREATE, + pNew->schemaName, pNew->tableName); + todo_assert(zCreateShadowChunks); + rc = sqlite3_prepare_v2(db, zCreateShadowChunks, -1, &stmt, 0); + sqlite3_free((void *)zCreateShadowChunks); + todo_assert(rc == SQLITE_OK); + rc = sqlite3_step(stmt); + todo_assert(rc == SQLITE_DONE); + sqlite3_finalize(stmt); + + // create the _rowids shadow table + if (pNew->pkIsText) { + // adds a "text unique not null" constraint to the id column + zCreateShadowRowids = sqlite3_mprintf(VEC0_SHADOW_ROWIDS_CREATE_PK_TEXT, + pNew->schemaName, pNew->tableName); + } else { + zCreateShadowRowids = sqlite3_mprintf(VEC0_SHADOW_ROWIDS_CREATE_BASIC, + pNew->schemaName, pNew->tableName); + } + todo_assert(zCreateShadowRowids); + rc = sqlite3_prepare_v2(db, zCreateShadowRowids, -1, &stmt, 0); + sqlite3_free((void *)zCreateShadowRowids); + todo_assert(rc == SQLITE_OK); + rc = sqlite3_step(stmt); + todo_assert(rc == SQLITE_DONE); + sqlite3_finalize(stmt); + + for (int i = 0; i < pNew->numVectorColumns; i++) { + char *zSql = sqlite3_mprintf(VEC0_SHADOW_VECTOR_N_CREATE, + pNew->schemaName, pNew->tableName, i); + todo_assert(zSql); + int rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, 0); + todo_assert(rc == SQLITE_OK); + rc = sqlite3_step(stmt); + todo_assert(rc == SQLITE_DONE); + sqlite3_finalize(stmt); + sqlite3_free((void *)zSql); + } + + rc = vec0_new_chunk(pNew, NULL); + assert(rc == SQLITE_OK); + } + + // init stmtLatestChunk + { + zSql = sqlite3_mprintf("SELECT max(rowid) FROM " VEC0_SHADOW_CHUNKS_NAME, + pNew->schemaName, pNew->tableName); + todo_assert(zSql); + rc = sqlite3_prepare_v2(pNew->db, zSql, -1, &pNew->stmtLatestChunk, 0); + sqlite3_free((void *)zSql); + todo_assert(rc == SQLITE_OK); + } + + // init stmtRowidsInsertRowid + { + zSql = sqlite3_mprintf("INSERT INTO " VEC0_SHADOW_ROWIDS_NAME "(rowid)" + "VALUES (?);", + pNew->schemaName, pNew->tableName); + todo_assert(zSql); + rc = + sqlite3_prepare_v2(pNew->db, zSql, -1, &pNew->stmtRowidsInsertRowid, 0); + sqlite3_free((void *)zSql); + todo_assert(rc == SQLITE_OK); + } + + // init stmtRowidsInsertId + { + zSql = sqlite3_mprintf("INSERT INTO " VEC0_SHADOW_ROWIDS_NAME "(id)" + "VALUES (?);", + pNew->schemaName, pNew->tableName); + todo_assert(zSql); + rc = sqlite3_prepare_v2(pNew->db, zSql, -1, &pNew->stmtRowidsInsertId, 0); + sqlite3_free((void *)zSql); + todo_assert(rc == SQLITE_OK); + } + + // init stmtRowidsUpdatePosition + { + zSql = sqlite3_mprintf(" UPDATE " VEC0_SHADOW_ROWIDS_NAME + " SET chunk_id = ?, chunk_offset = ?" + " WHERE rowid = ?", + pNew->schemaName, pNew->tableName); + todo_assert(zSql); + rc = sqlite3_prepare_v2(pNew->db, zSql, -1, &pNew->stmtRowidsUpdatePosition, + 0); + sqlite3_free((void *)zSql); + todo_assert(rc == SQLITE_OK); + } + + // init stmtRowidsGetChunkPosition + { + zSql = sqlite3_mprintf("SELECT id, chunk_id, chunk_offset " + "FROM " VEC0_SHADOW_ROWIDS_NAME " WHERE rowid = ?", + pNew->schemaName, pNew->tableName); + todo_assert(zSql); + rc = sqlite3_prepare_v2(pNew->db, zSql, -1, + &pNew->stmtRowidsGetChunkPosition, 0); + sqlite3_free((void *)zSql); + todo_assert(rc == SQLITE_OK); + } + + // init vectorBlobs[..] + for (int i = 0; i < pNew->numVectorColumns; i++) { + // TODO this is assuming there's always a chunk with chunk_id = 1. Is that + // true? + int rc = sqlite3_blob_open(db, pNew->schemaName, + pNew->shadowVectorChunksNames[i], "vectors", 1, + 0, &pNew->vectorBlobs[i]); + todo_assert(rc == SQLITE_OK); + } + + return SQLITE_OK; +} + +static int vec0Create(sqlite3 *db, void *pAux, int argc, + const char *const *argv, sqlite3_vtab **ppVtab, + char **pzErr) { + return vec0_init(db, pAux, argc, argv, ppVtab, pzErr, true); +} +static int vec0Connect(sqlite3 *db, void *pAux, int argc, + const char *const *argv, sqlite3_vtab **ppVtab, + char **pzErr) { + return vec0_init(db, pAux, argc, argv, ppVtab, pzErr, false); +} + +static int vec0Disconnect(sqlite3_vtab *pVtab) { + vec0_vtab *p = (vec0_vtab *)pVtab; + sqlite3_free(p->schemaName); + sqlite3_free(p->tableName); + sqlite3_free(p->shadowChunksName); + sqlite3_free(p->shadowRowidsName); + for (int i = 0; i < p->numVectorColumns; i++) { + sqlite3_free(p->shadowVectorChunksNames[i]); + sqlite3_blob_close(p->vectorBlobs[i]); + } + sqlite3_finalize(p->stmtLatestChunk); + sqlite3_finalize(p->stmtRowidsInsertRowid); + sqlite3_finalize(p->stmtRowidsInsertId); + sqlite3_finalize(p->stmtRowidsUpdatePosition); + sqlite3_finalize(p->stmtRowidsGetChunkPosition); + for (int i = 0; i < p->numVectorColumns; i++) { + sqlite3_free(p->vector_columns[i].name); + p->vector_columns[i].name = NULL; + } + sqlite3_free(p); + return SQLITE_OK; +} +static int vec0Destroy(sqlite3_vtab *pVtab) { + vec0_vtab *p = (vec0_vtab *)pVtab; + sqlite3_stmt *stmt; + const char *zSql = sqlite3_mprintf("TODO", p->schemaName, p->tableName); + int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0); + sqlite3_free((void *)zSql); + + if (rc == SQLITE_OK) { + // ignore if there's an error? + sqlite3_step(stmt); + } + + sqlite3_finalize(stmt); + vec0Disconnect(pVtab); + return SQLITE_OK; +} + +static int vec0Open(sqlite3_vtab *p, sqlite3_vtab_cursor **ppCursor) { + UNUSED_PARAMETER(p); + vec0_cursor *pCur; + pCur = sqlite3_malloc(sizeof(*pCur)); + if (pCur == 0) + return SQLITE_NOMEM; + memset(pCur, 0, sizeof(*pCur)); + *ppCursor = &pCur->base; + return SQLITE_OK; +} + +static int vec0Close(sqlite3_vtab_cursor *cur) { + int rc; + vec0_cursor *pCur = (vec0_cursor *)cur; + if (pCur->fullscan_data) { + rc = vec0_query_fullscan_data_clear(pCur->fullscan_data); + todo_assert(rc == SQLITE_OK); + sqlite3_free(pCur->fullscan_data); + } + if (pCur->knn_data) { + rc = vec0_query_knn_data_clear(pCur->knn_data); + todo_assert(rc == SQLITE_OK); + sqlite3_free(pCur->knn_data); + } + if (pCur->point_data) { + vec0_query_point_data_clear(pCur->point_data); + sqlite3_free(pCur->point_data); + } + sqlite3_free(pCur); + return SQLITE_OK; +} + +#define VEC0_QUERY_PLAN_FULLSCAN "fullscan" +#define VEC0_QUERY_PLAN_POINT "point" +#define VEC0_QUERY_PLAN_KNN "knn" + +static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { + vec0_vtab *p = (vec0_vtab *)pVTab; + /** + * Possible query plans are: + * 1. KNN when: + * a) An `MATCH` op on vector column + * b) ORDER BY on distance column + * c) LIMIT + * d) rowid in (...) OPTIONAL + * 2. Point when: + * a) An `EQ` op on rowid column + * 3. else: fullscan + * + */ + int iMatchTerm = -1; + int iMatchVectorTerm = -1; + int iLimitTerm = -1; + int iRowidTerm = -1; + int iKTerm = -1; + int iRowidInTerm = -1; + +#ifdef SQLITE_VEC_DEBUG + printf("pIdxInfo->nOrderBy=%d\n", pIdxInfo->nOrderBy); +#endif + + for (int i = 0; i < pIdxInfo->nConstraint; i++) { + uint8_t vtabIn = 0; + // sqlite3_vtab_in() was added in SQLite version 3.38 (2022-02-22) + // ref: https://www.sqlite.org/changes.html#version_3_38_0 + if (sqlite3_libversion_number() >= 3038000) { + vtabIn = sqlite3_vtab_in(pIdxInfo, i, -1); + } +#ifdef SQLITE_VEC_DEBUG + printf("xBestIndex [%d] usable=%d iColumn=%d op=%d vtabin=%d\n", i, + pIdxInfo->aConstraint[i].usable, pIdxInfo->aConstraint[i].iColumn, + pIdxInfo->aConstraint[i].op, vtabIn); +#endif + if (!pIdxInfo->aConstraint[i].usable) + continue; + + int iColumn = pIdxInfo->aConstraint[i].iColumn; + int op = pIdxInfo->aConstraint[i].op; + if (op == SQLITE_INDEX_CONSTRAINT_MATCH && + vec0_column_idx_is_vector(p, iColumn)) { + if (iMatchTerm > -1) { + // TODO only 1 match operator at a time + return SQLITE_ERROR; + } + iMatchTerm = i; + iMatchVectorTerm = vec0_column_idx_to_vector_idx(p, iColumn); + } + if (op == SQLITE_INDEX_CONSTRAINT_LIMIT) { + iLimitTerm = i; + } + if (op == SQLITE_INDEX_CONSTRAINT_EQ && iColumn == VEC0_COLUMN_ID) { + if (vtabIn) { + todo_assert(iRowidInTerm == -1); + iRowidInTerm = i; + + } else { + iRowidTerm = i; + } + } + if (op == SQLITE_INDEX_CONSTRAINT_EQ && iColumn == vec0_column_k_idx(p)) { + iKTerm = i; + } + } + if (iMatchTerm >= 0) { + if (iLimitTerm < 0 && iKTerm < 0) { + // TODO: error, match on vector1 should require a limit for KNN. right? + return SQLITE_ERROR; + } + if (iLimitTerm >= 0 && iKTerm >= 0) { + return SQLITE_ERROR; + } + if (pIdxInfo->nOrderBy < 1) { + // TODO error, `ORDER BY DISTANCE required + SET_VTAB_ERROR("ORDER BY distance required"); + return SQLITE_CONSTRAINT; + } + if (pIdxInfo->nOrderBy > 1) { + // TODO error, orderByConsumed is all or nothing, only 1 order by allowed + SET_VTAB_ERROR("more than 1 ORDER BY clause provided"); + return SQLITE_CONSTRAINT; + } + if (pIdxInfo->aOrderBy[0].iColumn != vec0_column_distance_idx(p)) { + // TODO error, ORDER BY must be on column + SET_VTAB_ERROR("ORDER BY must be on the distance column"); + return SQLITE_CONSTRAINT; + } + if (pIdxInfo->aOrderBy[0].desc) { + // TODO KNN should be ascending, is descending possible? + SET_VTAB_ERROR("Only ascending in ORDER BY distance clause is supported, " + "DESC is not supported yet."); + return SQLITE_CONSTRAINT; + } + + pIdxInfo->orderByConsumed = 1; + pIdxInfo->aConstraintUsage[iMatchTerm].argvIndex = 1; + pIdxInfo->aConstraintUsage[iMatchTerm].omit = 1; + if (iLimitTerm >= 0) { + pIdxInfo->aConstraintUsage[iLimitTerm].argvIndex = 2; + pIdxInfo->aConstraintUsage[iLimitTerm].omit = 1; + } else { + pIdxInfo->aConstraintUsage[iKTerm].argvIndex = 2; + pIdxInfo->aConstraintUsage[iKTerm].omit = 1; + } + + sqlite3_str *idxStr = sqlite3_str_new(NULL); + sqlite3_str_appendall(idxStr, "knn:"); +#define VEC0_IDX_KNN_ROWID_IN 'I' + if (iRowidInTerm >= 0) { + // already validated as >= SQLite 3.38 bc iRowidInTerm is only >= 0 when + // vtabIn == 1 + sqlite3_vtab_in(pIdxInfo, iRowidInTerm, 1); + sqlite3_str_appendchar(idxStr, VEC0_IDX_KNN_ROWID_IN, 1); + pIdxInfo->aConstraintUsage[iRowidInTerm].argvIndex = 3; + pIdxInfo->aConstraintUsage[iRowidInTerm].omit = 1; + } + pIdxInfo->idxNum = iMatchVectorTerm; + pIdxInfo->idxStr = sqlite3_str_finish(idxStr); + if (!pIdxInfo->idxStr) { + return SQLITE_NOMEM; + } + pIdxInfo->needToFreeIdxStr = 1; + pIdxInfo->estimatedCost = 30.0; + pIdxInfo->estimatedRows = 10; + + } else if (iRowidTerm >= 0) { + pIdxInfo->aConstraintUsage[iRowidTerm].argvIndex = 1; + pIdxInfo->aConstraintUsage[iRowidTerm].omit = 1; + pIdxInfo->idxNum = pIdxInfo->colUsed; + pIdxInfo->idxStr = VEC0_QUERY_PLAN_POINT; + pIdxInfo->needToFreeIdxStr = 0; + pIdxInfo->estimatedCost = 10.0; + pIdxInfo->estimatedRows = 1; + } else { + pIdxInfo->idxStr = VEC0_QUERY_PLAN_FULLSCAN; + pIdxInfo->estimatedCost = 3000000.0; + pIdxInfo->estimatedRows = 100000; + } + + return SQLITE_OK; +} + +// forward delcaration bc vec0Filter uses it +static int vec0Next(sqlite3_vtab_cursor *cur); + +void dethrone(int k, float *base_distances, sqlite3_int64 *base_rowids, + size_t chunk_size, int32_t *chunk_top_idx, float *chunk_distances, + sqlite3_int64 *chunk_rowids, + + sqlite3_int64 **out_rowids, float **out_distances) { + *out_rowids = sqlite3_malloc(k * sizeof(sqlite3_int64)); + todo_assert(out_rowids); + *out_distances = sqlite3_malloc(k * sizeof(float)); + todo_assert(out_distances); + + size_t ptrA = 0; + size_t ptrB = 0; + for (int i = 0; i < k; i++) { + if (chunk_distances[chunk_top_idx[ptrA]] < base_distances[ptrB]) { + (*out_rowids)[i] = chunk_rowids[chunk_top_idx[ptrA]]; + (*out_distances)[i] = chunk_distances[chunk_top_idx[ptrA]]; + // TODO if ptrA at chunk_size-1 is always minimum, won't it always repeat? + if (ptrA < (chunk_size - 1)) { + ptrA++; + } + } else { + (*out_rowids)[i] = base_rowids[ptrB]; + (*out_distances)[i] = base_distances[ptrB]; + ptrB++; + } + } +} + +// TODO: Ya this shit is slow + +/** + * @brief Finds the minimum k items in distances, and writes the indicies to + * out. + * + * @param distances input float array of size n, the items to consider. + * @param n: size of distances array. + * @param out: Output array of size k, will contain the minumum k element + * indicies + * @param k: Size of output array + * @return int + */ +int min_idx(const float *distances, int32_t n, int32_t *out, int32_t k) { + todo_assert(k > 0); + todo_assert(k <= n); + + unsigned char *taken = malloc(n * sizeof(unsigned char)); + todo_assert(taken); + memset(taken, 0, n); + + for (int ik = 0; ik < k; ik++) { + int min_idx = 0; + while (min_idx < n && taken[min_idx]) { + min_idx++; + } + todo_assert(min_idx < n); + + for (int i = 0; i < n; i++) { + if (distances[i] < distances[min_idx] && !taken[i]) { + min_idx = i; + } + } + + out[ik] = min_idx; + taken[min_idx] = 1; + } + free(taken); + return SQLITE_OK; +} + +int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum, + const char *idxStr, int argc, sqlite3_value **argv) { + UNUSED_PARAMETER(idxNum); + UNUSED_PARAMETER(idxStr); + todo_assert(argc >= 2); + int rc; + pCur->query_plan = SQLITE_VEC0_QUERYPLAN_KNN; + struct vec0_query_knn_data *knn_data = + sqlite3_malloc(sizeof(struct vec0_query_knn_data)); + if (!knn_data) { + return SQLITE_NOMEM; + } + memset(knn_data, 0, sizeof(struct vec0_query_knn_data)); + + int vectorColumnIdx = idxNum; + struct VectorColumnDefinition *vector_column = + &p->vector_columns[vectorColumnIdx]; + + void *queryVector; + size_t dimensions; + enum VectorElementType elementType; + vector_cleanup cleanup; + char * err; + rc = vector_from_value(argv[0], &queryVector, &dimensions, &elementType, &cleanup, &err); + todo_assert(elementType == vector_column->element_type); + todo_assert(dimensions == vector_column->dimensions); + + + sqlite3_int64 k = sqlite3_value_int64(argv[1]); + todo_assert(k >= 0); + if (k == 0) { + knn_data->k = 0; + pCur->knn_data = knn_data; + return SQLITE_OK; + } + + // handle when a `rowid in (...)` operation was provided + // Array of all the rowids that appear in any `rowid in (...)` constraint. + // NULL if none were provided, which means a "full" scan. + struct Array *arrayRowidsIn = NULL; + if (argc > 2) { + sqlite3_value *item; + int rc; + arrayRowidsIn = sqlite3_malloc(sizeof(struct Array)); + todo_assert(arrayRowidsIn); + rc = array_init(arrayRowidsIn, sizeof(sqlite3_int64), 32); + todo_assert(rc == SQLITE_OK); + for (rc = sqlite3_vtab_in_first(argv[2], &item); rc == SQLITE_OK && item; + rc = sqlite3_vtab_in_next(argv[2], &item)) { + sqlite3_int64 rowid = sqlite3_value_int64(item); + rc = array_append(arrayRowidsIn, &rowid); + todo_assert(rc == SQLITE_OK); + } + todo_assert(rc == SQLITE_DONE); + qsort(arrayRowidsIn->z, arrayRowidsIn->length, arrayRowidsIn->element_size, + _cmp); + } + + sqlite3_int64 *topk_rowids = sqlite3_malloc(k * sizeof(sqlite3_int64)); + todo_assert(topk_rowids); + for (int i = 0; i < k; i++) { + // TODO do we need to ensure that rowid is never -1? + topk_rowids[i] = -1; + } + float *topk_distances = sqlite3_malloc(k * sizeof(float)); + todo_assert(topk_distances); + for (int i = 0; i < k; i++) { + topk_distances[i] = __FLT_MAX__; + } + + // for each chunk, get top min(k, chunk_size) rowid + distances to query vec. + // then reconcile all topk_chunks for a true top k. + // output only rowids + distances for now + + { + sqlite3_blob *blobVectors; + sqlite3_stmt *stmtChunks; + char *zSql; + zSql = sqlite3_mprintf("select chunk_id, validity, rowids " + " from " VEC0_SHADOW_CHUNKS_NAME, + p->schemaName, p->tableName); + todo_assert(zSql); + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmtChunks, NULL); + sqlite3_free(zSql); + todo_assert(rc == SQLITE_OK); + + void *baseVectors = NULL; + sqlite3_int64 baseVectorsSize = 0; + + while (true) { + rc = sqlite3_step(stmtChunks); + if (rc == SQLITE_DONE) + break; + if (rc != SQLITE_ROW) { + todo("chunks iter error"); + } + sqlite3_int64 chunk_id = sqlite3_column_int64(stmtChunks, 0); + unsigned char *chunkValidity = + (unsigned char *)sqlite3_column_blob(stmtChunks, 1); + sqlite3_int64 validitySize = sqlite3_column_bytes(stmtChunks, 1); + todo_assert(validitySize == p->chunk_size / CHAR_BIT); + sqlite3_int64 *chunkRowids = + (sqlite3_int64 *)sqlite3_column_blob(stmtChunks, 2); + sqlite3_int64 rowidsSize = sqlite3_column_bytes(stmtChunks, 2); + todo_assert(rowidsSize == p->chunk_size * sizeof(sqlite3_int64)); + + // open the vector chunk blob for the current chunk + rc = sqlite3_blob_open(p->db, p->schemaName, + p->shadowVectorChunksNames[vectorColumnIdx], + "vectors", chunk_id, 0, &blobVectors); + todo_assert(rc == SQLITE_OK); + sqlite3_int64 currentBaseVectorsSize = sqlite3_blob_bytes(blobVectors); + todo_assert((unsigned long)currentBaseVectorsSize == + p->chunk_size * vector_column_byte_size(*vector_column)); + + if (currentBaseVectorsSize > baseVectorsSize) { + if (baseVectors) { + sqlite3_free(baseVectors); + } + baseVectors = sqlite3_malloc(currentBaseVectorsSize); + todo_assert(baseVectors); + baseVectorsSize = currentBaseVectorsSize; + } + rc = sqlite3_blob_read(blobVectors, baseVectors, currentBaseVectorsSize, + 0); + todo_assert(rc == SQLITE_OK); + + // TODO realloc here, like baseVectors + float *chunk_distances = sqlite3_malloc(p->chunk_size * sizeof(float)); + todo_assert(chunk_distances); + + for (int i = 0; i < p->chunk_size; i++) { + + // Ensure the current vector is "valid" in the validity bitmap. + // If not, skip and continue on + if (!(((chunkValidity[i / CHAR_BIT]) >> (i % CHAR_BIT)) & 1)) { + chunk_distances[i] = __FLT_MAX__; + continue; + }; + // If pre-filtering, make sure the rowid appears in the `rowid in (...)` + // list. + if (arrayRowidsIn) { + sqlite3_int64 rowid = chunkRowids[i]; + void *in = bsearch(&rowid, arrayRowidsIn->z, arrayRowidsIn->length, + sizeof(sqlite3_int64), _cmp); + if (!in) { + chunk_distances[i] = __FLT_MAX__; + continue; + } + } + + float result; + switch (vector_column->element_type) { + case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { + const float *base_i = + ((float *)baseVectors) + (i * vector_column->dimensions); + switch (vector_column->distance_metric) { + case VEC0_DISTANCE_METRIC_L2: { + result = distance_l2_sqr_float(base_i, (float *)queryVector, + &vector_column->dimensions); + break; + } + case VEC0_DISTANCE_METRIC_COSINE: { + result = distance_cosine_float(base_i, (float *)queryVector, + &vector_column->dimensions); + break; + } + } + + // result = distance_cosine(base_i, (float *) queryVector, & + // vector_column->dimensions); + break; + } + case SQLITE_VEC_ELEMENT_TYPE_INT8: { + const int8_t *base_i = + ((int8_t *)baseVectors) + (i * vector_column->dimensions); + switch (vector_column->distance_metric) { + case VEC0_DISTANCE_METRIC_L2: { + result = distance_l2_sqr_int8(base_i, (int8_t *)queryVector, + &vector_column->dimensions); + + break; + } + case VEC0_DISTANCE_METRIC_COSINE: { + result = distance_cosine_int8(base_i, (int8_t *)queryVector, + &vector_column->dimensions); + break; + } + } + + break; + } + case SQLITE_VEC_ELEMENT_TYPE_BIT: { + const uint8_t *base_i = ((uint8_t *)baseVectors) + + (i * (vector_column->dimensions / CHAR_BIT)); + result = distance_hamming(base_i, (uint8_t *)queryVector, + &vector_column->dimensions); + break; + } + } + + chunk_distances[i] = result; + } + + // now that we have the distances + int32_t *chunk_topk_idxs = sqlite3_malloc(k * sizeof(int32_t)); + todo_assert(chunk_topk_idxs); + min_idx(chunk_distances, p->chunk_size, chunk_topk_idxs, + k <= p->chunk_size ? k : p->chunk_size); + + sqlite3_int64 *out_rowids; + float *out_distances; + dethrone(k, topk_distances, topk_rowids, p->chunk_size, chunk_topk_idxs, + chunk_distances, chunkRowids, + + &out_rowids, &out_distances); + for (int i = 0; i < k; i++) { + topk_rowids[i] = out_rowids[i]; + topk_distances[i] = out_distances[i]; + } + sqlite3_free(out_rowids); + sqlite3_free(out_distances); + sqlite3_free(chunk_distances); + sqlite3_free(chunk_topk_idxs); + + sqlite3_blob_close(blobVectors); + } + + sqlite3_free(baseVectors); + rc = sqlite3_finalize(stmtChunks); + todo_assert(rc == SQLITE_OK); + + if (arrayRowidsIn) { + array_cleanup(arrayRowidsIn); + sqlite3_free(arrayRowidsIn); + } + } + + cleanup(queryVector); + + knn_data->current_idx = 0; + knn_data->k = k; + knn_data->rowids = topk_rowids; + knn_data->distances = topk_distances; + + pCur->knn_data = knn_data; + return SQLITE_OK; +} + +int vec0Filter_fullscan(vec0_cursor *pCur, vec0_vtab *p, int idxNum, + const char *idxStr, int argc, sqlite3_value **argv) { + UNUSED_PARAMETER(idxNum); + UNUSED_PARAMETER(idxStr); + UNUSED_PARAMETER(argc); + UNUSED_PARAMETER(argv); + int rc; + char *zSql; + + pCur->query_plan = SQLITE_VEC0_QUERYPLAN_FULLSCAN; + struct vec0_query_fullscan_data *fullscan_data = + sqlite3_malloc(sizeof(struct vec0_query_fullscan_data)); + if (!fullscan_data) { + return SQLITE_NOMEM; + } + memset(fullscan_data, 0, sizeof(struct vec0_query_fullscan_data)); + zSql = sqlite3_mprintf(" SELECT rowid " + " FROM " VEC0_SHADOW_ROWIDS_NAME + " ORDER by chunk_id, chunk_offset ", + p->schemaName, p->tableName); + todo_assert(zSql); + rc = sqlite3_prepare_v2(p->db, zSql, -1, &fullscan_data->rowids_stmt, NULL); + sqlite3_free(zSql); + todo_assert(rc == SQLITE_OK); + rc = sqlite3_step(fullscan_data->rowids_stmt); + fullscan_data->done = rc == SQLITE_DONE; + if (!(rc == SQLITE_ROW || rc == SQLITE_DONE)) { + vec0_query_fullscan_data_clear(fullscan_data); + return SQLITE_ERROR; + } + pCur->fullscan_data = fullscan_data; + return SQLITE_OK; +} + +int vec0Filter_point(vec0_cursor *pCur, vec0_vtab *p, int idxNum, + const char *idxStr, int argc, sqlite3_value **argv) { + UNUSED_PARAMETER(idxNum); + UNUSED_PARAMETER(idxStr); + int rc; + todo_assert(argc == 1); + sqlite3_int64 rowid = sqlite3_value_int64(argv[0]); + + pCur->query_plan = SQLITE_VEC0_QUERYPLAN_POINT; + struct vec0_query_point_data *point_data = + sqlite3_malloc(sizeof(struct vec0_query_point_data)); + if (!point_data) { + return SQLITE_NOMEM; + } + memset(point_data, 0, sizeof(struct vec0_query_point_data)); + + for (int i = 0; i < p->numVectorColumns; i++) { + rc = vec0_get_vector_data(p, rowid, i, &point_data->vectors[i], NULL); + assert(rc == SQLITE_OK); + } + point_data->rowid = rowid; + point_data->done = 0; + pCur->point_data = point_data; + return SQLITE_OK; +} + +static int vec0Filter(sqlite3_vtab_cursor *pVtabCursor, int idxNum, + const char *idxStr, int argc, sqlite3_value **argv) { + vec0_cursor *pCur = (vec0_cursor *)pVtabCursor; + vec0_vtab *p = (vec0_vtab *)pVtabCursor->pVtab; + if (strcmp(idxStr, VEC0_QUERY_PLAN_FULLSCAN) == 0) { + return vec0Filter_fullscan(pCur, p, idxNum, idxStr, argc, argv); + } else if (strncmp(idxStr, "knn:", 4) == 0) { + return vec0Filter_knn(pCur, p, idxNum, idxStr, argc, argv); + } else if (strcmp(idxStr, VEC0_QUERY_PLAN_POINT) == 0) { + return vec0Filter_point(pCur, p, idxNum, idxStr, argc, argv); + } else { + SET_VTAB_CURSOR_ERROR("unknown idxStr"); + return SQLITE_ERROR; + } +} + +static int vec0Rowid(sqlite3_vtab_cursor *cur, sqlite_int64 *pRowid) { + UNUSED_PARAMETER(cur); + UNUSED_PARAMETER(pRowid); + vec0_cursor *pCur = (vec0_cursor *)cur; + todo_assert(pCur->query_plan == SQLITE_VEC0_QUERYPLAN_POINT); + todo_assert(pCur->point_data); + *pRowid = pCur->point_data->rowid; + return SQLITE_OK; +} + +static int vec0Next(sqlite3_vtab_cursor *cur) { + vec0_cursor *pCur = (vec0_cursor *)cur; + switch (pCur->query_plan) { + case SQLITE_VEC0_QUERYPLAN_FULLSCAN: { + todo_assert(pCur->fullscan_data); + int rc = sqlite3_step(pCur->fullscan_data->rowids_stmt); + if (rc == SQLITE_DONE) { + pCur->fullscan_data->done = 1; + return SQLITE_OK; + } + if (rc == SQLITE_ROW) { + // TODO error handle + return SQLITE_OK; + } + return SQLITE_ERROR; + } + case SQLITE_VEC0_QUERYPLAN_KNN: { + todo_assert(pCur->knn_data); + pCur->knn_data->current_idx++; + return SQLITE_OK; + } + case SQLITE_VEC0_QUERYPLAN_POINT: { + todo_assert(pCur->point_data); + pCur->point_data->done = 1; + return SQLITE_OK; + } + default: { + todo("point next impl"); + } + } +} + +static int vec0Eof(sqlite3_vtab_cursor *cur) { + vec0_cursor *pCur = (vec0_cursor *)cur; + switch (pCur->query_plan) { + case SQLITE_VEC0_QUERYPLAN_FULLSCAN: { + todo_assert(pCur->fullscan_data); + return pCur->fullscan_data->done; + } + case SQLITE_VEC0_QUERYPLAN_KNN: { + todo_assert(pCur->knn_data); + return (pCur->knn_data->current_idx >= pCur->knn_data->k) || + (pCur->knn_data->distances[pCur->knn_data->current_idx] == + __FLT_MAX__); + } + case SQLITE_VEC0_QUERYPLAN_POINT: { + todo_assert(pCur->point_data); + return pCur->point_data->done; + } + } +} + +static int vec0Column_fullscan(vec0_vtab *pVtab, vec0_cursor *pCur, + sqlite3_context *context, int i) { + todo_assert(pCur->fullscan_data); + sqlite3_int64 rowid = + sqlite3_column_int64(pCur->fullscan_data->rowids_stmt, 0); + if (i == VEC0_COLUMN_ID) { + vec0_result_id(pVtab, context, rowid); + } else if (vec0_column_idx_is_vector(pVtab, i)) { + void *v; + int sz; + int vector_idx = vec0_column_idx_to_vector_idx(pVtab, i); + int rc = vec0_get_vector_data(pVtab, rowid, vector_idx, &v, &sz); + todo_assert(rc == SQLITE_OK); + sqlite3_result_blob(context, v, sz, SQLITE_TRANSIENT); + sqlite3_result_subtype(context, + pVtab->vector_columns[vector_idx].element_type); + + sqlite3_free(v); + } else if (i == vec0_column_distance_idx(pVtab)) { + sqlite3_result_null(context); + } else { + sqlite3_result_null(context); + } + return SQLITE_OK; +} + +static int vec0Column_point(vec0_vtab *pVtab, vec0_cursor *pCur, + sqlite3_context *context, int i) { + todo_assert(pCur->point_data); + if (i == VEC0_COLUMN_ID) { + vec0_result_id(pVtab, context, pCur->point_data->rowid); + return SQLITE_OK; + } + if (i == vec0_column_distance_idx(pVtab)) { + sqlite3_result_null(context); + return SQLITE_OK; + } + // TODO only have 1st vector data + if (vec0_column_idx_is_vector(pVtab, i)) { + int vector_idx = vec0_column_idx_to_vector_idx(pVtab, i); + sqlite3_result_blob( + context, pCur->point_data->vectors[vector_idx], + vector_column_byte_size(pVtab->vector_columns[vector_idx]), + SQLITE_TRANSIENT); + sqlite3_result_subtype(context, + pVtab->vector_columns[vector_idx].element_type); + return SQLITE_OK; + } + + return SQLITE_OK; +} + +static int vec0Column_knn(vec0_vtab *pVtab, vec0_cursor *pCur, + sqlite3_context *context, int i) { + todo_assert(pCur->knn_data); + if (i == VEC0_COLUMN_ID) { + sqlite3_int64 rowid = pCur->knn_data->rowids[pCur->knn_data->current_idx]; + vec0_result_id(pVtab, context, rowid); + return SQLITE_OK; + } + if (i == vec0_column_distance_idx(pVtab)) { + sqlite3_result_double( + context, pCur->knn_data->distances[pCur->knn_data->current_idx]); + return SQLITE_OK; + } + if (vec0_column_idx_is_vector(pVtab, i)) { + void *out; + int sz; + int rc = vec0_get_vector_data( + pVtab, pCur->knn_data->rowids[pCur->knn_data->current_idx], + vec0_column_idx_to_vector_idx(pVtab, i), &out, &sz); + todo_assert(rc == SQLITE_OK); + sqlite3_result_blob(context, out, sz, sqlite3_free); + return SQLITE_OK; + } + + return SQLITE_OK; +} + +static int vec0Column(sqlite3_vtab_cursor *cur, sqlite3_context *context, + int i) { + vec0_cursor *pCur = (vec0_cursor *)cur; + vec0_vtab *pVtab = (vec0_vtab *)cur->pVtab; + switch (pCur->query_plan) { + case SQLITE_VEC0_QUERYPLAN_FULLSCAN: { + return vec0Column_fullscan(pVtab, pCur, context, i); + } + case SQLITE_VEC0_QUERYPLAN_KNN: { + return vec0Column_knn(pVtab, pCur, context, i); + } + case SQLITE_VEC0_QUERYPLAN_POINT: { + return vec0Column_point(pVtab, pCur, context, i); + } + } + return SQLITE_OK; +} + +/** + * @brief Handles the "insert rowid" step of a row insert operation of a vec0 + * table. + * + * This function will insert a new row into the _rowids vec0 shadow table. + * + * @param p: virtual table + * @param idValue: Value containing the inserted rowid/id value. + * @param rowid: Output rowid, will point to the "real" sqlite3_int64 rowid + * value that was inserted + * @return int SQLITE_OK on success, error code on failure + */ +int vec0Update_InsertRowidStep(vec0_vtab *p, sqlite3_value *idValue, + sqlite3_int64 *rowid) { + + /** + * An insert into a vec0 table can happen a few different ways: + * 1) With default INTEGER primary key: With a supplied sqlite3_int64 rowid + * 2) With default INTEGER primary key: WITHOUT a supplied rowid + * 3) With TEXT primary key: supplied text rowid + */ + + int rc; + + // Option 3: vtab has a user-defined TEXT primary key, so ensure a text value + // is provided. + if (p->pkIsText) { + todo_assert(sqlite3_value_type(idValue) == SQLITE_TEXT); + +#ifdef SQLITE_VEC_THREADSAFE + sqlite3_mutex_enter(sqlite3_db_mutex(p->db)); +#endif + sqlite3_reset(p->stmtRowidsInsertId); + sqlite3_clear_bindings(p->stmtRowidsInsertId); + sqlite3_bind_value(p->stmtRowidsInsertId, 1, idValue); + rc = sqlite3_step(p->stmtRowidsInsertId); + todo_assert(rc == SQLITE_DONE); + *rowid = sqlite3_last_insert_rowid(p->db); +#ifdef SQLITE_VEC_THREADSAFE + sqlite3_mutex_leave(sqlite3_db_mutex(p->db)); +#endif + + } + // Option 1: User supplied a sqlite3_int64 rowid + else if (sqlite3_value_type(idValue) == SQLITE_INTEGER) { + sqlite3_int64 suppliedRowid = sqlite3_value_int64(idValue); + + sqlite3_reset(p->stmtRowidsInsertRowid); + sqlite3_clear_bindings(p->stmtRowidsInsertRowid); + sqlite3_bind_int64(p->stmtRowidsInsertRowid, 1, suppliedRowid); + rc = sqlite3_step(p->stmtRowidsInsertRowid); + todo_assert(rc == SQLITE_DONE); + *rowid = suppliedRowid; + } + // Option 2: User did not suppled a rowid + else { + todo_assert(sqlite3_value_type(idValue) == SQLITE_NULL); +#ifdef SQLITE_VEC_THREADSAFE + sqlite3_mutex_enter(sqlite3_db_mutex(p->db)); +#endif + sqlite3_reset(p->stmtRowidsInsertId); + sqlite3_clear_bindings(p->stmtRowidsInsertId); + // no need to bind a value to ?1 here: needs to be NULL + // so we can get the next autoincremented rowid value. + rc = sqlite3_step(p->stmtRowidsInsertId); + todo_assert(rc == SQLITE_DONE); + *rowid = sqlite3_last_insert_rowid(p->db); +#ifdef SQLITE_VEC_THREADSAFE + sqlite3_mutex_leave(sqlite3_db_mutex(p->db)); +#endif + } + return SQLITE_OK; +} + +/** + * @brief Determines the "next available" chunk position for a newly inserted + * vec0 row. + * + * This operation may insert a new "blank" chunk the _chunks table, if there is + * no more space in previous chunks. + * + * @param p: virtual table + * @param chunk_rowid: Output rowid of the chunk in the _chunks virtual table + * that has the avialabiity. + * @param chunk_offset: Output the index of the available space insert the + * chunk, based on the index of the first available validity bit. + * @param pBlobValidity: Output blob of the validity column of the available + * chunk. Will be opened with read/write permissions. + * @param pValidity: Output buffer of the original chunk's validity column. + * Needs to be cleaned up with sqlite3_free(). + * @return int SQLITE_OK on success, error code on failure + */ +int vec0Update_InsertNextAvailableStep( + vec0_vtab *p, sqlite3_int64 *chunk_rowid, sqlite3_int64 *chunk_offset, + sqlite3_blob **blobChunksValidity, + const unsigned char **bufferChunksValidity) { + + int rc; + sqlite3_int64 validitySize; + *chunk_offset = -1; + + sqlite3_reset(p->stmtLatestChunk); + rc = sqlite3_step(p->stmtLatestChunk); + todo_assert(rc == SQLITE_ROW); + *chunk_rowid = sqlite3_column_int64(p->stmtLatestChunk, 0); + rc = sqlite3_step(p->stmtLatestChunk); + todo_assert(rc == SQLITE_DONE); + + rc = sqlite3_blob_open(p->db, p->schemaName, p->shadowChunksName, "validity", + *chunk_rowid, 1, blobChunksValidity); + todo_assert(rc == SQLITE_OK); + + validitySize = sqlite3_blob_bytes(*blobChunksValidity); + todo_assert(validitySize == p->chunk_size / CHAR_BIT); + + *bufferChunksValidity = sqlite3_malloc(validitySize); + todo_assert(*bufferChunksValidity); + + rc = sqlite3_blob_read(*blobChunksValidity, (void *)*bufferChunksValidity, + validitySize, 0); + todo_assert(rc == SQLITE_OK); + + for (int i = 0; i < validitySize; i++) { + if ((*bufferChunksValidity)[i] == 0b11111111) + continue; + for (int j = 0; j < CHAR_BIT; j++) { + if (((((*bufferChunksValidity)[i] >> j) & 1) == 0)) { + *chunk_offset = (i * CHAR_BIT) + j; + goto done; + } + } + } + +done: + // latest chunk was full, so need to create a new one + if (*chunk_offset == -1) { + int rc = vec0_new_chunk(p, chunk_rowid); + assert(rc == SQLITE_OK); + *chunk_offset = 0; + + // blobChunksValidity and pValidity are stale, pointing to the previous + // (full) chunk. to re-assign them + sqlite3_blob_close(*blobChunksValidity); + sqlite3_free((void *)*bufferChunksValidity); + + rc = sqlite3_blob_open(p->db, p->schemaName, p->shadowChunksName, + "validity", *chunk_rowid, 1, blobChunksValidity); + todo_assert(rc == SQLITE_OK); + validitySize = sqlite3_blob_bytes(*blobChunksValidity); + todo_assert(validitySize == p->chunk_size / CHAR_BIT); + *bufferChunksValidity = sqlite3_malloc(validitySize); + rc = sqlite3_blob_read(*blobChunksValidity, (void *)*bufferChunksValidity, + validitySize, 0); + todo_assert(rc == SQLITE_OK); + } + + return SQLITE_OK; +} + +static int vec0Update_InsertWriteFinalStepVectors( + sqlite3_blob *blobVectors, const void *bVector, sqlite3_int64 chunk_offset, + size_t dimensions, enum VectorElementType element_type) { + int n; + int offset; + + switch (element_type) { + case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: + n = dimensions * sizeof(float); + offset = chunk_offset * dimensions * sizeof(float); + break; + case SQLITE_VEC_ELEMENT_TYPE_INT8: + n = dimensions * sizeof(int8_t); + offset = chunk_offset * dimensions * sizeof(int8_t); + break; + case SQLITE_VEC_ELEMENT_TYPE_BIT: + n = dimensions / CHAR_BIT; + offset = chunk_offset * dimensions / CHAR_BIT; + break; + } + + int rc = sqlite3_blob_write(blobVectors, bVector, n, offset); + todo_assert(rc == SQLITE_OK); + return rc; +} + +/** + * @brief + * + * @param p vec0 virtual table + * @param chunk_rowid: which chunk to write to + * @param chunk_offset: the offset inside the chunk to write the vector to. + * @param rowid: the rowid of the inserting row + * @param vectorDatas: array of the vector data to insert + * @param blobValidity: writeable validity blob of the row's assigned chunk. + * @param validity: snapshot buffer of the valdity column from the row's + * assigned chunk. + * @return int SQLITE_OK on success, error code on failure + */ +int vec0Update_InsertWriteFinalStep(vec0_vtab *p, sqlite3_int64 chunk_rowid, + sqlite3_int64 chunk_offset, + sqlite3_int64 rowid, void *vectorDatas[], + sqlite3_blob *blobChunksValidity, + const unsigned char *bufferChunksValidity) { + int rc; + sqlite3_blob *blobChunksRowids; + + // mark the validity bit for this row in the chunk's validity bitmap + // Get the byte offset of the bitmap + char unsigned bx = bufferChunksValidity[chunk_offset / CHAR_BIT]; + // set the bit at the chunk_offset position inside that byte + bx = bx | (1 << (chunk_offset % CHAR_BIT)); + // write that 1 byte + rc = sqlite3_blob_write(blobChunksValidity, &bx, 1, chunk_offset / CHAR_BIT); + todo_assert(rc == SQLITE_OK); + + // Go insert the vector data into the vector chunk shadow tables + for (int i = 0; i < p->numVectorColumns; i++) { + sqlite3_blob *blobVectors; + rc = sqlite3_blob_open(p->db, p->schemaName, p->shadowVectorChunksNames[i], + "vectors", chunk_rowid, 1, &blobVectors); + todo_assert(rc == SQLITE_OK); + + switch (p->vector_columns[i].element_type) { + case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: + todo_assert((unsigned long)sqlite3_blob_bytes(blobVectors) == + p->chunk_size * p->vector_columns[i].dimensions * + sizeof(float)); + break; + case SQLITE_VEC_ELEMENT_TYPE_INT8: + todo_assert((unsigned long)sqlite3_blob_bytes(blobVectors) == + p->chunk_size * p->vector_columns[i].dimensions * + sizeof(int8_t)); + break; + case SQLITE_VEC_ELEMENT_TYPE_BIT: + todo_assert((unsigned long)sqlite3_blob_bytes(blobVectors) == + p->chunk_size * p->vector_columns[i].dimensions / CHAR_BIT); + break; + } + + rc = vec0Update_InsertWriteFinalStepVectors( + blobVectors, vectorDatas[i], chunk_offset, + p->vector_columns[i].dimensions, p->vector_columns[i].element_type); + todo_assert(rc == SQLITE_OK); + sqlite3_blob_close(blobVectors); + } + + // write the new rowid to the rowids column of the _chunks table + rc = sqlite3_blob_open(p->db, p->schemaName, p->shadowChunksName, "rowids", + chunk_rowid, 1, &blobChunksRowids); + todo_assert(rc == SQLITE_OK); + todo_assert(sqlite3_blob_bytes(blobChunksRowids) == + p->chunk_size * sizeof(sqlite3_int64)); + rc = sqlite3_blob_write(blobChunksRowids, &rowid, sizeof(sqlite3_int64), + chunk_offset * sizeof(sqlite3_int64)); + todo_assert(rc == SQLITE_OK); + sqlite3_blob_close(blobChunksRowids); + + // Now with all the vectors inserted, go back and update the _rowids table + // with the new chunk_rowid/chunk_offset values + sqlite3_reset(p->stmtRowidsUpdatePosition); + sqlite3_clear_bindings(p->stmtRowidsUpdatePosition); + sqlite3_bind_int64(p->stmtRowidsUpdatePosition, 1, chunk_rowid); + sqlite3_bind_int64(p->stmtRowidsUpdatePosition, 2, chunk_offset); + sqlite3_bind_int64(p->stmtRowidsUpdatePosition, 3, rowid); + rc = sqlite3_step(p->stmtRowidsUpdatePosition); + todo_assert(rc == SQLITE_DONE); + + return SQLITE_OK; +} + +/** + * @brief Handles INSERT INTO operations on a vec0 table. + * + * @return int SQLITE_OK on success, otherwise error code on failure + */ +int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, + sqlite_int64 *pRowid) { + UNUSED_PARAMETER(argc); + vec0_vtab *p = (vec0_vtab *)pVTab; + int rc; + // Rowid for the inserted row, deterimined by the inserted ID + _rowids shadow + // table + sqlite3_int64 rowid; + // Array to hold the vector data of the inserted row. Individual elements will + // have a lifetime bound to the argv[..] values. + void *vectorDatas[VEC0_MAX_VECTOR_COLUMNS]; + + // Rowid of the chunk in the _chunks shadow table that the row will be a part + // of. + sqlite3_int64 chunk_rowid; + // offset within the chunk where the rowid belongs + sqlite3_int64 chunk_offset; + + // a write-able blob of the validity column for the given chunk. Used to mark + // validity bit + sqlite3_blob *blobChunksValidity; + // buffer for the valididty column for the given chunk. TODO maybe not needed + // here? + const unsigned char *bufferChunksValidity; + + vector_cleanup cleanups[VEC0_MAX_VECTOR_COLUMNS]; + // read all the inserted vectors into vectorDatas, validate their lengths. + for (int i = 0; i < p->numVectorColumns; i++) { + sqlite3_value *valueVector = argv[2 + VEC0_COLUMN_VECTORN_START + i]; + size_t dimensions; + + char *pzError; + enum VectorElementType elementType; + int rc = vector_from_value(valueVector, &vectorDatas[i], &dimensions, + &elementType, &cleanups[i], &pzError); + todo_assert(rc == SQLITE_OK); + printf("%d %d\n", elementType, p->vector_columns[i].element_type); + assert(elementType == p->vector_columns[i].element_type); + + if (dimensions != p->vector_columns[i].dimensions) { + sqlite3_free(pVTab->zErrMsg); + pVTab->zErrMsg = sqlite3_mprintf( + "Dimension mismatch for inserted vector for the \"%.*s\" column. " + "Expected %d dimensions but received %d.", + p->vector_columns[i].name_length, p->vector_columns[i].name, + p->vector_columns[i].dimensions, dimensions); + return SQLITE_ERROR; + } + } + + // Cannot insert a value in the hidden "distance" column + if (sqlite3_value_type(argv[2 + vec0_column_distance_idx(p)]) != + SQLITE_NULL) { + SET_VTAB_ERROR("TODO distance provided in INSERT operation."); + return SQLITE_ERROR; + } + // Cannot insert a value in the hidden "k" column + if (sqlite3_value_type(argv[2 + vec0_column_k_idx(p)]) != SQLITE_NULL) { + SET_VTAB_ERROR("TODO k provided in INSERT operation."); + return SQLITE_ERROR; + } + + // Step #1: Insert/get a rowid for this row, from the _rowids table. + rc = vec0Update_InsertRowidStep(p, argv[2 + VEC0_COLUMN_ID], &rowid); + todo_assert(rc == SQLITE_OK); + + // Step #2: Find the next "available" position in the _chunks table for this + // row. + rc = vec0Update_InsertNextAvailableStep(p, &chunk_rowid, &chunk_offset, + &blobChunksValidity, + &bufferChunksValidity); + todo_assert(rc == SQLITE_OK); + + // Step #3: With the next available chunk position, write out all the vectors + // to their specified location. + rc = vec0Update_InsertWriteFinalStep(p, chunk_rowid, chunk_offset, rowid, + vectorDatas, blobChunksValidity, + bufferChunksValidity); + todo_assert(rc == SQLITE_OK); + + for (int i = 0; i < p->numVectorColumns; i++) { + cleanups[i](vectorDatas[i]); + } + + sqlite3_blob_close(blobChunksValidity); + sqlite3_free((void *)bufferChunksValidity); + *pRowid = rowid; + + return SQLITE_OK; +} + +int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite_int64 rowid) { + vec0_vtab *p = (vec0_vtab *)pVTab; + int rc; + sqlite3_int64 chunk_id; + sqlite3_int64 chunk_offset; + sqlite3_blob *blobChunksValidity = NULL; + + // 1. get chunk_id and chunk_offset from _rowids + rc = vec0_get_chunk_position(p, rowid, &chunk_id, &chunk_offset); + todo_assert(rc == SQLITE_OK); + + // 2. ensure chunks.validity bit is 1, then set to 0 + rc = sqlite3_blob_open(p->db, p->schemaName, p->shadowChunksName, "validity", + chunk_id, 1, &blobChunksValidity); + assert(rc == SQLITE_OK); + char unsigned bx; + rc = sqlite3_blob_read(blobChunksValidity, &bx, sizeof(bx), + chunk_offset / CHAR_BIT); + todo_assert(rc == SQLITE_OK); + todo_assert(bx >> (chunk_offset % CHAR_BIT)); + char unsigned mask = ~(1 << (chunk_offset % CHAR_BIT)); + char result = bx & mask; + rc = sqlite3_blob_write(blobChunksValidity, &result, sizeof(bx), + chunk_offset / CHAR_BIT); + todo_assert(rc == SQLITE_OK); + sqlite3_blob_close(blobChunksValidity); + + // 3. zero out rowid in chunks.rowids TODO + + // 4. zero out any data in vector chunks tables TODO + + // 5. delete from _rowids table + char *zSql = + sqlite3_mprintf("DELETE FROM " VEC0_SHADOW_ROWIDS_NAME " WHERE rowid = ?", + p->schemaName, p->tableName); + todo_assert(zSql); + sqlite3_stmt *stmt; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + todo_assert(rc == SQLITE_OK); + sqlite3_bind_int64(stmt, 1, rowid); + rc = sqlite3_step(stmt); + todo_assert(SQLITE_DONE); + sqlite3_finalize(stmt); + + return SQLITE_OK; +} + +int vec0Update_UpdateOnRowid(sqlite3_vtab *pVTab, int argc, + sqlite3_value **argv) { + UNUSED_PARAMETER(argc); + vec0_vtab *p = (vec0_vtab *)pVTab; + int rc; + sqlite3_int64 chunk_id; + sqlite3_int64 chunk_offset; + sqlite3_int64 rowid = sqlite3_value_int64(argv[0]); + + // 1. get chunk_id and chunk_offset from _rowids + rc = vec0_get_chunk_position(p, rowid, &chunk_id, &chunk_offset); + todo_assert(rc == SQLITE_OK); + + // 2) iterate over all new vectors, update the vectors + + // read all the inserted vectors into vectorDatas, validate their lengths. + for (int i = 0; i < p->numVectorColumns; i++) { + sqlite3_value *valueVector = argv[2 + VEC0_COLUMN_VECTORN_START + i]; + size_t dimensions; + void *vector = (void *)sqlite3_value_blob(valueVector); + switch (p->vector_columns[i].element_type) { + case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: + dimensions = sqlite3_value_bytes(valueVector) / sizeof(float); + break; + case SQLITE_VEC_ELEMENT_TYPE_INT8: + dimensions = sqlite3_value_bytes(valueVector) * sizeof(int8_t); + break; + case SQLITE_VEC_ELEMENT_TYPE_BIT: + dimensions = sqlite3_value_bytes(valueVector) * CHAR_BIT; + break; + } + if (dimensions != p->vector_columns[i].dimensions) { + SET_VTAB_ERROR("TODO vector length dont make sense."); + sqlite3_free(pVTab->zErrMsg); + pVTab->zErrMsg = + sqlite3_mprintf("Vector length mismatch on '%s' column: Expected %d " + "dimensions, found %d", + p->vector_columns[i].name, + p->vector_columns[i].dimensions, dimensions); + return SQLITE_ERROR; + } + + sqlite3_blob *blobVectors; + rc = sqlite3_blob_open(p->db, p->schemaName, p->shadowVectorChunksNames[i], + "vectors", chunk_id, 1, &blobVectors); + todo_assert(rc == SQLITE_OK); + // TODO rename this functions + rc = vec0Update_InsertWriteFinalStepVectors( + blobVectors, vector, chunk_offset, p->vector_columns[i].dimensions, + p->vector_columns[i].element_type); + todo_assert(rc == SQLITE_OK); + sqlite3_blob_close(blobVectors); + } + + return SQLITE_OK; +} + +static int vec0Update(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, + sqlite_int64 *pRowid) { + // DELETE operation + if (argc == 1 && sqlite3_value_type(argv[0]) != SQLITE_NULL) { + return vec0Update_Delete(pVTab, sqlite3_value_int64(argv[0])); + } + // INSERT operation + else if (argc > 1 && sqlite3_value_type(argv[0]) == SQLITE_NULL) { + return vec0Update_Insert(pVTab, argc, argv, pRowid); + } + // UPDATE operation + else if (argc > 1 && sqlite3_value_type(argv[0]) != SQLITE_NULL) { + if ((sqlite3_value_type(argv[0]) == SQLITE_INTEGER) && + (sqlite3_value_type(argv[1]) == SQLITE_INTEGER) && + (sqlite3_value_int64(argv[0]) == sqlite3_value_int64(argv[1]))) { + return vec0Update_UpdateOnRowid(pVTab, argc, argv); + } + + SET_VTAB_ERROR("UPDATE operation on rowids with vec0 is not supported."); + return SQLITE_ERROR; + } + // unknown operation + else { + SET_VTAB_ERROR("Unrecognized xUpdate operation provided for vec0."); + return SQLITE_ERROR; + } +} + +static int vec0ShadowName(const char *zName) { + // TODO multiple vector_chunk tables + static const char *azName[] = {"rowids", "chunks", "vector_chunks"}; + + for (size_t i = 0; i < sizeof(azName) / sizeof(azName[0]); i++) { + if (sqlite3_stricmp(zName, azName[i]) == 0) + return 1; + } + return 0; +} + +static sqlite3_module vec0Module = { + /* iVersion */ 3, + /* xCreate */ vec0Create, + /* xConnect */ vec0Connect, + /* xBestIndex */ vec0BestIndex, + /* xDisconnect */ vec0Disconnect, + /* xDestroy */ vec0Destroy, + /* xOpen */ vec0Open, + /* xClose */ vec0Close, + /* xFilter */ vec0Filter, + /* xNext */ vec0Next, + /* xEof */ vec0Eof, + /* xColumn */ vec0Column, + /* xRowid */ vec0Rowid, + /* xUpdate */ vec0Update, + /* xBegin */ 0, + /* xSync */ 0, + /* xCommit */ 0, + /* xRollback */ 0, + /* xFindFunction */ 0, + /* xRename */ 0, // TODO + /* xSavepoint */ 0, + /* xRelease */ 0, + /* xRollbackTo */ 0, + /* xShadowName */ vec0ShadowName, + /* xIntegrity */ 0, // TODO +}; +#pragma endregion + +int sqlite3_mmap_warm(sqlite3 *db, const char *zDb) { + int rc = SQLITE_OK; + char *zSql = 0; + int pgsz = 0; + unsigned int nTotal = 0; + + if (0 == sqlite3_get_autocommit(db)) + return SQLITE_MISUSE; + + /* Open a read-only transaction on the file in question */ + zSql = sqlite3_mprintf("BEGIN; SELECT * FROM %s%q%ssqlite_schema", + (zDb ? "'" : ""), (zDb ? zDb : ""), (zDb ? "'." : "")); + if (zSql == 0) + return SQLITE_NOMEM; + rc = sqlite3_exec(db, zSql, 0, 0, 0); + sqlite3_free(zSql); + + /* Find the SQLite page size of the file */ + if (rc == SQLITE_OK) { + zSql = sqlite3_mprintf("PRAGMA %s%q%spage_size", (zDb ? "'" : ""), + (zDb ? zDb : ""), (zDb ? "'." : "")); + if (zSql == 0) { + rc = SQLITE_NOMEM; + } else { + sqlite3_stmt *pPgsz = 0; + rc = sqlite3_prepare_v2(db, zSql, -1, &pPgsz, 0); + sqlite3_free(zSql); + if (rc == SQLITE_OK) { + if (sqlite3_step(pPgsz) == SQLITE_ROW) { + pgsz = sqlite3_column_int(pPgsz, 0); + } + rc = sqlite3_finalize(pPgsz); + } + if (rc == SQLITE_OK && pgsz == 0) { + rc = SQLITE_ERROR; + } + } + } + + /* Touch each mmap'd page of the file */ + if (rc == SQLITE_OK) { + int rc2; + sqlite3_file *pFd = 0; + rc = sqlite3_file_control(db, zDb, SQLITE_FCNTL_FILE_POINTER, &pFd); + if (rc == SQLITE_OK && pFd->pMethods && pFd->pMethods->iVersion >= 3) { + sqlite3_int64 iPg = 1; + sqlite3_io_methods const *p = pFd->pMethods; + while (1) { + unsigned char *pMap; + rc = p->xFetch(pFd, pgsz * iPg, pgsz, (void **)&pMap); + if (rc != SQLITE_OK || pMap == 0) + break; + + nTotal += (unsigned int)pMap[0]; + nTotal += (unsigned int)pMap[pgsz - 1]; + + rc = p->xUnfetch(pFd, pgsz * iPg, (void *)pMap); + if (rc != SQLITE_OK) + break; + iPg++; + } + sqlite3_log(SQLITE_OK, + "sqlite3_mmap_warm_cache: Warmed up %d pages of %s", + iPg == 1 ? 0 : iPg, sqlite3_db_filename(db, zDb)); + } + + rc2 = sqlite3_exec(db, "END", 0, 0, 0); + if (rc == SQLITE_OK) + rc = rc2; + } + + (void)nTotal; + return rc; +} + +#ifdef _WIN32 +__declspec(dllexport) +#endif + int sqlite3_vec_warm_mmap(sqlite3 *db, char **pzErrMsg, + const sqlite3_api_routines *pApi) { + UNUSED_PARAMETER(pzErrMsg); + SQLITE_EXTENSION_INIT2(pApi); + return sqlite3_mmap_warm(db, NULL); +} + +#ifdef SQLITE_VEC_ENABLE_AVX +#define SQLITE_VEC_DEBUG_BUILD_AVX "avx" +#else +#define SQLITE_VEC_DEBUG_BUILD_AVX "" +#endif +#ifdef SQLITE_VEC_ENABLE_NEON +#define SQLITE_VEC_DEBUG_BUILD_NEON "neon" +#else +#define SQLITE_VEC_DEBUG_BUILD_NEON "" +#endif + +#define SQLITE_VEC_DEBUG_BUILD \ + SQLITE_VEC_DEBUG_BUILD_AVX " " SQLITE_VEC_DEBUG_BUILD_NEON + +#define SQLITE_VEC_DEBUG_STRING \ + "Version: " SQLITE_VEC_VERSION "\n" \ + "Date: " SQLITE_VEC_DATE "\n" \ + "Commit: " SQLITE_VEC_SOURCE "\n" \ + "Build flags: " SQLITE_VEC_DEBUG_BUILD + +#ifndef SQLITE_SUBTYPE +#define SQLITE_SUBTYPE 0x000100000 +#endif + +#ifndef SQLITE_RESULT_SUBTYPE +#define SQLITE_RESULT_SUBTYPE 0x001000000 +#endif + +#ifdef _WIN32 +__declspec(dllexport) +#endif + int sqlite3_vec_init(sqlite3 *db, char **pzErrMsg, + const sqlite3_api_routines *pApi) { + SQLITE_EXTENSION_INIT2(pApi); + int rc = SQLITE_OK; + const int DEFAULT_FLAGS = + SQLITE_UTF8 | SQLITE_INNOCUOUS | SQLITE_DETERMINISTIC; + + static const struct { + char *zFName; + void (*xFunc)(sqlite3_context *, int, sqlite3_value **); + int nArg; + int flags; + void *p; + } aFunc[] = { + // clang-format off + {"vec_version", _static_text_func, 0, DEFAULT_FLAGS, SQLITE_VEC_VERSION }, + {"vec_debug", _static_text_func, 0, DEFAULT_FLAGS, SQLITE_VEC_DEBUG_STRING }, + {"vec_distance_l2", vec_distance_l2, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE, NULL }, + {"vec_distance_hamming",vec_distance_hamming, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE, NULL }, + {"vec_distance_cosine", vec_distance_cosine, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE, NULL }, + {"vec_length", vec_length, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE, NULL }, + {"vec_to_json", vec_to_json, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL }, + {"vec_add", vec_add, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL }, + {"vec_sub", vec_sub, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL }, + {"vec_slice", vec_slice, 3, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL }, + {"vec_normalize", vec_normalize, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL }, + {"vec_f32", vec_f32, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL }, + {"vec_bit", vec_bit, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL }, + {"vec_int8", vec_int8, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL }, + {"vec_quantize_i8", vec_quantize_i8, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL }, + {"vec_quantize_i8", vec_quantize_i8, 3, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL }, + {"vec_quantize_binary", vec_quantize_binary, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL }, + // clang-format on + }; + + static const struct { + char *name; + const sqlite3_module *module; + } aMod[] = { + // clang-format off + {"vec0", &vec0Module}, + {"vec_each", &vec_eachModule}, + {"vec_npy_each", &vec_npy_eachModule}, + // clang-format on + }; + + for (unsigned long i = 0; + i < sizeof(aFunc) / sizeof(aFunc[0]) && rc == SQLITE_OK; i++) { + rc = sqlite3_create_function_v2(db, aFunc[i].zFName, aFunc[i].nArg, + aFunc[i].flags, aFunc[i].p, aFunc[i].xFunc, + NULL, NULL, NULL); + if (rc != SQLITE_OK) { + *pzErrMsg = sqlite3_mprintf("Error creating function %s: %s", + aFunc[i].zFName, sqlite3_errmsg(db)); + return rc; + } + } + + for (unsigned long i = 0; i < countof(aMod) && rc == SQLITE_OK; i++) { + rc = sqlite3_create_module_v2(db, aMod[i].name, aMod[i].module, NULL, NULL); + if (rc != SQLITE_OK) { + *pzErrMsg = sqlite3_mprintf("Error creating module %s: %s", aMod[i].name, + sqlite3_errmsg(db)); + return rc; + } + } + + return SQLITE_OK; +} + +#ifdef _WIN32 +__declspec(dllexport) +#endif + int sqlite3_vec_fs_read_init(sqlite3 *db, char **pzErrMsg, + const sqlite3_api_routines *pApi) { + UNUSED_PARAMETER(pzErrMsg); + SQLITE_EXTENSION_INIT2(pApi); + int rc = SQLITE_OK; + rc = sqlite3_create_function_v2(db, "vec_npy_file", 1, SQLITE_RESULT_SUBTYPE, + NULL, vec_npy_file, NULL, NULL, NULL); + return rc; +} + +#ifdef SQLITE_VEC_ENABLE_TRACE_ENTRYPOINT + +int trace(unsigned int x, void *p1, void *p2, void *p3) { + if (x == SQLITE_TRACE_STMT) { + sqlite3_stmt *stmt = (sqlite3_stmt *)p2; + char *zSql = sqlite3_expanded_sql(stmt); + printf("%s\n", zSql); + } +} +#ifdef _WIN32 +__declspec(dllexport) +#endif + int trace_debug(sqlite3 *db, char **pzErrMsg, + const sqlite3_api_routines *pApi) { + UNUSED_PARAMETER(pzErrMsg); + SQLITE_EXTENSION_INIT2(pApi); + sqlite3_trace_v2(db, SQLITE_TRACE_STMT, trace, NULL); + return SQLITE_OK; +} +#endif diff --git a/sqlite-vec.h.tmpl b/sqlite-vec.h.tmpl new file mode 100644 index 0000000..a644326 --- /dev/null +++ b/sqlite-vec.h.tmpl @@ -0,0 +1,11 @@ +#include "sqlite3ext.h" + +#define SQLITE_VEC_VERSION "v${VERSION}" +#define SQLITE_VEC_DATE "${DATE}" +#define SQLITE_VEC_SOURCE "${SOURCE}" + + +int sqlite3_vec_init(sqlite3 *db, char **pzErrMsg, + const sqlite3_api_routines *pApi); +int sqlite3_vec_fs_read_init(sqlite3 *db, char **pzErrMsg, + const sqlite3_api_routines *pApi); diff --git a/tests/.gitignore b/tests/.gitignore new file mode 100644 index 0000000..2f7896d --- /dev/null +++ b/tests/.gitignore @@ -0,0 +1 @@ +target/ diff --git a/tests/Cargo.lock b/tests/Cargo.lock new file mode 100644 index 0000000..cd9b518 --- /dev/null +++ b/tests/Cargo.lock @@ -0,0 +1,16 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "cc" +version = "1.0.90" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cd6604a82acf3039f1144f54b8eb34e91ffba622051189e71b781822d5ee1f5" + +[[package]] +name = "tests" +version = "0.1.0" +dependencies = [ + "cc", +] diff --git a/tests/Cargo.toml b/tests/Cargo.toml new file mode 100644 index 0000000..32b675d --- /dev/null +++ b/tests/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "tests" +version = "0.1.0" +edition = "2021" + +[dependencies] + +[build-dependencies] +cc = "1.0" + +[[bin]] +name = "unittest" +path = "unittest.rs" + + diff --git a/tests/build.rs b/tests/build.rs new file mode 100644 index 0000000..842cf2e --- /dev/null +++ b/tests/build.rs @@ -0,0 +1,13 @@ +use std::env; +use std::path::{Path, PathBuf}; +use std::process::Command; + +fn main() { + cc::Build::new() + .file("../sqlite-vec.c") + .include(".") + .static_flag(true) + .compile("sqlite-vec-internal"); + println!("cargo:rerun-if-changed=usleep.c"); + println!("cargo:rerun-if-changed=build.rs"); +} diff --git a/tests/sqlite-vec-internal.h b/tests/sqlite-vec-internal.h new file mode 100644 index 0000000..d81ab7a --- /dev/null +++ b/tests/sqlite-vec-internal.h @@ -0,0 +1,12 @@ +#include + +int min_idx( + // list of distances, size n + const float *distances, + // number of entries in distances + int32_t n, + // output array of size k, the indicies of the lowest k values in distances + int32_t *out, + // output number of elements + int32_t k +); diff --git a/tests/test-correctness.py b/tests/test-correctness.py new file mode 100644 index 0000000..045ca4d --- /dev/null +++ b/tests/test-correctness.py @@ -0,0 +1,49 @@ +import sqlite3 +import json + +db = sqlite3.connect("test2.db") +db.enable_load_extension(True) +db.load_extension("dist/vec0") +db.enable_load_extension(False) +db.row_factory = sqlite3.Row +db.execute('attach database "sift1m-base.db" as sift1m') + + +#def test_sift1m(): +rows = db.execute( + ''' + with q as ( + select rowid, vector, k100 from sift1m.sift1m_query limit 10 + ), + results as ( + select + q.rowid as query_rowid, + vec_sift1m.rowid as vec_rowid, + distance, + k100 as k100_groundtruth + from q + join vec_sift1m + where + vec_sift1m.vector match q.vector + and k = 100 + order by distance + ) + select + query_rowid, + json_group_array(vec_rowid order by distance) as topk, + k100_groundtruth, + json_group_array(vec_rowid order by distance) == k100_groundtruth + from results + group by 1; + ''').fetchall() + +results = [] +for row in rows: + actual = json.loads(row["topk"]) + expected = json.loads(row["k100_groundtruth"]) + + ncorrect = sum([x in expected for x in actual]) + results.append(ncorrect / 100.0) + +from statistics import mean +print(mean(results)) diff --git a/tests/test-loadable.py b/tests/test-loadable.py new file mode 100644 index 0000000..7b08911 --- /dev/null +++ b/tests/test-loadable.py @@ -0,0 +1,874 @@ +# ruff: noqa: E731 + +import re +from typing import List +import sqlite3 +import unittest +from random import random +import struct +import inspect +import pytest +import json +import numpy as np +from math import isclose + +EXT_PATH = "./dist/vec0" + + +def bitmap_full(n: int) -> bytearray: + assert (n % 8) == 0 + return bytes([0xFF] * int(n / 8)) + + +def bitmap_zerod(n: int) -> bytearray: + assert (n % 8) == 0 + return bytes([0x00] * int(n / 8)) + + +def f32_zerod(n: int) -> bytearray: + return bytes([0x00, 0x00, 0x00, 0x00] * int(n)) + + +CHAR_BIT = 8 + + +def _f32(list): + return struct.pack("%sf" % len(list), *list) + + +def _int8(list): + return struct.pack("%sb" % len(list), *list) + + +def connect(ext, path=":memory:"): + db = sqlite3.connect(path) + + db.execute( + "create temp table base_functions as select name from pragma_function_list" + ) + db.execute("create temp table base_modules as select name from pragma_module_list") + + db.enable_load_extension(True) + db.load_extension(ext) + + db.execute( + "create temp table loaded_functions as select name from pragma_function_list where name not in (select name from base_functions) order by name" + ) + db.execute( + "create temp table loaded_modules as select name from pragma_module_list where name not in (select name from base_modules) order by name" + ) + + db.row_factory = sqlite3.Row + return db + + +db = connect(EXT_PATH) + +# db.load_extension(EXT_PATH, entrypoint="trace_debug") + + +def explain_query_plan(sql): + return db.execute("explain query plan " + sql).fetchone()["detail"] + + +def execute_all(cursor, sql, args=None): + if args is None: + args = [] + results = cursor.execute(sql, args).fetchall() + return list(map(lambda x: dict(x), results)) + + +def spread_args(args): + return ",".join(["?"] * len(args)) + + +FUNCTIONS = [ + "vec_add", + "vec_bit", + "vec_debug", + "vec_distance_cosine", + "vec_distance_hamming", + "vec_distance_l2", + "vec_f32", + "vec_int8", + "vec_length", + "vec_normalize", + "vec_quantize_binary", + "vec_quantize_i8", + "vec_quantize_i8", + "vec_slice", + "vec_sub", + "vec_to_json", + "vec_version", +] +MODULES = ["vec0", "vec_each", "vec_npy_each"] + + +def test_funcs(): + funcs = list( + map( + lambda a: a[0], + db.execute("select name from loaded_functions").fetchall(), + ) + ) + assert funcs == FUNCTIONS + + +def test_modules(): + modules = list( + map(lambda a: a[0], db.execute("select name from loaded_modules").fetchall()) + ) + assert modules == MODULES + + +def test_vec_version(): + vec_version = lambda *args: db.execute("select vec_version()", args).fetchone()[0] + assert vec_version()[0] == "v" + + +def test_vec_debug(): + vec_debug = lambda *args: db.execute("select vec_debug()", args).fetchone()[0] + d = vec_debug().split("\n") + assert len(d) == 4 + + +def test_vec_bit(): + vec_bit = lambda *args: db.execute("select vec_bit(?)", args).fetchone()[0] + assert vec_bit(b"\xff") == b"\xff" + + assert db.execute("select subtype(vec_bit(X'FF'))").fetchone()[0] == 224 + + with pytest.raises( + sqlite3.OperationalError, match="zero-length vectors are not supported." + ): + db.execute("select vec_bit(X'')").fetchone() + + for x in [None, "text", 1, 1.999]: + with pytest.raises( + sqlite3.OperationalError, match="Unknown type for bitvector." + ): + db.execute("select vec_bit(?)", [x]).fetchone() + + +def test_vec_f32(): + vec_f32 = lambda *args: db.execute("select vec_f32(?)", args).fetchone()[0] + assert vec_f32(b"\x00\x00\x00\x00") == b"\x00\x00\x00\x00" + assert vec_f32("[0.0000]") == b"\x00\x00\x00\x00" + # fmt: off + tests = [ + [0], + [0, 0, 0, 0], + [1, -1, 10, -10], + [-0, 0, .0001, -.0001], + ] + # fmt: on + for test in tests: + assert vec_f32(json.dumps(test)) == _f32(test) + + assert db.execute("select subtype(vec_f32(X'00000000'))").fetchone()[0] == 223 + + with pytest.raises( + sqlite3.OperationalError, match="zero-length vectors are not supported." + ): + vec_f32(b"") + + for invalid in [None, 1, 1.2]: + with pytest.raises( + sqlite3.OperationalError, + match=re.escape( + "Input must have type BLOB (compact format) or TEXT (JSON)", + ), + ): + vec_f32(invalid) + + with pytest.raises( + sqlite3.OperationalError, + match="invalid float32 vector BLOB length. Must be divisible by 4, found 5", + ): + vec_f32(b"aaaaa") + with pytest.raises( + sqlite3.OperationalError, + match=re.escape("JSON array parsing error: Input does not start with '['"), + ): + vec_f32("1]") + # TODO mas tests + + # TODO different error message + with pytest.raises( + sqlite3.OperationalError, + match="zero-length vectors are not supported.", + ): + vec_f32("[") + + # vec_f32("[]") + + +def test_vec_int8(): + vec_int8 = lambda *args: db.execute("select vec_int8(?)", args).fetchone()[0] + assert vec_int8(b"\x00") == _int8([0]) + assert vec_int8(b"\x00\x0f") == _int8([0, 15]) + assert db.execute("select subtype(vec_int8(?))", [b"\x00"]).fetchone()[0] == 225 + + +def npy_cosine(a, b): + return 1 - (np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))) + + +def npy_l2(a, b): + return np.linalg.norm(a - b) + + +def test_vec_distance_cosine(): + vec_distance_cosine = lambda *args, a="?", b="?": db.execute( + f"select vec_distance_cosine({a}, {b})", args + ).fetchone()[0] + + def check(a, b, dtype=np.float32): + if dtype == np.float32: + transform = "?" + elif dtype == np.int8: + transform = "vec_int8(?)" + a = np.array(a, dtype=dtype) + b = np.array(b, dtype=dtype) + + x = vec_distance_cosine(a, b, a=transform, b=transform) + y = npy_cosine(a, b) + assert isclose(x, y, abs_tol=1e-6) + + check([1.2, 0.1], [0.4, -0.4]) + check([-1.2, -0.1], [-0.4, 0.4]) + check([1, 2, 3], [-9, -8, -7], dtype=np.int8) + assert vec_distance_cosine("[1.1, 1.0]", "[1.2, 1.2]") == 0.001131898257881403 + + +def test_vec_distance_hamming(): + vec_distance_hamming = lambda *args: db.execute( + "select vec_distance_hamming(vec_bit(?), vec_bit(?))", args + ).fetchone()[0] + assert vec_distance_hamming(b"\xff", b"\x00") == 8 + assert vec_distance_hamming(b"\xff", b"\x01") == 7 + assert vec_distance_hamming(b"\xab", b"\xab") == 0 + + with pytest.raises( + sqlite3.OperationalError, + match="Cannot calculate hamming distance between two float32 vectors.", + ): + db.execute("select vec_distance_hamming(vec_f32('[1.0]'), vec_f32('[1.0]'))") + + with pytest.raises( + sqlite3.OperationalError, + match="Cannot calculate hamming distance between two int8 vectors.", + ): + db.execute("select vec_distance_hamming(vec_int8(X'FF'), vec_int8(X'FF'))") + + +def test_vec_distance_l2(): + vec_distance_l2 = lambda *args, a="?", b="?": db.execute( + f"select vec_distance_l2({a}, {b})", args + ).fetchone()[0] + + def check(a, b, dtype=np.float32): + if dtype == np.float32: + transform = "?" + elif dtype == np.int8: + transform = "vec_int8(?)" + a = np.array(a, dtype=dtype) + b = np.array(b, dtype=dtype) + + x = vec_distance_l2(a, b, a=transform, b=transform) + y = npy_l2(a, b) + assert isclose(x, y, abs_tol=1e-6) + + check([1.2, 0.1], [0.4, -0.4]) + check([-1.2, -0.1], [-0.4, 0.4]) + check([1, 2, 3], [-9, -8, -7], dtype=np.int8) + + +def test_vec_length(): + def test_f32(): + vec_length = lambda *args: db.execute("select vec_length(?)", args).fetchone()[ + 0 + ] + assert vec_length(b"\xAA\xBB\xCC\xDD") == 1 + assert vec_length(b"\xAA\xBB\xCC\xDD\x01\x02\x03\x04") == 2 + assert vec_length(f32_zerod(1024)) == 1024 + + with pytest.raises( + sqlite3.OperationalError, match="zero-length vectors are not supported." + ): + assert vec_length(b"") == 0 + with pytest.raises( + sqlite3.OperationalError, match="zero-length vectors are not supported." + ): + vec_length("[]") + + def test_int8(): + vec_length_int8 = lambda *args: db.execute( + "select vec_length(vec_int8(?))", args + ).fetchone()[0] + assert vec_length_int8(b"\xAA") == 1 + assert vec_length_int8(b"\xAA\xBB\xCC\xDD") == 4 + assert vec_length_int8(b"\xAA\xBB\xCC\xDD\x01\x02\x03\x04") == 8 + + with pytest.raises( + sqlite3.OperationalError, match="zero-length vectors are not supported." + ): + assert vec_length_int8(b"") == 0 + + def test_bit(): + vec_length_bit = lambda *args: db.execute( + "select vec_length(vec_bit(?))", args + ).fetchone()[0] + assert vec_length_bit(b"\xAA") == 8 + assert vec_length_bit(b"\xAA\xBB\xCC\xDD") == 8 * 4 + assert vec_length_bit(b"\xAA\xBB\xCC\xDD\x01\x02\x03\x04") == 8 * 8 + + with pytest.raises( + sqlite3.OperationalError, match="zero-length vectors are not supported." + ): + assert vec_length_bit(b"") == 0 + + test_f32() + test_int8() + test_bit() + + +def test_vec_normalize(): + vec_normalize = lambda *args: db.execute( + "select vec_normalize(?)", args + ).fetchone()[0] + assert list(struct.unpack_from("4f", vec_normalize(_f32([1, 2, -1, -2])))) == [ + 0.3162277638912201, + 0.6324555277824402, + -0.3162277638912201, + -0.6324555277824402, + ] + + +def test_vec_slice(): + vec_slice = lambda *args, f="?": db.execute( + f"select vec_slice({f}, ?, ?)", args + ).fetchone()[0] + assert vec_slice(_f32([1.1, 2.2, 3.3]), 0, 3) == _f32([1.1, 2.2, 3.3]) + assert vec_slice(_f32([1.1, 2.2, 3.3]), 0, 2) == _f32([1.1, 2.2]) + assert vec_slice(_f32([1.1, 2.2, 3.3]), 0, 1) == _f32([1.1]) + assert vec_slice(_int8([1, 2, 3]), 0, 3, f="vec_int8(?)") == _int8([1, 2, 3]) + assert vec_slice(_int8([1, 2, 3]), 0, 2, f="vec_int8(?)") == _int8([1, 2]) + assert vec_slice(_int8([1, 2, 3]), 0, 1, f="vec_int8(?)") == _int8([1]) + assert vec_slice(b"\xAA\xBB\xCC\xDD", 0, 8, f="vec_bit(?)") == b"\xAA" + assert vec_slice(b"\xAA\xBB\xCC\xDD", 8, 16, f="vec_bit(?)") == b"\xBB" + assert vec_slice(b"\xAA\xBB\xCC\xDD", 8, 24, f="vec_bit(?)") == b"\xBB\xCC" + assert vec_slice(b"\xAA\xBB\xCC\xDD", 0, 32, f="vec_bit(?)") == b"\xAA\xBB\xCC\xDD" + + with pytest.raises( + sqlite3.OperationalError, match="start index must be divisible by 8." + ): + vec_slice(b"\xAA\xBB\xCC\xDD", 2, 32, f="vec_bit(?)") + + with pytest.raises( + sqlite3.OperationalError, match="end index must be divisible by 8." + ): + vec_slice(b"\xAA\xBB\xCC\xDD", 0, 31, f="vec_bit(?)") + + with pytest.raises( + sqlite3.OperationalError, match="slice 'start' index must be a postive number." + ): + vec_slice(b"\xab\xab\xab\xab", -1, 1) + + with pytest.raises( + sqlite3.OperationalError, match="slice 'end' index must be a postive number." + ): + vec_slice(b"\xab\xab\xab\xab", 0, -3) + with pytest.raises( + sqlite3.OperationalError, + match="slice 'start' index is greater than the number of dimensions", + ): + vec_slice(b"\xab\xab\xab\xab", 2, 3) + with pytest.raises( + sqlite3.OperationalError, + match="slice 'end' index is greater than the number of dimensions", + ): + vec_slice(b"\xab\xab\xab\xab", 0, 2) + with pytest.raises( + sqlite3.OperationalError, + match="slice 'start' index is greater than 'end' index", + ): + vec_slice(b"\xab\xab\xab\xab", 1, 0) + + +def test_vec_add(): + vec_add = lambda *args, a="?", b="?": db.execute( + f"select vec_add({a}, {b})", args + ).fetchone()[0] + assert vec_add("[1]", "[2]") == _f32([3]) + assert vec_add("[.1]", "[.2]") == _f32([0.3]) + assert vec_add(_int8([1]), _int8([2]), a="vec_int8(?)", b="vec_int8(?)") == _int8( + [3] + ) + + with pytest.raises( + sqlite3.OperationalError, + match="Cannot add two bitvectors together.", + ): + vec_add(b"0xff", b"0xff", a="vec_bit(?)", b="vec_bit(?)") + + with pytest.raises( + sqlite3.OperationalError, + match="Vector type mistmatch. First vector has type float32, while the second has type int8.", + ): + vec_add(_f32([1]), _int8([2]), b="vec_int8(?)") + with pytest.raises( + sqlite3.OperationalError, + match="Vector type mistmatch. First vector has type int8, while the second has type float32.", + ): + vec_add(_int8([2]), _f32([1]), a="vec_int8(?)") + + +def test_vec_sub(): + vec_sub = lambda *args, a="?", b="?": db.execute( + f"select vec_sub({a}, {b})", args + ).fetchone()[0] + assert vec_sub("[1]", "[2]") == _f32([-1]) + assert vec_sub("[.1]", "[.2]") == _f32([-0.1]) + assert vec_sub(_int8([11]), _int8([2]), a="vec_int8(?)", b="vec_int8(?)") == _int8( + [9] + ) + + with pytest.raises( + sqlite3.OperationalError, + match="Cannot subtract two bitvectors together.", + ): + vec_sub(b"0xff", b"0xff", a="vec_bit(?)", b="vec_bit(?)") + + with pytest.raises( + sqlite3.OperationalError, + match="Vector type mistmatch. First vector has type float32, while the second has type int8.", + ): + vec_sub(_f32([1]), _int8([2]), b="vec_int8(?)") + with pytest.raises( + sqlite3.OperationalError, + match="Vector type mistmatch. First vector has type int8, while the second has type float32.", + ): + vec_sub(_int8([2]), _f32([1]), a="vec_int8(?)") + + +def test_vec_to_json(): + vec_to_json = lambda *args, input="?": db.execute( + f"select vec_to_json({input})", args + ).fetchone()[0] + assert vec_to_json("[1, 2, 3]") == "[1.000000,2.000000,3.000000]" + assert vec_to_json(b"\x00\x00\x00\x00\x00\x00\x80\xbf") == "[0.000000,-1.000000]" + assert vec_to_json(b"\x04", input="vec_int8(?)") == "[4]" + assert vec_to_json(b"\x04\xff", input="vec_int8(?)") == "[4,-1]" + assert vec_to_json(b"\xff", input="vec_bit(?)") == "[1,1,1,1,1,1,1,1]" + assert vec_to_json(b"\x0f", input="vec_bit(?)") == "[1,1,1,1,0,0,0,0]" + + +@pytest.mark.skip(reason="TODO") +def test_vec_quantize_i8(): + vec_quantize_i8 = lambda *args: db.execute( + "select vec_quantize_i8()", args + ).fetchone()[0] + assert vec_quantize_i8() == 111 + + +@pytest.mark.skip(reason="TODO") +def test_vec_quantize_binary(): + vec_quantize_binary = lambda *args: db.execute( + "select vec_quantize_binary()", args + ).fetchone()[0] + assert vec_quantize_binary() == 111 + + +@pytest.mark.skip(reason="TODO") +def test_vec0(): + pass + + +def test_vec0_updates(): + db = connect(EXT_PATH) + db.execute( + """ + create virtual table t using vec0( + aaa float[128], + bbb int8[128], + ccc bit[128] + ); + """ + ) + + db.execute( + "insert into t values (?, ?, vec_int8(?), vec_bit(?))", + [ + 1, + np.full((128,), 0.0001, dtype="float32"), + np.full((128,), 4, dtype="int8"), + bitmap_full(128), + ], + ) + + assert execute_all(db, "select * from t") == [ + { + "rowid": 1, + "aaa": _f32([0.0001] * 128), + "bbb": _int8([4] * 128), + "ccc": bitmap_full(128), + } + ] + db.execute( + "update t set aaa = ? where rowid = ?", + [np.full((128,), 0.00011, dtype="float32"), 1], + ) + assert execute_all(db, "select * from t") == [ + { + "rowid": 1, + "aaa": _f32([0.00011] * 128), + "bbb": _int8([4] * 128), + "ccc": bitmap_full(128), + } + ] + + +def test_vec_each(): + vec_each_f32 = lambda *args: execute_all( + db, "select rowid, * from vec_each(vec_f32(?))", args + ) + assert vec_each_f32(_f32([1.0, 2.0, 3.0])) == [ + {"rowid": 0, "value": 1.0}, + {"rowid": 1, "value": 2.0}, + {"rowid": 2, "value": 3.0}, + ] + + +import io + + +def to_npy(arr): + buf = io.BytesIO() + np.save(buf, arr) + buf.seek(0) + return buf.read() + + +def test_vec_npy_each(): + vec_npy_each = lambda *args: execute_all( + db, "select rowid, * from vec_npy_each(?)", args + ) + assert vec_npy_each(to_npy(np.array([1.1, 2.2, 3.3], dtype=np.float32))) == [ + { + "rowid": 0, + "vector": _f32([1.1, 2.2, 3.3]), + }, + ] + assert vec_npy_each(to_npy(np.array([[1.1, 2.2, 3.3]], dtype=np.float32))) == [ + { + "rowid": 0, + "vector": _f32([1.1, 2.2, 3.3]), + }, + ] + assert vec_npy_each( + to_npy(np.array([[1.1, 2.2, 3.3], [9.9, 8.8, 7.7]], dtype=np.float32)) + ) == [ + { + "rowid": 0, + "vector": _f32([1.1, 2.2, 3.3]), + }, + { + "rowid": 1, + "vector": _f32([9.9, 8.8, 7.7]), + }, + ] + + +def test_smoke(): + db.execute("create virtual table vec_xyz using vec0( a float[2] )") + assert execute_all( + db, + "select name, ncol from pragma_table_list where name like 'vec_xyz%' order by name;", + ) == [ + { + "name": "vec_xyz", + "ncol": 4, + }, + { + "name": "vec_xyz_chunks", + "ncol": 4, + }, + { + "name": "vec_xyz_rowids", + "ncol": 4, + }, + { + "name": "vec_xyz_vector_chunks00", + "ncol": 2, + }, + ] + chunk = db.execute("select * from vec_xyz_chunks").fetchone() + assert chunk["chunk_id"] == 1 + assert chunk["validity"] == bytearray(int(1024 / 8)) + assert chunk["rowids"] == bytearray(int(1024 * 8)) + vchunk = db.execute("select * from vec_xyz_vector_chunks00").fetchone() + assert vchunk["rowid"] == 1 + assert vchunk["vectors"] == bytearray(int(1024 * 4 * 2)) + + assert ( + explain_query_plan( + "select * from vec_xyz where a match X'' order by distance limit 10" + ) + == "SCAN vec_xyz VIRTUAL TABLE INDEX 0:knn:" + ) + assert ( + explain_query_plan("select * from vec_xyz") + == "SCAN vec_xyz VIRTUAL TABLE INDEX 0:fullscan" + ) + assert ( + explain_query_plan("select * from vec_xyz where rowid = 4") + == "SCAN vec_xyz VIRTUAL TABLE INDEX 3:point" + ) + + db.execute("insert into vec_xyz(rowid, a) select 1, X'000000000000803f'") + chunk = db.execute("select * from vec_xyz_chunks").fetchone() + assert chunk["chunk_id"] == 1 + assert chunk["validity"] == b"\x01" + bytearray(int(1024 / 8) - 1) + assert chunk["rowids"] == b"\x01\x00\x00\x00\x00\x00\x00\x00" + bytearray( + int(1024 * 8) - 8 + ) + vchunk = db.execute("select * from vec_xyz_vector_chunks00").fetchone() + assert vchunk["rowid"] == 1 + assert vchunk["vectors"] == b"\x00\x00\x00\x00\x00\x00\x80\x3f" + bytearray( + int(1024 * 4 * 2) - (2 * 4) + ) + + db.execute("insert into vec_xyz(rowid, a) select 2, X'0000000000000040'") + chunk = db.execute("select * from vec_xyz_chunks").fetchone() + assert chunk[ + "rowids" + ] == b"\x01\x00\x00\x00\x00\x00\x00\x00" + b"\x02\x00\x00\x00\x00\x00\x00\x00" + bytearray( + int(1024 * 8) - 8 * 2 + ) + assert chunk["chunk_id"] == 1 + assert chunk["validity"] == b"\x03" + bytearray(int(1024 / 8) - 1) + vchunk = db.execute("select * from vec_xyz_vector_chunks00").fetchone() + assert vchunk["rowid"] == 1 + assert vchunk[ + "vectors" + ] == b"\x00\x00\x00\x00\x00\x00\x80\x3f" + b"\x00\x00\x00\x00\x00\x00\x00\x40" + bytearray( + int(1024 * 4 * 2) - (2 * 4 * 2) + ) + + db.execute("insert into vec_xyz(rowid, a) select 3, X'00000000000080bf'") + chunk = db.execute("select * from vec_xyz_chunks").fetchone() + assert chunk["chunk_id"] == 1 + assert chunk["validity"] == b"\x07" + bytearray(int(1024 / 8) - 1) + assert chunk[ + "rowids" + ] == b"\x01\x00\x00\x00\x00\x00\x00\x00" + b"\x02\x00\x00\x00\x00\x00\x00\x00" + b"\x03\x00\x00\x00\x00\x00\x00\x00" + bytearray( + int(1024 * 8) - 8 * 3 + ) + vchunk = db.execute("select * from vec_xyz_vector_chunks00").fetchone() + assert vchunk["rowid"] == 1 + assert vchunk[ + "vectors" + ] == b"\x00\x00\x00\x00\x00\x00\x80\x3f" + b"\x00\x00\x00\x00\x00\x00\x00\x40" + b"\x00\x00\x00\x00\x00\x00\x80\xbf" + bytearray( + int(1024 * 4 * 2) - (2 * 4 * 3) + ) + + # db.execute("select * from vec_xyz") + assert execute_all(db, "select * from vec_xyz") == [ + {"rowid": 1, "a": b"\x00\x00\x00\x00\x00\x00\x80?"}, + {"rowid": 2, "a": b"\x00\x00\x00\x00\x00\x00\x00@"}, + {"rowid": 3, "a": b"\x00\x00\x00\x00\x00\x00\x80\xbf"}, + ] + + +def test_vec0_stress_small_chunks(): + data = np.zeros((1000, 8), dtype=np.float32) + for i in range(1000): + data[i] = np.array([(i + 1) * 0.1] * 8) + db.execute("create virtual table vec_small using vec0(chunk_size=8, a float[8])") + assert execute_all(db, "select rowid, * from vec_small") == [] + with db: + for row in data: + db.execute("insert into vec_small(a) values (?) ", [row]) + assert execute_all(db, "select rowid, * from vec_small limit 8") == [ + {"rowid": 1, "a": _f32([0.1] * 8)}, + {"rowid": 2, "a": _f32([0.2] * 8)}, + {"rowid": 3, "a": _f32([0.3] * 8)}, + {"rowid": 4, "a": _f32([0.4] * 8)}, + {"rowid": 5, "a": _f32([0.5] * 8)}, + {"rowid": 6, "a": _f32([0.6] * 8)}, + {"rowid": 7, "a": _f32([0.7] * 8)}, + {"rowid": 8, "a": _f32([0.8] * 8)}, + ] + assert db.execute("select count(*) from vec_small").fetchone()[0] == 1000 + assert execute_all( + db, "select rowid, * from vec_small order by rowid desc limit 8" + ) == [ + {"rowid": 1000, "a": _f32([100.0] * 8)}, + {"rowid": 999, "a": _f32([99.9] * 8)}, + {"rowid": 998, "a": _f32([99.8] * 8)}, + {"rowid": 997, "a": _f32([99.7] * 8)}, + {"rowid": 996, "a": _f32([99.6] * 8)}, + {"rowid": 995, "a": _f32([99.5] * 8)}, + {"rowid": 994, "a": _f32([99.4] * 8)}, + {"rowid": 993, "a": _f32([99.3] * 8)}, + ] + assert ( + execute_all( + db, + """ + select rowid, a, distance + from vec_small + where a match ? + and k = 9 + order by distance + """, + [_f32([50.0] * 8)], + ) + == [ + { + "a": _f32([500 * 0.1] * 8), + "distance": 0.0, + "rowid": 500, + }, + { + "a": _f32([499 * 0.1] * 8), + "distance": 0.2828384041786194, + "rowid": 499, + }, + { + "a": _f32([501 * 0.1] * 8), + "distance": 0.2828384041786194, + "rowid": 501, + }, + { + "a": _f32([498 * 0.1] * 8), + "distance": 0.5656875967979431, + "rowid": 498, + }, + { + "a": _f32([502 * 0.1] * 8), + "distance": 0.5656875967979431, + "rowid": 502, + }, + { + "a": _f32([497 * 0.1] * 8), + "distance": 0.8485260009765625, + "rowid": 497, + }, + { + "a": _f32([503 * 0.1] * 8), + "distance": 0.8485260009765625, + "rowid": 503, + }, + { + "a": _f32([496 * 0.1] * 8), + "distance": 1.1313751935958862, + "rowid": 496, + }, + { + "a": _f32([504 * 0.1] * 8), + "distance": 1.1313751935958862, + "rowid": 504, + }, + ] + ) + + +def rowids_value(buffer: bytearray) -> List[int]: + assert (len(buffer) % 8) == 0 + n = int(len(buffer) / 8) + return list(struct.unpack_from(f"<{n}q", buffer)) + + +import numpy.typing as npt + + +def cosine_similarity( + vec: npt.NDArray[np.float32], mat: npt.NDArray[np.float32], do_norm: bool = True +) -> npt.NDArray[np.float32]: + sim = vec @ mat.T + if do_norm: + sim /= np.linalg.norm(vec) * np.linalg.norm(mat, axis=1) + return sim + + +def topk( + vec: npt.NDArray[np.float32], + mat: npt.NDArray[np.float32], + k: int = 5, + do_norm: bool = True, +) -> tuple[npt.NDArray[np.int32], npt.NDArray[np.float32]]: + sim = cosine_similarity(vec, mat, do_norm=do_norm) + # Rather than sorting all similarities and taking the top K, it's faster to + # argpartition and then just sort the top K. + # The difference is O(N logN) vs O(N + k logk) + indices = np.argpartition(-sim, kth=k)[:k] + top_indices = np.argsort(-sim[indices]) + return indices[top_indices], sim[top_indices] + + +def test_stress1(): + np.random.seed(1234) + data = np.random.uniform(-1.0, 1.0, (8000, 128)).astype(np.float32) + db.execute( + "create virtual table vec_stress1 using vec0( a float[128] distance_metric=cosine)" + ) + with db: + for idx, row in enumerate(data): + db.execute("insert into vec_stress1 values (?, ?)", [idx, row]) + queries = np.random.uniform(-1.0, 1.0, (100, 128)).astype(np.float32) + for q in queries: + ids, distances = topk(q, data, k=10) + rows = db.execute( + """ + select rowid, distance + from vec_stress1 + where a match ? and k = ? + order by distance + """, + [q, 10], + ).fetchall() + assert len(ids) == 10 + assert len(rows) == 10 + vec_ids = [row[0] for row in rows] + assert ids.tolist() == vec_ids + + +@pytest.mark.skip(reason="slow") +def test_stress(): + db.execute("create virtual table vec_t1 using vec0( a float[1536])") + + def rand_vec(n): + return struct.pack("%sf" % n, *list(map(lambda x: random(), range(n)))) + + for i in range(1025): + db.execute("insert into vec_t1(a) values (?)", [rand_vec(1536)]) + rows = db.execute("select validity, rowids from vec_t1_chunks").fetchall() + assert len(rows) == 2 + + assert len(rows[0]["validity"]) == 1024 / CHAR_BIT + assert len(rows[0]["rowids"]) == 1024 * CHAR_BIT + assert rows[0]["validity"] == bitmap_full(1024) + assert rowids_value(rows[0]["rowids"]) == [x + 1 for x in range(1024)] + + assert len(rows[1]["validity"]) == 1024 / CHAR_BIT + assert len(rows[1]["rowids"]) == 1024 * CHAR_BIT + assert rows[1]["validity"] == bytes([0b0000_0001]) + bitmap_zerod(1024)[1:] + assert rowids_value(rows[1]["rowids"])[0] == 1025 + + +def test_coverage(): + current_module = inspect.getmodule(inspect.currentframe()) + test_methods = [ + member[0] + for member in inspect.getmembers(current_module) + if member[0].startswith("test_") + ] + funcs_with_tests = set([x.replace("test_", "") for x in test_methods]) + for func in [*FUNCTIONS, *MODULES]: + assert func in funcs_with_tests, f"{func} is not tested" + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unittest.rs b/tests/unittest.rs new file mode 100644 index 0000000..e95c6c3 --- /dev/null +++ b/tests/unittest.rs @@ -0,0 +1,37 @@ +fn main() { + println!("Hello, world!"); + println!("{:?}", _min_idx(vec![3.0, 2.0, 1.0], 2)); +} + +fn _min_idx(distances: Vec, k: i32) -> Vec { + let mut out: Vec = vec![0; k as usize]; + + unsafe { + min_idx( + distances.as_ptr().cast(), + distances.len() as i32, + out.as_mut_ptr(), + k, + ); + } + out +} + +#[link(name = "sqlite-vec-internal")] +extern "C" { + fn min_idx(distances: *const f32, n: i32, out: *mut i32, k: i32) -> i32; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_basic() { + assert_eq!(_min_idx(vec![1.0, 2.0, 3.0], 3), vec![0, 1, 2]); + assert_eq!(_min_idx(vec![3.0, 2.0, 1.0], 3), vec![2, 1, 0]); + + assert_eq!(_min_idx(vec![1.0, 2.0, 3.0], 2), vec![0, 1]); + assert_eq!(_min_idx(vec![3.0, 2.0, 1.0], 2), vec![2, 1]); + } +} diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..f7f0676 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,22 @@ +import numpy as np +from io import BytesIO + + +def to_npy(arr): + buf = BytesIO() + np.save(buf, arr) + buf.seek(0) + return buf.read() + + +to_npy(np.array([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], dtype=np.float32)) + +print(to_npy(np.array([[1.0, 2.0]], dtype=np.float32))) +print(to_npy(np.array([1.0, 2.0], dtype=np.float32))) + +to_npy( + np.array( + [np.zeros(10), np.zeros(10), np.zeros(10), np.zeros(10), np.zeros(10)], + dtype=np.float32, + ) +)