mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 08:46:49 +02:00
Add DiskANN index for vec0 virtual table
Add DiskANN graph-based index: builds a Vamana graph with configurable R (max degree) and L (search list size, separate for insert/query), supports int8 quantization with rescore, lazy reverse-edge replacement, pre-quantized query optimization, and insert buffer reuse. Includes shadow table management, delete support, KNN integration, compile flag (SQLITE_VEC_ENABLE_DISKANN), release-demo workflow, fuzz targets, and tests. Fixes rescore int8 quantization bug.
This commit is contained in:
parent
e2c38f387c
commit
575371d751
23 changed files with 6550 additions and 135 deletions
118
.github/workflows/release-demo.yml
vendored
Normal file
118
.github/workflows/release-demo.yml
vendored
Normal file
|
|
@ -0,0 +1,118 @@
|
||||||
|
name: "Release Demo (DiskANN)"
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [diskann-yolo2]
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
jobs:
|
||||||
|
build-linux-x86_64-extension:
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- run: ./scripts/vendor.sh
|
||||||
|
- run: make loadable static
|
||||||
|
- uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: sqlite-vec-linux-x86_64-extension
|
||||||
|
path: dist/*
|
||||||
|
build-linux-aarch64-extension:
|
||||||
|
runs-on: ubuntu-22.04-arm
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- run: ./scripts/vendor.sh
|
||||||
|
- run: make loadable static
|
||||||
|
- uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: sqlite-vec-linux-aarch64-extension
|
||||||
|
path: dist/*
|
||||||
|
build-macos-x86_64-extension:
|
||||||
|
runs-on: macos-15-intel
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- run: ./scripts/vendor.sh
|
||||||
|
- run: make loadable static
|
||||||
|
- uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: sqlite-vec-macos-x86_64-extension
|
||||||
|
path: dist/*
|
||||||
|
build-macos-aarch64-extension:
|
||||||
|
runs-on: macos-14
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- run: ./scripts/vendor.sh
|
||||||
|
- run: make loadable static
|
||||||
|
- uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: sqlite-vec-macos-aarch64-extension
|
||||||
|
path: dist/*
|
||||||
|
build-windows-x86_64-extension:
|
||||||
|
runs-on: windows-2022
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: ilammy/msvc-dev-cmd@v1
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
- run: ./scripts/vendor.sh
|
||||||
|
shell: bash
|
||||||
|
- run: make sqlite-vec.h
|
||||||
|
- run: mkdir dist
|
||||||
|
- run: cl.exe /fPIC -shared /W4 /Ivendor/ /O2 /LD sqlite-vec.c -o dist/vec0.dll
|
||||||
|
- uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: sqlite-vec-windows-x86_64-extension
|
||||||
|
path: dist/*
|
||||||
|
dist:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs:
|
||||||
|
[
|
||||||
|
build-linux-x86_64-extension,
|
||||||
|
build-linux-aarch64-extension,
|
||||||
|
build-macos-x86_64-extension,
|
||||||
|
build-macos-aarch64-extension,
|
||||||
|
build-windows-x86_64-extension,
|
||||||
|
]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: sqlite-vec-linux-x86_64-extension
|
||||||
|
path: dist/linux-x86_64
|
||||||
|
- uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: sqlite-vec-linux-aarch64-extension
|
||||||
|
path: dist/linux-aarch64
|
||||||
|
- uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: sqlite-vec-macos-x86_64-extension
|
||||||
|
path: dist/macos-x86_64
|
||||||
|
- uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: sqlite-vec-macos-aarch64-extension
|
||||||
|
path: dist/macos-aarch64
|
||||||
|
- uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: sqlite-vec-windows-x86_64-extension
|
||||||
|
path: dist/windows-x86_64
|
||||||
|
- run: make sqlite-vec.h
|
||||||
|
- run: |
|
||||||
|
./scripts/vendor.sh
|
||||||
|
make amalgamation
|
||||||
|
mkdir -p amalgamation
|
||||||
|
cp dist/sqlite-vec.c sqlite-vec.h amalgamation/
|
||||||
|
rm dist/sqlite-vec.c
|
||||||
|
- uses: asg017/setup-sqlite-dist@73e37b2ffb0b51e64a64eb035da38c958b9ff6c6
|
||||||
|
- run: sqlite-dist build --set-version $(cat VERSION)
|
||||||
|
- name: Create release and upload assets
|
||||||
|
env:
|
||||||
|
GH_TOKEN: ${{ github.token }}
|
||||||
|
run: |
|
||||||
|
SHORT_SHA=$(echo "${{ github.sha }}" | head -c 10)
|
||||||
|
TAG="diskann-${SHORT_SHA}"
|
||||||
|
zip -j "amalgamation/sqlite-vec-amalgamation.zip" amalgamation/sqlite-vec.c amalgamation/sqlite-vec.h
|
||||||
|
gh release create "$TAG" \
|
||||||
|
--title "$TAG" \
|
||||||
|
--target "${{ github.sha }}" \
|
||||||
|
--prerelease \
|
||||||
|
amalgamation/sqlite-vec-amalgamation.zip \
|
||||||
|
.sqlite-dist/pip/*
|
||||||
2
Makefile
2
Makefile
|
|
@ -204,7 +204,7 @@ test-loadable-watch:
|
||||||
watchexec --exts c,py,Makefile --clear -- make test-loadable
|
watchexec --exts c,py,Makefile --clear -- make test-loadable
|
||||||
|
|
||||||
test-unit:
|
test-unit:
|
||||||
$(CC) -DSQLITE_CORE -DSQLITE_VEC_TEST -DSQLITE_VEC_ENABLE_RESCORE tests/test-unit.c sqlite-vec.c vendor/sqlite3.c -I./ -Ivendor -o $(prefix)/test-unit && $(prefix)/test-unit
|
$(CC) -DSQLITE_CORE -DSQLITE_VEC_TEST -DSQLITE_VEC_ENABLE_RESCORE -DSQLITE_VEC_ENABLE_DISKANN=1 tests/test-unit.c sqlite-vec.c vendor/sqlite3.c -I./ -Ivendor $(CFLAGS) -o $(prefix)/test-unit && $(prefix)/test-unit
|
||||||
|
|
||||||
# Standalone sqlite3 CLI with vec0 compiled in. Useful for benchmarking,
|
# Standalone sqlite3 CLI with vec0 compiled in. Useful for benchmarking,
|
||||||
# profiling (has debug symbols), and scripting without .load_extension.
|
# profiling (has debug symbols), and scripting without .load_extension.
|
||||||
|
|
|
||||||
|
|
@ -19,9 +19,16 @@ RESCORE_CONFIGS = \
|
||||||
"rescore-bit-os16:type=rescore,quantizer=bit,oversample=16" \
|
"rescore-bit-os16:type=rescore,quantizer=bit,oversample=16" \
|
||||||
"rescore-int8-os8:type=rescore,quantizer=int8,oversample=8"
|
"rescore-int8-os8:type=rescore,quantizer=int8,oversample=8"
|
||||||
|
|
||||||
ALL_CONFIGS = $(BASELINES) $(RESCORE_CONFIGS) $(IVF_CONFIGS)
|
# --- DiskANN configs ---
|
||||||
|
DISKANN_CONFIGS = \
|
||||||
|
"diskann-R48-binary:type=diskann,R=48,L=128,quantizer=binary" \
|
||||||
|
"diskann-R72-binary:type=diskann,R=72,L=128,quantizer=binary" \
|
||||||
|
"diskann-R72-int8:type=diskann,R=72,L=128,quantizer=int8" \
|
||||||
|
"diskann-R72-L256:type=diskann,R=72,L=256,quantizer=binary"
|
||||||
|
|
||||||
.PHONY: seed ground-truth bench-smoke bench-rescore bench-ivf bench-10k bench-50k bench-100k bench-all \
|
ALL_CONFIGS = $(BASELINES) $(RESCORE_CONFIGS) $(IVF_CONFIGS) $(DISKANN_CONFIGS)
|
||||||
|
|
||||||
|
.PHONY: seed ground-truth bench-smoke bench-rescore bench-ivf bench-diskann bench-10k bench-50k bench-100k bench-all \
|
||||||
report clean
|
report clean
|
||||||
|
|
||||||
# --- Data preparation ---
|
# --- Data preparation ---
|
||||||
|
|
@ -37,7 +44,8 @@ ground-truth: seed
|
||||||
bench-smoke: seed
|
bench-smoke: seed
|
||||||
$(BENCH) --subset-size 5000 -k 10 -n 20 -o runs/smoke \
|
$(BENCH) --subset-size 5000 -k 10 -n 20 -o runs/smoke \
|
||||||
"brute-float:type=baseline,variant=float" \
|
"brute-float:type=baseline,variant=float" \
|
||||||
"ivf-quick:type=ivf,nlist=16,nprobe=4"
|
"ivf-quick:type=ivf,nlist=16,nprobe=4" \
|
||||||
|
"diskann-quick:type=diskann,R=48,L=64,quantizer=binary"
|
||||||
|
|
||||||
bench-rescore: seed
|
bench-rescore: seed
|
||||||
$(BENCH) --subset-size 10000 -k 10 -o runs/rescore \
|
$(BENCH) --subset-size 10000 -k 10 -o runs/rescore \
|
||||||
|
|
@ -62,6 +70,12 @@ bench-ivf: seed
|
||||||
$(BENCH) --subset-size 50000 -k 10 -o runs/ivf $(BASELINES) $(IVF_CONFIGS)
|
$(BENCH) --subset-size 50000 -k 10 -o runs/ivf $(BASELINES) $(IVF_CONFIGS)
|
||||||
$(BENCH) --subset-size 100000 -k 10 -o runs/ivf $(BASELINES) $(IVF_CONFIGS)
|
$(BENCH) --subset-size 100000 -k 10 -o runs/ivf $(BASELINES) $(IVF_CONFIGS)
|
||||||
|
|
||||||
|
# --- DiskANN across sizes ---
|
||||||
|
bench-diskann: seed
|
||||||
|
$(BENCH) --subset-size 10000 -k 10 -o runs/diskann $(BASELINES) $(DISKANN_CONFIGS)
|
||||||
|
$(BENCH) --subset-size 50000 -k 10 -o runs/diskann $(BASELINES) $(DISKANN_CONFIGS)
|
||||||
|
$(BENCH) --subset-size 100000 -k 10 -o runs/diskann $(BASELINES) $(DISKANN_CONFIGS)
|
||||||
|
|
||||||
# --- Report ---
|
# --- Report ---
|
||||||
report:
|
report:
|
||||||
@echo "Use: sqlite3 runs/<dir>/results.db 'SELECT * FROM bench_results ORDER BY recall DESC'"
|
@echo "Use: sqlite3 runs/<dir>/results.db 'SELECT * FROM bench_results ORDER BY recall DESC'"
|
||||||
|
|
|
||||||
|
|
@ -6,18 +6,16 @@ across different vec0 configurations.
|
||||||
|
|
||||||
Config format: name:type=<index_type>,key=val,key=val
|
Config format: name:type=<index_type>,key=val,key=val
|
||||||
|
|
||||||
Baseline (brute-force) keys:
|
Available types: none, vec0-flat, rescore, ivf, diskann
|
||||||
type=baseline, variant=float|int8|bit, oversample=8
|
|
||||||
|
|
||||||
Index-specific types can be registered via INDEX_REGISTRY (see below).
|
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python bench.py --subset-size 10000 \
|
python bench.py --subset-size 10000 \
|
||||||
"brute-float:type=baseline,variant=float" \
|
"raw:type=none" \
|
||||||
"brute-int8:type=baseline,variant=int8" \
|
"flat:type=vec0-flat,variant=float" \
|
||||||
"brute-bit:type=baseline,variant=bit"
|
"flat-int8:type=vec0-flat,variant=int8"
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
|
from datetime import datetime, timezone
|
||||||
import os
|
import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import statistics
|
import statistics
|
||||||
|
|
@ -56,11 +54,118 @@ INDEX_REGISTRY = {}
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Baseline implementation
|
# "none" — regular table, no vec0, manual KNN via vec_distance_cosine()
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
def _baseline_create_table_sql(params):
|
def _none_create_table_sql(params):
|
||||||
|
variant = params["variant"]
|
||||||
|
if variant == "int8":
|
||||||
|
return (
|
||||||
|
"CREATE TABLE vec_items ("
|
||||||
|
" id INTEGER PRIMARY KEY,"
|
||||||
|
" embedding BLOB NOT NULL,"
|
||||||
|
" embedding_int8 BLOB NOT NULL)"
|
||||||
|
)
|
||||||
|
elif variant == "bit":
|
||||||
|
return (
|
||||||
|
"CREATE TABLE vec_items ("
|
||||||
|
" id INTEGER PRIMARY KEY,"
|
||||||
|
" embedding BLOB NOT NULL,"
|
||||||
|
" embedding_bq BLOB NOT NULL)"
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
"CREATE TABLE vec_items ("
|
||||||
|
" id INTEGER PRIMARY KEY,"
|
||||||
|
" embedding BLOB NOT NULL)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _none_insert_sql(params):
|
||||||
|
variant = params["variant"]
|
||||||
|
if variant == "int8":
|
||||||
|
return (
|
||||||
|
"INSERT INTO vec_items(id, embedding, embedding_int8) "
|
||||||
|
"SELECT id, vector, vec_quantize_int8(vector, 'unit') "
|
||||||
|
"FROM base.train WHERE id >= :lo AND id < :hi"
|
||||||
|
)
|
||||||
|
elif variant == "bit":
|
||||||
|
return (
|
||||||
|
"INSERT INTO vec_items(id, embedding, embedding_bq) "
|
||||||
|
"SELECT id, vector, vec_quantize_binary(vector) "
|
||||||
|
"FROM base.train WHERE id >= :lo AND id < :hi"
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
"INSERT INTO vec_items(id, embedding) "
|
||||||
|
"SELECT id, vector FROM base.train WHERE id >= :lo AND id < :hi"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _none_run_query(conn, params, query, k):
|
||||||
|
variant = params["variant"]
|
||||||
|
oversample = params.get("oversample", 8)
|
||||||
|
|
||||||
|
if variant == "int8":
|
||||||
|
q_int8 = conn.execute(
|
||||||
|
"SELECT vec_quantize_int8(:query, 'unit')", {"query": query}
|
||||||
|
).fetchone()[0]
|
||||||
|
return conn.execute(
|
||||||
|
"WITH coarse AS ("
|
||||||
|
" SELECT id, embedding FROM ("
|
||||||
|
" SELECT id, embedding, vec_distance_cosine(vec_int8(:q_int8), vec_int8(embedding_int8)) as dist "
|
||||||
|
" FROM vec_items ORDER BY dist LIMIT :oversample_k"
|
||||||
|
" )"
|
||||||
|
") "
|
||||||
|
"SELECT id, vec_distance_cosine(:query, embedding) as distance "
|
||||||
|
"FROM coarse ORDER BY 2 LIMIT :k",
|
||||||
|
{"q_int8": q_int8, "query": query, "k": k, "oversample_k": k * oversample},
|
||||||
|
).fetchall()
|
||||||
|
elif variant == "bit":
|
||||||
|
q_bit = conn.execute(
|
||||||
|
"SELECT vec_quantize_binary(:query)", {"query": query}
|
||||||
|
).fetchone()[0]
|
||||||
|
return conn.execute(
|
||||||
|
"WITH coarse AS ("
|
||||||
|
" SELECT id, embedding FROM ("
|
||||||
|
" SELECT id, embedding, vec_distance_hamming(vec_bit(:q_bit), vec_bit(embedding_bq)) as dist "
|
||||||
|
" FROM vec_items ORDER BY dist LIMIT :oversample_k"
|
||||||
|
" )"
|
||||||
|
") "
|
||||||
|
"SELECT id, vec_distance_cosine(:query, embedding) as distance "
|
||||||
|
"FROM coarse ORDER BY 2 LIMIT :k",
|
||||||
|
{"q_bit": q_bit, "query": query, "k": k, "oversample_k": k * oversample},
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
return conn.execute(
|
||||||
|
"SELECT id, vec_distance_cosine(:query, embedding) as distance "
|
||||||
|
"FROM vec_items ORDER BY 2 LIMIT :k",
|
||||||
|
{"query": query, "k": k},
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
|
||||||
|
def _none_describe(params):
|
||||||
|
v = params["variant"]
|
||||||
|
if v in ("int8", "bit"):
|
||||||
|
return f"none {v} (os={params['oversample']})"
|
||||||
|
return f"none float"
|
||||||
|
|
||||||
|
|
||||||
|
INDEX_REGISTRY["none"] = {
|
||||||
|
"defaults": {"variant": "float", "oversample": 8},
|
||||||
|
"create_table_sql": _none_create_table_sql,
|
||||||
|
"insert_sql": _none_insert_sql,
|
||||||
|
"post_insert_hook": None,
|
||||||
|
"run_query": _none_run_query,
|
||||||
|
"describe": _none_describe,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# vec0-flat — vec0 virtual table with brute-force MATCH
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def _vec0flat_create_table_sql(params):
|
||||||
variant = params["variant"]
|
variant = params["variant"]
|
||||||
extra = ""
|
extra = ""
|
||||||
if variant == "int8":
|
if variant == "int8":
|
||||||
|
|
@ -76,7 +181,7 @@ def _baseline_create_table_sql(params):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _baseline_insert_sql(params):
|
def _vec0flat_insert_sql(params):
|
||||||
variant = params["variant"]
|
variant = params["variant"]
|
||||||
if variant == "int8":
|
if variant == "int8":
|
||||||
return (
|
return (
|
||||||
|
|
@ -93,7 +198,7 @@ def _baseline_insert_sql(params):
|
||||||
return None # use default
|
return None # use default
|
||||||
|
|
||||||
|
|
||||||
def _baseline_run_query(conn, params, query, k):
|
def _vec0flat_run_query(conn, params, query, k):
|
||||||
variant = params["variant"]
|
variant = params["variant"]
|
||||||
oversample = params.get("oversample", 8)
|
oversample = params.get("oversample", 8)
|
||||||
|
|
||||||
|
|
@ -123,20 +228,20 @@ def _baseline_run_query(conn, params, query, k):
|
||||||
return None # use default MATCH
|
return None # use default MATCH
|
||||||
|
|
||||||
|
|
||||||
def _baseline_describe(params):
|
def _vec0flat_describe(params):
|
||||||
v = params["variant"]
|
v = params["variant"]
|
||||||
if v in ("int8", "bit"):
|
if v in ("int8", "bit"):
|
||||||
return f"baseline {v} (os={params['oversample']})"
|
return f"vec0-flat {v} (os={params['oversample']})"
|
||||||
return f"baseline {v}"
|
return f"vec0-flat {v}"
|
||||||
|
|
||||||
|
|
||||||
INDEX_REGISTRY["baseline"] = {
|
INDEX_REGISTRY["vec0-flat"] = {
|
||||||
"defaults": {"variant": "float", "oversample": 8},
|
"defaults": {"variant": "float", "oversample": 8},
|
||||||
"create_table_sql": _baseline_create_table_sql,
|
"create_table_sql": _vec0flat_create_table_sql,
|
||||||
"insert_sql": _baseline_insert_sql,
|
"insert_sql": _vec0flat_insert_sql,
|
||||||
"post_insert_hook": None,
|
"post_insert_hook": None,
|
||||||
"run_query": _baseline_run_query,
|
"run_query": _vec0flat_run_query,
|
||||||
"describe": _baseline_describe,
|
"describe": _vec0flat_describe,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -215,12 +320,64 @@ INDEX_REGISTRY["ivf"] = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# DiskANN implementation
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def _diskann_create_table_sql(params):
|
||||||
|
bt = params["buffer_threshold"]
|
||||||
|
extra = f", buffer_threshold={bt}" if bt > 0 else ""
|
||||||
|
return (
|
||||||
|
f"CREATE VIRTUAL TABLE vec_items USING vec0("
|
||||||
|
f" id integer primary key,"
|
||||||
|
f" embedding float[768] distance_metric=cosine"
|
||||||
|
f" INDEXED BY diskann("
|
||||||
|
f" neighbor_quantizer={params['quantizer']},"
|
||||||
|
f" n_neighbors={params['R']},"
|
||||||
|
f" search_list_size={params['L']}"
|
||||||
|
f" {extra}"
|
||||||
|
f" )"
|
||||||
|
f")"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _diskann_pre_query_hook(conn, params):
|
||||||
|
L_search = params.get("L_search")
|
||||||
|
if L_search:
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO vec_items(id) VALUES (?)",
|
||||||
|
(f"search_list_size_search={L_search}",),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
print(f" Set search_list_size_search={L_search}")
|
||||||
|
|
||||||
|
|
||||||
|
def _diskann_describe(params):
|
||||||
|
desc = f"diskann q={params['quantizer']:<6} R={params['R']:<3} L={params['L']}"
|
||||||
|
L_search = params.get("L_search")
|
||||||
|
if L_search:
|
||||||
|
desc += f" L_search={L_search}"
|
||||||
|
return desc
|
||||||
|
|
||||||
|
|
||||||
|
INDEX_REGISTRY["diskann"] = {
|
||||||
|
"defaults": {"R": 72, "L": 128, "quantizer": "binary", "buffer_threshold": 0},
|
||||||
|
"create_table_sql": _diskann_create_table_sql,
|
||||||
|
"insert_sql": None,
|
||||||
|
"post_insert_hook": None,
|
||||||
|
"pre_query_hook": _diskann_pre_query_hook,
|
||||||
|
"run_query": None,
|
||||||
|
"describe": _diskann_describe,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Config parsing
|
# Config parsing
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
INT_KEYS = {
|
INT_KEYS = {
|
||||||
"R", "L", "buffer_threshold", "nlist", "nprobe", "oversample",
|
"R", "L", "L_search", "buffer_threshold", "nlist", "nprobe", "oversample",
|
||||||
"n_trees", "search_k",
|
"n_trees", "search_k",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -238,7 +395,7 @@ def parse_config(spec):
|
||||||
k, v = kv.split("=", 1)
|
k, v = kv.split("=", 1)
|
||||||
raw[k.strip()] = v.strip()
|
raw[k.strip()] = v.strip()
|
||||||
|
|
||||||
index_type = raw.pop("type", "baseline")
|
index_type = raw.pop("type", "vec0-flat")
|
||||||
if index_type not in INDEX_REGISTRY:
|
if index_type not in INDEX_REGISTRY:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown index type: {index_type}. "
|
f"Unknown index type: {index_type}. "
|
||||||
|
|
@ -289,7 +446,7 @@ def insert_loop(conn, sql, subset_size, label=""):
|
||||||
return time.perf_counter() - t0
|
return time.perf_counter() - t0
|
||||||
|
|
||||||
|
|
||||||
def open_bench_db(db_path, ext_path, base_db):
|
def create_bench_db(db_path, ext_path, base_db):
|
||||||
if os.path.exists(db_path):
|
if os.path.exists(db_path):
|
||||||
os.remove(db_path)
|
os.remove(db_path)
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
|
|
@ -300,6 +457,19 @@ def open_bench_db(db_path, ext_path, base_db):
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
def open_existing_bench_db(db_path, ext_path, base_db):
|
||||||
|
if not os.path.exists(db_path):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Index DB not found: {db_path}\n"
|
||||||
|
f"Build it first with: --phase build"
|
||||||
|
)
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
conn.enable_load_extension(True)
|
||||||
|
conn.load_extension(ext_path)
|
||||||
|
conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_INSERT_SQL = (
|
DEFAULT_INSERT_SQL = (
|
||||||
"INSERT INTO vec_items(id, embedding) "
|
"INSERT INTO vec_items(id, embedding) "
|
||||||
"SELECT id, vector FROM base.train WHERE id >= :lo AND id < :hi"
|
"SELECT id, vector FROM base.train WHERE id >= :lo AND id < :hi"
|
||||||
|
|
@ -313,7 +483,7 @@ DEFAULT_INSERT_SQL = (
|
||||||
|
|
||||||
def build_index(base_db, ext_path, name, params, subset_size, out_dir):
|
def build_index(base_db, ext_path, name, params, subset_size, out_dir):
|
||||||
db_path = os.path.join(out_dir, f"{name}.{subset_size}.db")
|
db_path = os.path.join(out_dir, f"{name}.{subset_size}.db")
|
||||||
conn = open_bench_db(db_path, ext_path, base_db)
|
conn = create_bench_db(db_path, ext_path, base_db)
|
||||||
|
|
||||||
reg = INDEX_REGISTRY[params["index_type"]]
|
reg = INDEX_REGISTRY[params["index_type"]]
|
||||||
|
|
||||||
|
|
@ -364,12 +534,16 @@ def _default_match_query(conn, query, k):
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
|
|
||||||
def measure_knn(db_path, ext_path, base_db, params, subset_size, k=10, n=50):
|
def measure_knn(db_path, ext_path, base_db, params, subset_size, k=10, n=50,
|
||||||
|
pre_query_hook=None):
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
conn.enable_load_extension(True)
|
conn.enable_load_extension(True)
|
||||||
conn.load_extension(ext_path)
|
conn.load_extension(ext_path)
|
||||||
conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
|
conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
|
||||||
|
|
||||||
|
if pre_query_hook:
|
||||||
|
pre_query_hook(conn, params)
|
||||||
|
|
||||||
query_vectors = load_query_vectors(base_db, n)
|
query_vectors = load_query_vectors(base_db, n)
|
||||||
|
|
||||||
reg = INDEX_REGISTRY[params["index_type"]]
|
reg = INDEX_REGISTRY[params["index_type"]]
|
||||||
|
|
@ -431,6 +605,34 @@ def measure_knn(db_path, ext_path, base_db, params, subset_size, k=10, n=50):
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def open_results_db(results_path):
|
||||||
|
db = sqlite3.connect(results_path)
|
||||||
|
db.executescript(open(os.path.join(_SCRIPT_DIR, "schema.sql")).read())
|
||||||
|
# Migrate existing DBs that predate the runs table
|
||||||
|
cols = {r[1] for r in db.execute("PRAGMA table_info(runs)").fetchall()}
|
||||||
|
if "phase" not in cols:
|
||||||
|
db.execute("ALTER TABLE runs ADD COLUMN phase TEXT NOT NULL DEFAULT 'both'")
|
||||||
|
db.commit()
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
def create_run(db, config_name, index_type, subset_size, phase, k=None, n=None):
|
||||||
|
cur = db.execute(
|
||||||
|
"INSERT INTO runs (config_name, index_type, subset_size, phase, status, k, n) "
|
||||||
|
"VALUES (?, ?, ?, ?, 'pending', ?, ?)",
|
||||||
|
(config_name, index_type, subset_size, phase, k, n),
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
return cur.lastrowid
|
||||||
|
|
||||||
|
|
||||||
|
def update_run(db, run_id, **kwargs):
|
||||||
|
sets = ", ".join(f"{k} = ?" for k in kwargs)
|
||||||
|
vals = list(kwargs.values()) + [run_id]
|
||||||
|
db.execute(f"UPDATE runs SET {sets} WHERE run_id = ?", vals)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
|
||||||
def save_results(results_path, rows):
|
def save_results(results_path, rows):
|
||||||
db = sqlite3.connect(results_path)
|
db = sqlite3.connect(results_path)
|
||||||
db.executescript(open(os.path.join(_SCRIPT_DIR, "schema.sql")).read())
|
db.executescript(open(os.path.join(_SCRIPT_DIR, "schema.sql")).read())
|
||||||
|
|
@ -500,6 +702,8 @@ def main():
|
||||||
parser.add_argument("--subset-size", type=int, required=True)
|
parser.add_argument("--subset-size", type=int, required=True)
|
||||||
parser.add_argument("-k", type=int, default=10, help="KNN k (default 10)")
|
parser.add_argument("-k", type=int, default=10, help="KNN k (default 10)")
|
||||||
parser.add_argument("-n", type=int, default=50, help="number of queries (default 50)")
|
parser.add_argument("-n", type=int, default=50, help="number of queries (default 50)")
|
||||||
|
parser.add_argument("--phase", choices=["build", "query", "both"], default="both",
|
||||||
|
help="build=build only, query=query existing index, both=default")
|
||||||
parser.add_argument("--base-db", default=BASE_DB)
|
parser.add_argument("--base-db", default=BASE_DB)
|
||||||
parser.add_argument("--ext", default=EXT_PATH)
|
parser.add_argument("--ext", default=EXT_PATH)
|
||||||
parser.add_argument("-o", "--out-dir", default="runs")
|
parser.add_argument("-o", "--out-dir", default="runs")
|
||||||
|
|
@ -508,55 +712,164 @@ def main():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
os.makedirs(args.out_dir, exist_ok=True)
|
os.makedirs(args.out_dir, exist_ok=True)
|
||||||
results_db = args.results_db or os.path.join(args.out_dir, "results.db")
|
results_db_path = args.results_db or os.path.join(args.out_dir, "results.db")
|
||||||
configs = [parse_config(c) for c in args.configs]
|
configs = [parse_config(c) for c in args.configs]
|
||||||
|
results_db = open_results_db(results_db_path)
|
||||||
|
|
||||||
all_results = []
|
all_results = []
|
||||||
for i, (name, params) in enumerate(configs, 1):
|
for i, (name, params) in enumerate(configs, 1):
|
||||||
reg = INDEX_REGISTRY[params["index_type"]]
|
reg = INDEX_REGISTRY[params["index_type"]]
|
||||||
desc = reg["describe"](params)
|
desc = reg["describe"](params)
|
||||||
print(f"\n[{i}/{len(configs)}] {name} ({desc.strip()})")
|
print(f"\n[{i}/{len(configs)}] {name} ({desc.strip()}) [phase={args.phase}]")
|
||||||
|
|
||||||
build = build_index(
|
db_path = os.path.join(args.out_dir, f"{name}.{args.subset_size}.db")
|
||||||
args.base_db, args.ext, name, params, args.subset_size, args.out_dir
|
|
||||||
)
|
|
||||||
train_str = f" + {build['train_time_s']}s train" if build["train_time_s"] > 0 else ""
|
|
||||||
print(
|
|
||||||
f" Build: {build['insert_time_s']}s insert{train_str} "
|
|
||||||
f"{build['file_size_mb']} MB"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f" Measuring KNN (k={args.k}, n={args.n})...")
|
if args.phase == "build":
|
||||||
knn = measure_knn(
|
run_id = create_run(results_db, name, params["index_type"],
|
||||||
build["db_path"], args.ext, args.base_db,
|
args.subset_size, "build")
|
||||||
params, args.subset_size, k=args.k, n=args.n,
|
update_run(results_db, run_id, status="inserting")
|
||||||
)
|
|
||||||
print(f" KNN: mean={knn['mean_ms']}ms recall@{args.k}={knn['recall']}")
|
|
||||||
|
|
||||||
all_results.append({
|
build = build_index(
|
||||||
"name": name,
|
args.base_db, args.ext, name, params, args.subset_size, args.out_dir
|
||||||
"n_vectors": args.subset_size,
|
)
|
||||||
"index_type": params["index_type"],
|
train_str = f" + {build['train_time_s']}s train" if build["train_time_s"] > 0 else ""
|
||||||
"config_desc": desc,
|
print(
|
||||||
"db_path": build["db_path"],
|
f" Build: {build['insert_time_s']}s insert{train_str} "
|
||||||
"insert_time_s": build["insert_time_s"],
|
f"{build['file_size_mb']} MB"
|
||||||
"train_time_s": build["train_time_s"],
|
)
|
||||||
"total_time_s": build["total_time_s"],
|
update_run(results_db, run_id,
|
||||||
"insert_per_vec_ms": build["insert_per_vec_ms"],
|
status="built",
|
||||||
"rows": build["rows"],
|
db_path=build["db_path"],
|
||||||
"file_size_mb": build["file_size_mb"],
|
insert_time_s=build["insert_time_s"],
|
||||||
"k": args.k,
|
train_time_s=build["train_time_s"],
|
||||||
"n_queries": args.n,
|
total_build_time_s=build["total_time_s"],
|
||||||
"mean_ms": knn["mean_ms"],
|
rows=build["rows"],
|
||||||
"median_ms": knn["median_ms"],
|
file_size_mb=build["file_size_mb"],
|
||||||
"p99_ms": knn["p99_ms"],
|
finished_at=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"))
|
||||||
"total_ms": knn["total_ms"],
|
print(f" Index DB: {build['db_path']}")
|
||||||
"recall": knn["recall"],
|
|
||||||
})
|
|
||||||
|
|
||||||
print_report(all_results)
|
elif args.phase == "query":
|
||||||
save_results(results_db, all_results)
|
if not os.path.exists(db_path):
|
||||||
print(f"\nResults saved to {results_db}")
|
raise FileNotFoundError(
|
||||||
|
f"Index DB not found: {db_path}\n"
|
||||||
|
f"Build it first with: --phase build"
|
||||||
|
)
|
||||||
|
|
||||||
|
run_id = create_run(results_db, name, params["index_type"],
|
||||||
|
args.subset_size, "query", k=args.k, n=args.n)
|
||||||
|
update_run(results_db, run_id, status="querying")
|
||||||
|
|
||||||
|
pre_hook = reg.get("pre_query_hook")
|
||||||
|
print(f" Measuring KNN (k={args.k}, n={args.n})...")
|
||||||
|
knn = measure_knn(
|
||||||
|
db_path, args.ext, args.base_db,
|
||||||
|
params, args.subset_size, k=args.k, n=args.n,
|
||||||
|
pre_query_hook=pre_hook,
|
||||||
|
)
|
||||||
|
print(f" KNN: mean={knn['mean_ms']}ms recall@{args.k}={knn['recall']}")
|
||||||
|
|
||||||
|
qps = round(args.n / (knn["total_ms"] / 1000), 1) if knn["total_ms"] > 0 else 0
|
||||||
|
update_run(results_db, run_id,
|
||||||
|
status="done",
|
||||||
|
db_path=db_path,
|
||||||
|
mean_ms=knn["mean_ms"],
|
||||||
|
median_ms=knn["median_ms"],
|
||||||
|
p99_ms=knn["p99_ms"],
|
||||||
|
total_query_ms=knn["total_ms"],
|
||||||
|
qps=qps,
|
||||||
|
recall=knn["recall"],
|
||||||
|
finished_at=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"))
|
||||||
|
|
||||||
|
file_size_mb = os.path.getsize(db_path) / (1024 * 1024)
|
||||||
|
all_results.append({
|
||||||
|
"name": name,
|
||||||
|
"n_vectors": args.subset_size,
|
||||||
|
"index_type": params["index_type"],
|
||||||
|
"config_desc": desc,
|
||||||
|
"db_path": db_path,
|
||||||
|
"insert_time_s": 0,
|
||||||
|
"train_time_s": 0,
|
||||||
|
"total_time_s": 0,
|
||||||
|
"insert_per_vec_ms": 0,
|
||||||
|
"rows": 0,
|
||||||
|
"file_size_mb": file_size_mb,
|
||||||
|
"k": args.k,
|
||||||
|
"n_queries": args.n,
|
||||||
|
"mean_ms": knn["mean_ms"],
|
||||||
|
"median_ms": knn["median_ms"],
|
||||||
|
"p99_ms": knn["p99_ms"],
|
||||||
|
"total_ms": knn["total_ms"],
|
||||||
|
"recall": knn["recall"],
|
||||||
|
})
|
||||||
|
|
||||||
|
else: # both
|
||||||
|
run_id = create_run(results_db, name, params["index_type"],
|
||||||
|
args.subset_size, "both", k=args.k, n=args.n)
|
||||||
|
update_run(results_db, run_id, status="inserting")
|
||||||
|
|
||||||
|
build = build_index(
|
||||||
|
args.base_db, args.ext, name, params, args.subset_size, args.out_dir
|
||||||
|
)
|
||||||
|
train_str = f" + {build['train_time_s']}s train" if build["train_time_s"] > 0 else ""
|
||||||
|
print(
|
||||||
|
f" Build: {build['insert_time_s']}s insert{train_str} "
|
||||||
|
f"{build['file_size_mb']} MB"
|
||||||
|
)
|
||||||
|
update_run(results_db, run_id, status="querying",
|
||||||
|
db_path=build["db_path"],
|
||||||
|
insert_time_s=build["insert_time_s"],
|
||||||
|
train_time_s=build["train_time_s"],
|
||||||
|
total_build_time_s=build["total_time_s"],
|
||||||
|
rows=build["rows"],
|
||||||
|
file_size_mb=build["file_size_mb"])
|
||||||
|
|
||||||
|
print(f" Measuring KNN (k={args.k}, n={args.n})...")
|
||||||
|
knn = measure_knn(
|
||||||
|
build["db_path"], args.ext, args.base_db,
|
||||||
|
params, args.subset_size, k=args.k, n=args.n,
|
||||||
|
)
|
||||||
|
print(f" KNN: mean={knn['mean_ms']}ms recall@{args.k}={knn['recall']}")
|
||||||
|
|
||||||
|
qps = round(args.n / (knn["total_ms"] / 1000), 1) if knn["total_ms"] > 0 else 0
|
||||||
|
update_run(results_db, run_id,
|
||||||
|
status="done",
|
||||||
|
mean_ms=knn["mean_ms"],
|
||||||
|
median_ms=knn["median_ms"],
|
||||||
|
p99_ms=knn["p99_ms"],
|
||||||
|
total_query_ms=knn["total_ms"],
|
||||||
|
qps=qps,
|
||||||
|
recall=knn["recall"],
|
||||||
|
finished_at=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"))
|
||||||
|
|
||||||
|
all_results.append({
|
||||||
|
"name": name,
|
||||||
|
"n_vectors": args.subset_size,
|
||||||
|
"index_type": params["index_type"],
|
||||||
|
"config_desc": desc,
|
||||||
|
"db_path": build["db_path"],
|
||||||
|
"insert_time_s": build["insert_time_s"],
|
||||||
|
"train_time_s": build["train_time_s"],
|
||||||
|
"total_time_s": build["total_time_s"],
|
||||||
|
"insert_per_vec_ms": build["insert_per_vec_ms"],
|
||||||
|
"rows": build["rows"],
|
||||||
|
"file_size_mb": build["file_size_mb"],
|
||||||
|
"k": args.k,
|
||||||
|
"n_queries": args.n,
|
||||||
|
"mean_ms": knn["mean_ms"],
|
||||||
|
"median_ms": knn["median_ms"],
|
||||||
|
"p99_ms": knn["p99_ms"],
|
||||||
|
"total_ms": knn["total_ms"],
|
||||||
|
"recall": knn["recall"],
|
||||||
|
})
|
||||||
|
|
||||||
|
results_db.close()
|
||||||
|
|
||||||
|
if all_results:
|
||||||
|
print_report(all_results)
|
||||||
|
save_results(results_db_path, all_results)
|
||||||
|
print(f"\nResults saved to {results_db_path}")
|
||||||
|
elif args.phase == "build":
|
||||||
|
print(f"\nBuild complete. Results tracked in {results_db_path}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,31 @@
|
||||||
-- "baseline"; index-specific branches add their own types (registered
|
-- "baseline"; index-specific branches add their own types (registered
|
||||||
-- via INDEX_REGISTRY in bench.py).
|
-- via INDEX_REGISTRY in bench.py).
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS runs (
|
||||||
|
run_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
config_name TEXT NOT NULL,
|
||||||
|
index_type TEXT NOT NULL,
|
||||||
|
subset_size INTEGER NOT NULL,
|
||||||
|
phase TEXT NOT NULL DEFAULT 'both', -- 'build', 'query', or 'both'
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending',
|
||||||
|
k INTEGER,
|
||||||
|
n INTEGER,
|
||||||
|
db_path TEXT,
|
||||||
|
insert_time_s REAL,
|
||||||
|
train_time_s REAL,
|
||||||
|
total_build_time_s REAL,
|
||||||
|
rows INTEGER,
|
||||||
|
file_size_mb REAL,
|
||||||
|
mean_ms REAL,
|
||||||
|
median_ms REAL,
|
||||||
|
p99_ms REAL,
|
||||||
|
total_query_ms REAL,
|
||||||
|
qps REAL,
|
||||||
|
recall REAL,
|
||||||
|
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||||
|
finished_at TEXT
|
||||||
|
);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS build_results (
|
CREATE TABLE IF NOT EXISTS build_results (
|
||||||
config_name TEXT NOT NULL,
|
config_name TEXT NOT NULL,
|
||||||
index_type TEXT NOT NULL,
|
index_type TEXT NOT NULL,
|
||||||
|
|
|
||||||
1768
sqlite-vec-diskann.c
Normal file
1768
sqlite-vec-diskann.c
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -156,21 +156,11 @@ static void rescore_quantize_float_to_bit(const float *src, uint8_t *dst,
|
||||||
|
|
||||||
static void rescore_quantize_float_to_int8(const float *src, int8_t *dst,
|
static void rescore_quantize_float_to_int8(const float *src, int8_t *dst,
|
||||||
size_t dimensions) {
|
size_t dimensions) {
|
||||||
float vmin = src[0], vmax = src[0];
|
float step = 2.0f / 255.0f;
|
||||||
for (size_t i = 1; i < dimensions; i++) {
|
|
||||||
if (src[i] < vmin) vmin = src[i];
|
|
||||||
if (src[i] > vmax) vmax = src[i];
|
|
||||||
}
|
|
||||||
float range = vmax - vmin;
|
|
||||||
if (range == 0.0f) {
|
|
||||||
memset(dst, 0, dimensions);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
float scale = 255.0f / range;
|
|
||||||
for (size_t i = 0; i < dimensions; i++) {
|
for (size_t i = 0; i < dimensions; i++) {
|
||||||
float v = (src[i] - vmin) * scale - 128.0f;
|
float v = (src[i] - (-1.0f)) / step - 128.0f;
|
||||||
if (v < -128.0f) v = -128.0f;
|
if (!(v <= 127.0f)) v = 127.0f;
|
||||||
if (v > 127.0f) v = 127.0f;
|
if (!(v >= -128.0f)) v = -128.0f;
|
||||||
dst[i] = (int8_t)v;
|
dst[i] = (int8_t)v;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
787
sqlite-vec.c
787
sqlite-vec.c
|
|
@ -61,6 +61,10 @@ SQLITE_EXTENSION_INIT1
|
||||||
#define LONGDOUBLE_TYPE long double
|
#define LONGDOUBLE_TYPE long double
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifndef SQLITE_VEC_ENABLE_DISKANN
|
||||||
|
#define SQLITE_VEC_ENABLE_DISKANN 1
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
#ifndef __EMSCRIPTEN__
|
#ifndef __EMSCRIPTEN__
|
||||||
#ifndef __COSMOPOLITAN__
|
#ifndef __COSMOPOLITAN__
|
||||||
|
|
@ -2544,6 +2548,7 @@ enum Vec0IndexType {
|
||||||
VEC0_INDEX_TYPE_RESCORE = 2,
|
VEC0_INDEX_TYPE_RESCORE = 2,
|
||||||
#endif
|
#endif
|
||||||
VEC0_INDEX_TYPE_IVF = 3,
|
VEC0_INDEX_TYPE_IVF = 3,
|
||||||
|
VEC0_INDEX_TYPE_DISKANN = 4,
|
||||||
};
|
};
|
||||||
|
|
||||||
#if SQLITE_VEC_ENABLE_RESCORE
|
#if SQLITE_VEC_ENABLE_RESCORE
|
||||||
|
|
@ -2575,6 +2580,75 @@ struct Vec0IvfConfig {
|
||||||
struct Vec0IvfConfig { char _unused; };
|
struct Vec0IvfConfig { char _unused; };
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// DiskANN types and constants
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
#define VEC0_DISKANN_DEFAULT_N_NEIGHBORS 72
|
||||||
|
#define VEC0_DISKANN_MAX_N_NEIGHBORS 256
|
||||||
|
#define VEC0_DISKANN_DEFAULT_SEARCH_LIST_SIZE 128
|
||||||
|
#define VEC0_DISKANN_DEFAULT_ALPHA 1.2f
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Quantizer type used for compressing neighbor vectors in the DiskANN graph.
|
||||||
|
*/
|
||||||
|
enum Vec0DiskannQuantizerType {
|
||||||
|
VEC0_DISKANN_QUANTIZER_BINARY = 1, // 1 bit per dimension (1/32 compression)
|
||||||
|
VEC0_DISKANN_QUANTIZER_INT8 = 2, // 1 byte per dimension (1/4 compression)
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configuration for a DiskANN index on a single vector column.
|
||||||
|
* Parsed from `INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=72)`.
|
||||||
|
*/
|
||||||
|
struct Vec0DiskannConfig {
|
||||||
|
// Quantizer type for neighbor vectors
|
||||||
|
enum Vec0DiskannQuantizerType quantizer_type;
|
||||||
|
|
||||||
|
// Maximum number of neighbors per node (R in the paper). Must be divisible by 8.
|
||||||
|
int n_neighbors;
|
||||||
|
|
||||||
|
// Search list size (L in the paper) — unified default for both insert and query.
|
||||||
|
int search_list_size;
|
||||||
|
|
||||||
|
// Per-path overrides (0 = fall back to search_list_size).
|
||||||
|
int search_list_size_search;
|
||||||
|
int search_list_size_insert;
|
||||||
|
|
||||||
|
// Alpha parameter for RobustPrune (distance scaling factor, typically 1.0-1.5)
|
||||||
|
f32 alpha;
|
||||||
|
|
||||||
|
// Buffer threshold for batched inserts. When > 0, inserts go into a flat
|
||||||
|
// buffer table and are flushed into the graph when the buffer reaches this
|
||||||
|
// size. 0 = disabled (legacy per-row insert behavior).
|
||||||
|
int buffer_threshold;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a single candidate during greedy beam search.
|
||||||
|
* Used in priority queues / sorted arrays during LM-Search.
|
||||||
|
*/
|
||||||
|
struct Vec0DiskannCandidate {
|
||||||
|
i64 rowid;
|
||||||
|
f32 distance;
|
||||||
|
int visited; // 1 if this candidate's neighbors have been explored
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the byte size of a quantized vector for the given quantizer type
|
||||||
|
* and number of dimensions.
|
||||||
|
*/
|
||||||
|
size_t diskann_quantized_vector_byte_size(
|
||||||
|
enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions) {
|
||||||
|
switch (quantizer_type) {
|
||||||
|
case VEC0_DISKANN_QUANTIZER_BINARY:
|
||||||
|
return dimensions / CHAR_BIT; // 1 bit per dimension
|
||||||
|
case VEC0_DISKANN_QUANTIZER_INT8:
|
||||||
|
return dimensions * sizeof(i8); // 1 byte per dimension
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
struct VectorColumnDefinition {
|
struct VectorColumnDefinition {
|
||||||
char *name;
|
char *name;
|
||||||
int name_length;
|
int name_length;
|
||||||
|
|
@ -2586,6 +2660,7 @@ struct VectorColumnDefinition {
|
||||||
struct Vec0RescoreConfig rescore;
|
struct Vec0RescoreConfig rescore;
|
||||||
#endif
|
#endif
|
||||||
struct Vec0IvfConfig ivf;
|
struct Vec0IvfConfig ivf;
|
||||||
|
struct Vec0DiskannConfig diskann;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Vec0PartitionColumnDefinition {
|
struct Vec0PartitionColumnDefinition {
|
||||||
|
|
@ -2743,6 +2818,126 @@ static int vec0_parse_ivf_options(struct Vec0Scanner *scanner,
|
||||||
struct Vec0IvfConfig *config);
|
struct Vec0IvfConfig *config);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse the options inside diskann(...) parentheses.
|
||||||
|
* Scanner should be positioned right before the '(' token.
|
||||||
|
*
|
||||||
|
* Recognized options:
|
||||||
|
* neighbor_quantizer = binary | int8 (required)
|
||||||
|
* n_neighbors = <integer> (optional, default 72)
|
||||||
|
* search_list_size = <integer> (optional, default 128)
|
||||||
|
*/
|
||||||
|
static int vec0_parse_diskann_options(struct Vec0Scanner *scanner,
|
||||||
|
struct Vec0DiskannConfig *config) {
|
||||||
|
int rc;
|
||||||
|
struct Vec0Token token;
|
||||||
|
int hasQuantizer = 0;
|
||||||
|
|
||||||
|
// Set defaults
|
||||||
|
config->n_neighbors = VEC0_DISKANN_DEFAULT_N_NEIGHBORS;
|
||||||
|
config->search_list_size = VEC0_DISKANN_DEFAULT_SEARCH_LIST_SIZE;
|
||||||
|
config->search_list_size_search = 0;
|
||||||
|
config->search_list_size_insert = 0;
|
||||||
|
config->alpha = VEC0_DISKANN_DEFAULT_ALPHA;
|
||||||
|
config->buffer_threshold = 0;
|
||||||
|
int hasSearchListSize = 0;
|
||||||
|
int hasSearchListSizeSplit = 0;
|
||||||
|
|
||||||
|
// Expect '('
|
||||||
|
rc = vec0_scanner_next(scanner, &token);
|
||||||
|
if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_LPAREN) {
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
while (1) {
|
||||||
|
// key
|
||||||
|
rc = vec0_scanner_next(scanner, &token);
|
||||||
|
if (rc == VEC0_TOKEN_RESULT_SOME && token.token_type == TOKEN_TYPE_RPAREN) {
|
||||||
|
break; // empty parens or trailing comma
|
||||||
|
}
|
||||||
|
if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_IDENTIFIER) {
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
char *optKey = token.start;
|
||||||
|
int optKeyLen = token.end - token.start;
|
||||||
|
|
||||||
|
// '='
|
||||||
|
rc = vec0_scanner_next(scanner, &token);
|
||||||
|
if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_EQ) {
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
// value (identifier or digit)
|
||||||
|
rc = vec0_scanner_next(scanner, &token);
|
||||||
|
if (rc != VEC0_TOKEN_RESULT_SOME) {
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
char *optVal = token.start;
|
||||||
|
int optValLen = token.end - token.start;
|
||||||
|
|
||||||
|
if (sqlite3_strnicmp(optKey, "neighbor_quantizer", optKeyLen) == 0) {
|
||||||
|
if (sqlite3_strnicmp(optVal, "binary", optValLen) == 0) {
|
||||||
|
config->quantizer_type = VEC0_DISKANN_QUANTIZER_BINARY;
|
||||||
|
} else if (sqlite3_strnicmp(optVal, "int8", optValLen) == 0) {
|
||||||
|
config->quantizer_type = VEC0_DISKANN_QUANTIZER_INT8;
|
||||||
|
} else {
|
||||||
|
return SQLITE_ERROR; // unknown quantizer
|
||||||
|
}
|
||||||
|
hasQuantizer = 1;
|
||||||
|
} else if (sqlite3_strnicmp(optKey, "n_neighbors", optKeyLen) == 0) {
|
||||||
|
config->n_neighbors = atoi(optVal);
|
||||||
|
if (config->n_neighbors <= 0 || (config->n_neighbors % 8) != 0 ||
|
||||||
|
config->n_neighbors > VEC0_DISKANN_MAX_N_NEIGHBORS) {
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
} else if (sqlite3_strnicmp(optKey, "search_list_size_search", optKeyLen) == 0 && optKeyLen == 23) {
|
||||||
|
config->search_list_size_search = atoi(optVal);
|
||||||
|
if (config->search_list_size_search <= 0) {
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
hasSearchListSizeSplit = 1;
|
||||||
|
} else if (sqlite3_strnicmp(optKey, "search_list_size_insert", optKeyLen) == 0 && optKeyLen == 23) {
|
||||||
|
config->search_list_size_insert = atoi(optVal);
|
||||||
|
if (config->search_list_size_insert <= 0) {
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
hasSearchListSizeSplit = 1;
|
||||||
|
} else if (sqlite3_strnicmp(optKey, "search_list_size", optKeyLen) == 0) {
|
||||||
|
config->search_list_size = atoi(optVal);
|
||||||
|
if (config->search_list_size <= 0) {
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
hasSearchListSize = 1;
|
||||||
|
} else if (sqlite3_strnicmp(optKey, "buffer_threshold", optKeyLen) == 0) {
|
||||||
|
config->buffer_threshold = atoi(optVal);
|
||||||
|
if (config->buffer_threshold < 0) {
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return SQLITE_ERROR; // unknown option
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expect ',' or ')'
|
||||||
|
rc = vec0_scanner_next(scanner, &token);
|
||||||
|
if (rc == VEC0_TOKEN_RESULT_SOME && token.token_type == TOKEN_TYPE_RPAREN) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_COMMA) {
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!hasQuantizer) {
|
||||||
|
return SQLITE_ERROR; // neighbor_quantizer is required
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hasSearchListSize && hasSearchListSizeSplit) {
|
||||||
|
return SQLITE_ERROR; // cannot mix search_list_size with search_list_size_search/insert
|
||||||
|
}
|
||||||
|
|
||||||
|
return SQLITE_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int vec0_parse_vector_column(const char *source, int source_length,
|
int vec0_parse_vector_column(const char *source, int source_length,
|
||||||
struct VectorColumnDefinition *outColumn) {
|
struct VectorColumnDefinition *outColumn) {
|
||||||
// parses a vector column definition like so:
|
// parses a vector column definition like so:
|
||||||
|
|
@ -2763,8 +2958,9 @@ int vec0_parse_vector_column(const char *source, int source_length,
|
||||||
#endif
|
#endif
|
||||||
struct Vec0IvfConfig ivfConfig;
|
struct Vec0IvfConfig ivfConfig;
|
||||||
memset(&ivfConfig, 0, sizeof(ivfConfig));
|
memset(&ivfConfig, 0, sizeof(ivfConfig));
|
||||||
|
struct Vec0DiskannConfig diskannConfig;
|
||||||
|
memset(&diskannConfig, 0, sizeof(diskannConfig));
|
||||||
int dimensions;
|
int dimensions;
|
||||||
|
|
||||||
vec0_scanner_init(&scanner, source, source_length);
|
vec0_scanner_init(&scanner, source, source_length);
|
||||||
|
|
||||||
// starts with an identifier
|
// starts with an identifier
|
||||||
|
|
@ -2931,6 +3127,16 @@ int vec0_parse_vector_column(const char *source, int source_length,
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
return SQLITE_ERROR; // IVF not compiled in
|
return SQLITE_ERROR; // IVF not compiled in
|
||||||
|
#endif
|
||||||
|
} else if (sqlite3_strnicmp(token.start, "diskann", indexNameLen) == 0) {
|
||||||
|
#if SQLITE_VEC_ENABLE_DISKANN
|
||||||
|
indexType = VEC0_INDEX_TYPE_DISKANN;
|
||||||
|
rc = vec0_parse_diskann_options(&scanner, &diskannConfig);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
return SQLITE_ERROR;
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
// unknown index type
|
// unknown index type
|
||||||
|
|
@ -2956,6 +3162,7 @@ int vec0_parse_vector_column(const char *source, int source_length,
|
||||||
outColumn->rescore = rescoreConfig;
|
outColumn->rescore = rescoreConfig;
|
||||||
#endif
|
#endif
|
||||||
outColumn->ivf = ivfConfig;
|
outColumn->ivf = ivfConfig;
|
||||||
|
outColumn->diskann = diskannConfig;
|
||||||
return SQLITE_OK;
|
return SQLITE_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -3154,6 +3361,7 @@ static sqlite3_module vec_eachModule = {
|
||||||
#pragma endregion
|
#pragma endregion
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#pragma region vec0 virtual table
|
#pragma region vec0 virtual table
|
||||||
|
|
||||||
#define VEC0_COLUMN_ID 0
|
#define VEC0_COLUMN_ID 0
|
||||||
|
|
@ -3214,6 +3422,9 @@ static sqlite3_module vec_eachModule = {
|
||||||
#define VEC0_SHADOW_AUXILIARY_NAME "\"%w\".\"%w_auxiliary\""
|
#define VEC0_SHADOW_AUXILIARY_NAME "\"%w\".\"%w_auxiliary\""
|
||||||
|
|
||||||
#define VEC0_SHADOW_METADATA_N_NAME "\"%w\".\"%w_metadatachunks%02d\""
|
#define VEC0_SHADOW_METADATA_N_NAME "\"%w\".\"%w_metadatachunks%02d\""
|
||||||
|
#define VEC0_SHADOW_VECTORS_N_NAME "\"%w\".\"%w_vectors%02d\""
|
||||||
|
#define VEC0_SHADOW_DISKANN_NODES_N_NAME "\"%w\".\"%w_diskann_nodes%02d\""
|
||||||
|
#define VEC0_SHADOW_DISKANN_BUFFER_N_NAME "\"%w\".\"%w_diskann_buffer%02d\""
|
||||||
#define VEC0_SHADOW_METADATA_TEXT_DATA_NAME "\"%w\".\"%w_metadatatext%02d\""
|
#define VEC0_SHADOW_METADATA_TEXT_DATA_NAME "\"%w\".\"%w_metadatatext%02d\""
|
||||||
|
|
||||||
#define VEC_INTERAL_ERROR "Internal sqlite-vec error: "
|
#define VEC_INTERAL_ERROR "Internal sqlite-vec error: "
|
||||||
|
|
@ -3388,6 +3599,24 @@ struct vec0_vtab {
|
||||||
* Must be cleaned up with sqlite3_finalize().
|
* Must be cleaned up with sqlite3_finalize().
|
||||||
*/
|
*/
|
||||||
sqlite3_stmt *stmtRowidsGetChunkPosition;
|
sqlite3_stmt *stmtRowidsGetChunkPosition;
|
||||||
|
|
||||||
|
// === DiskANN additions ===
|
||||||
|
#if SQLITE_VEC_ENABLE_DISKANN
|
||||||
|
// Shadow table names for DiskANN, per vector column
|
||||||
|
// e.g., "{schema}"."{table}_vectors{00..15}"
|
||||||
|
char *shadowVectorsNames[VEC0_MAX_VECTOR_COLUMNS];
|
||||||
|
|
||||||
|
// e.g., "{schema}"."{table}_diskann_nodes{00..15}"
|
||||||
|
char *shadowDiskannNodesNames[VEC0_MAX_VECTOR_COLUMNS];
|
||||||
|
|
||||||
|
// Prepared statements for DiskANN operations (per vector column)
|
||||||
|
// These will be lazily prepared on first use.
|
||||||
|
sqlite3_stmt *stmtDiskannNodeRead[VEC0_MAX_VECTOR_COLUMNS];
|
||||||
|
sqlite3_stmt *stmtDiskannNodeWrite[VEC0_MAX_VECTOR_COLUMNS];
|
||||||
|
sqlite3_stmt *stmtDiskannNodeInsert[VEC0_MAX_VECTOR_COLUMNS];
|
||||||
|
sqlite3_stmt *stmtVectorsRead[VEC0_MAX_VECTOR_COLUMNS];
|
||||||
|
sqlite3_stmt *stmtVectorsInsert[VEC0_MAX_VECTOR_COLUMNS];
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
#if SQLITE_VEC_ENABLE_RESCORE
|
#if SQLITE_VEC_ENABLE_RESCORE
|
||||||
|
|
@ -3427,6 +3656,13 @@ void vec0_free_resources(vec0_vtab *p) {
|
||||||
sqlite3_finalize(p->stmtIvfRowidMapLookup[i]); p->stmtIvfRowidMapLookup[i] = NULL;
|
sqlite3_finalize(p->stmtIvfRowidMapLookup[i]); p->stmtIvfRowidMapLookup[i] = NULL;
|
||||||
sqlite3_finalize(p->stmtIvfRowidMapDelete[i]); p->stmtIvfRowidMapDelete[i] = NULL;
|
sqlite3_finalize(p->stmtIvfRowidMapDelete[i]); p->stmtIvfRowidMapDelete[i] = NULL;
|
||||||
sqlite3_finalize(p->stmtIvfCentroidsAll[i]); p->stmtIvfCentroidsAll[i] = NULL;
|
sqlite3_finalize(p->stmtIvfCentroidsAll[i]); p->stmtIvfCentroidsAll[i] = NULL;
|
||||||
|
#if SQLITE_VEC_ENABLE_DISKANN
|
||||||
|
sqlite3_finalize(p->stmtDiskannNodeRead[i]); p->stmtDiskannNodeRead[i] = NULL;
|
||||||
|
sqlite3_finalize(p->stmtDiskannNodeWrite[i]); p->stmtDiskannNodeWrite[i] = NULL;
|
||||||
|
sqlite3_finalize(p->stmtDiskannNodeInsert[i]); p->stmtDiskannNodeInsert[i] = NULL;
|
||||||
|
sqlite3_finalize(p->stmtVectorsRead[i]); p->stmtVectorsRead[i] = NULL;
|
||||||
|
sqlite3_finalize(p->stmtVectorsInsert[i]); p->stmtVectorsInsert[i] = NULL;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
@ -3464,6 +3700,13 @@ void vec0_free(vec0_vtab *p) {
|
||||||
p->shadowRescoreVectorsNames[i] = NULL;
|
p->shadowRescoreVectorsNames[i] = NULL;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_DISKANN
|
||||||
|
sqlite3_free(p->shadowVectorsNames[i]);
|
||||||
|
p->shadowVectorsNames[i] = NULL;
|
||||||
|
sqlite3_free(p->shadowDiskannNodesNames[i]);
|
||||||
|
p->shadowDiskannNodesNames[i] = NULL;
|
||||||
|
#endif
|
||||||
|
|
||||||
sqlite3_free(p->vector_columns[i].name);
|
sqlite3_free(p->vector_columns[i].name);
|
||||||
p->vector_columns[i].name = NULL;
|
p->vector_columns[i].name = NULL;
|
||||||
}
|
}
|
||||||
|
|
@ -3484,6 +3727,12 @@ void vec0_free(vec0_vtab *p) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_DISKANN
|
||||||
|
#include "sqlite-vec-diskann.c"
|
||||||
|
#else
|
||||||
|
static int vec0_all_columns_diskann(vec0_vtab *p) { (void)p; return 0; }
|
||||||
|
#endif
|
||||||
|
|
||||||
int vec0_num_defined_user_columns(vec0_vtab *p) {
|
int vec0_num_defined_user_columns(vec0_vtab *p) {
|
||||||
return p->numVectorColumns + p->numPartitionColumns + p->numAuxiliaryColumns + p->numMetadataColumns;
|
return p->numVectorColumns + p->numPartitionColumns + p->numAuxiliaryColumns + p->numMetadataColumns;
|
||||||
}
|
}
|
||||||
|
|
@ -3753,6 +4002,25 @@ int vec0_get_vector_data(vec0_vtab *pVtab, i64 rowid, int vector_column_idx,
|
||||||
void **outVector, int *outVectorSize) {
|
void **outVector, int *outVectorSize) {
|
||||||
vec0_vtab *p = pVtab;
|
vec0_vtab *p = pVtab;
|
||||||
int rc, brc;
|
int rc, brc;
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_DISKANN
|
||||||
|
// DiskANN fast path: read from _vectors table
|
||||||
|
if (p->vector_columns[vector_column_idx].index_type == VEC0_INDEX_TYPE_DISKANN) {
|
||||||
|
void *vec = NULL;
|
||||||
|
int vecSize;
|
||||||
|
rc = diskann_vector_read(p, vector_column_idx, rowid, &vec, &vecSize);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
vtab_set_error(&pVtab->base,
|
||||||
|
"Could not fetch vector data for %lld from DiskANN vectors table",
|
||||||
|
rowid);
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
*outVector = vec;
|
||||||
|
if (outVectorSize) *outVectorSize = vecSize;
|
||||||
|
return SQLITE_OK;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
i64 chunk_id;
|
i64 chunk_id;
|
||||||
i64 chunk_offset;
|
i64 chunk_offset;
|
||||||
|
|
||||||
|
|
@ -4653,6 +4921,26 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
|
||||||
(i64)vecColumn.dimensions, SQLITE_VEC_VEC0_MAX_DIMENSIONS);
|
(i64)vecColumn.dimensions, SQLITE_VEC_VEC0_MAX_DIMENSIONS);
|
||||||
goto error;
|
goto error;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DiskANN validation
|
||||||
|
if (vecColumn.index_type == VEC0_INDEX_TYPE_DISKANN) {
|
||||||
|
if (vecColumn.element_type == SQLITE_VEC_ELEMENT_TYPE_BIT) {
|
||||||
|
sqlite3_free(vecColumn.name);
|
||||||
|
*pzErr = sqlite3_mprintf(
|
||||||
|
VEC_CONSTRUCTOR_ERROR
|
||||||
|
"DiskANN index is not supported on bit vector columns");
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
if (vecColumn.diskann.quantizer_type == VEC0_DISKANN_QUANTIZER_BINARY &&
|
||||||
|
(vecColumn.dimensions % CHAR_BIT) != 0) {
|
||||||
|
sqlite3_free(vecColumn.name);
|
||||||
|
*pzErr = sqlite3_mprintf(
|
||||||
|
VEC_CONSTRUCTOR_ERROR
|
||||||
|
"DiskANN with binary quantizer requires dimensions divisible by 8");
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pNew->user_column_kinds[user_column_idx] = SQLITE_VEC0_USER_COLUMN_KIND_VECTOR;
|
pNew->user_column_kinds[user_column_idx] = SQLITE_VEC0_USER_COLUMN_KIND_VECTOR;
|
||||||
pNew->user_column_idxs[user_column_idx] = numVectorColumns;
|
pNew->user_column_idxs[user_column_idx] = numVectorColumns;
|
||||||
memcpy(&pNew->vector_columns[numVectorColumns], &vecColumn, sizeof(vecColumn));
|
memcpy(&pNew->vector_columns[numVectorColumns], &vecColumn, sizeof(vecColumn));
|
||||||
|
|
@ -4881,6 +5169,31 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DiskANN columns cannot coexist with aux/metadata/partition columns
|
||||||
|
for (int i = 0; i < numVectorColumns; i++) {
|
||||||
|
if (pNew->vector_columns[i].index_type == VEC0_INDEX_TYPE_DISKANN) {
|
||||||
|
if (numAuxiliaryColumns > 0) {
|
||||||
|
*pzErr = sqlite3_mprintf(
|
||||||
|
VEC_CONSTRUCTOR_ERROR
|
||||||
|
"Auxiliary columns are not supported with DiskANN-indexed vector columns");
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
if (numMetadataColumns > 0) {
|
||||||
|
*pzErr = sqlite3_mprintf(
|
||||||
|
VEC_CONSTRUCTOR_ERROR
|
||||||
|
"Metadata columns are not supported with DiskANN-indexed vector columns");
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
if (numPartitionColumns > 0) {
|
||||||
|
*pzErr = sqlite3_mprintf(
|
||||||
|
VEC_CONSTRUCTOR_ERROR
|
||||||
|
"Partition key columns are not supported with DiskANN-indexed vector columns");
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
sqlite3_str *createStr = sqlite3_str_new(NULL);
|
sqlite3_str *createStr = sqlite3_str_new(NULL);
|
||||||
sqlite3_str_appendall(createStr, "CREATE TABLE x(");
|
sqlite3_str_appendall(createStr, "CREATE TABLE x(");
|
||||||
if (pkColumnName) {
|
if (pkColumnName) {
|
||||||
|
|
@ -4984,6 +5297,20 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
|
||||||
goto error;
|
goto error;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
#if SQLITE_VEC_ENABLE_DISKANN
|
||||||
|
if (pNew->vector_columns[i].index_type == VEC0_INDEX_TYPE_DISKANN) {
|
||||||
|
pNew->shadowVectorsNames[i] =
|
||||||
|
sqlite3_mprintf("%s_vectors%02d", tableName, i);
|
||||||
|
if (!pNew->shadowVectorsNames[i]) {
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
pNew->shadowDiskannNodesNames[i] =
|
||||||
|
sqlite3_mprintf("%s_diskann_nodes%02d", tableName, i);
|
||||||
|
if (!pNew->shadowDiskannNodesNames[i]) {
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE
|
#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE
|
||||||
|
|
@ -5060,7 +5387,32 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
|
||||||
}
|
}
|
||||||
sqlite3_finalize(stmt);
|
sqlite3_finalize(stmt);
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_DISKANN
|
||||||
|
// Seed medoid entries for DiskANN-indexed columns
|
||||||
|
for (int i = 0; i < pNew->numVectorColumns; i++) {
|
||||||
|
if (pNew->vector_columns[i].index_type != VEC0_INDEX_TYPE_DISKANN) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
char *key = sqlite3_mprintf("diskann_medoid_%02d", i);
|
||||||
|
char *zInsert = sqlite3_mprintf(
|
||||||
|
"INSERT INTO " VEC0_SHADOW_INFO_NAME "(key, value) VALUES (?1, ?2)",
|
||||||
|
pNew->schemaName, pNew->tableName);
|
||||||
|
rc = sqlite3_prepare_v2(db, zInsert, -1, &stmt, NULL);
|
||||||
|
sqlite3_free(zInsert);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
sqlite3_free(key);
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
sqlite3_bind_text(stmt, 1, key, -1, sqlite3_free);
|
||||||
|
sqlite3_bind_null(stmt, 2); // NULL means empty graph
|
||||||
|
if (sqlite3_step(stmt) != SQLITE_DONE) {
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// create the _chunks shadow table
|
// create the _chunks shadow table
|
||||||
char *zCreateShadowChunks = NULL;
|
char *zCreateShadowChunks = NULL;
|
||||||
|
|
@ -5118,7 +5470,7 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
|
||||||
|
|
||||||
for (int i = 0; i < pNew->numVectorColumns; i++) {
|
for (int i = 0; i < pNew->numVectorColumns; i++) {
|
||||||
#if SQLITE_VEC_ENABLE_RESCORE
|
#if SQLITE_VEC_ENABLE_RESCORE
|
||||||
// Rescore and IVF columns don't use _vector_chunks
|
// Non-FLAT columns don't use _vector_chunks
|
||||||
if (pNew->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT)
|
if (pNew->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT)
|
||||||
continue;
|
continue;
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -5159,6 +5511,84 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_DISKANN
|
||||||
|
// Create DiskANN shadow tables for indexed vector columns
|
||||||
|
for (int i = 0; i < pNew->numVectorColumns; i++) {
|
||||||
|
if (pNew->vector_columns[i].index_type != VEC0_INDEX_TYPE_DISKANN) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create _vectors{NN} table
|
||||||
|
{
|
||||||
|
char *zSql = sqlite3_mprintf(
|
||||||
|
"CREATE TABLE " VEC0_SHADOW_VECTORS_N_NAME
|
||||||
|
" (rowid INTEGER PRIMARY KEY, vector BLOB NOT NULL);",
|
||||||
|
pNew->schemaName, pNew->tableName, i);
|
||||||
|
if (!zSql) {
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, 0);
|
||||||
|
sqlite3_free(zSql);
|
||||||
|
if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) {
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
*pzErr = sqlite3_mprintf(
|
||||||
|
"Could not create '_vectors%02d' shadow table: %s", i,
|
||||||
|
sqlite3_errmsg(db));
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create _diskann_nodes{NN} table
|
||||||
|
{
|
||||||
|
char *zSql = sqlite3_mprintf(
|
||||||
|
"CREATE TABLE " VEC0_SHADOW_DISKANN_NODES_N_NAME " ("
|
||||||
|
"rowid INTEGER PRIMARY KEY, "
|
||||||
|
"neighbors_validity BLOB NOT NULL, "
|
||||||
|
"neighbor_ids BLOB NOT NULL, "
|
||||||
|
"neighbor_quantized_vectors BLOB NOT NULL"
|
||||||
|
");",
|
||||||
|
pNew->schemaName, pNew->tableName, i);
|
||||||
|
if (!zSql) {
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, 0);
|
||||||
|
sqlite3_free(zSql);
|
||||||
|
if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) {
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
*pzErr = sqlite3_mprintf(
|
||||||
|
"Could not create '_diskann_nodes%02d' shadow table: %s", i,
|
||||||
|
sqlite3_errmsg(db));
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create _diskann_buffer{NN} table (for batched inserts)
|
||||||
|
{
|
||||||
|
char *zSql = sqlite3_mprintf(
|
||||||
|
"CREATE TABLE " VEC0_SHADOW_DISKANN_BUFFER_N_NAME " ("
|
||||||
|
"rowid INTEGER PRIMARY KEY, "
|
||||||
|
"vector BLOB NOT NULL"
|
||||||
|
");",
|
||||||
|
pNew->schemaName, pNew->tableName, i);
|
||||||
|
if (!zSql) {
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, 0);
|
||||||
|
sqlite3_free(zSql);
|
||||||
|
if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) {
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
*pzErr = sqlite3_mprintf(
|
||||||
|
"Could not create '_diskann_buffer%02d' shadow table: %s", i,
|
||||||
|
sqlite3_errmsg(db));
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// See SHADOW_TABLE_ROWID_QUIRK in vec0_new_chunk() — same "rowid PRIMARY KEY"
|
// See SHADOW_TABLE_ROWID_QUIRK in vec0_new_chunk() — same "rowid PRIMARY KEY"
|
||||||
// without INTEGER type issue applies here.
|
// without INTEGER type issue applies here.
|
||||||
for (int i = 0; i < pNew->numMetadataColumns; i++) {
|
for (int i = 0; i < pNew->numMetadataColumns; i++) {
|
||||||
|
|
@ -5293,6 +5723,45 @@ static int vec0Destroy(sqlite3_vtab *pVtab) {
|
||||||
sqlite3_finalize(stmt);
|
sqlite3_finalize(stmt);
|
||||||
|
|
||||||
for (int i = 0; i < p->numVectorColumns; i++) {
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
|
#if SQLITE_VEC_ENABLE_DISKANN
|
||||||
|
if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_DISKANN) {
|
||||||
|
// Drop DiskANN shadow tables
|
||||||
|
zSql = sqlite3_mprintf("DROP TABLE IF EXISTS " VEC0_SHADOW_VECTORS_N_NAME,
|
||||||
|
p->schemaName, p->tableName, i);
|
||||||
|
if (zSql) {
|
||||||
|
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0);
|
||||||
|
sqlite3_free((void *)zSql);
|
||||||
|
if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) {
|
||||||
|
rc = SQLITE_ERROR;
|
||||||
|
goto done;
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
zSql = sqlite3_mprintf("DROP TABLE IF EXISTS " VEC0_SHADOW_DISKANN_NODES_N_NAME,
|
||||||
|
p->schemaName, p->tableName, i);
|
||||||
|
if (zSql) {
|
||||||
|
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0);
|
||||||
|
sqlite3_free((void *)zSql);
|
||||||
|
if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) {
|
||||||
|
rc = SQLITE_ERROR;
|
||||||
|
goto done;
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
zSql = sqlite3_mprintf("DROP TABLE IF EXISTS " VEC0_SHADOW_DISKANN_BUFFER_N_NAME,
|
||||||
|
p->schemaName, p->tableName, i);
|
||||||
|
if (zSql) {
|
||||||
|
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0);
|
||||||
|
sqlite3_free((void *)zSql);
|
||||||
|
if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) {
|
||||||
|
rc = SQLITE_ERROR;
|
||||||
|
goto done;
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
#if SQLITE_VEC_ENABLE_RESCORE
|
#if SQLITE_VEC_ENABLE_RESCORE
|
||||||
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT)
|
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT)
|
||||||
continue;
|
continue;
|
||||||
|
|
@ -7088,6 +7557,171 @@ cleanup:
|
||||||
#include "sqlite-vec-rescore.c"
|
#include "sqlite-vec-rescore.c"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_DISKANN
|
||||||
|
/**
|
||||||
|
* Handle a KNN query using the DiskANN graph search.
|
||||||
|
*/
|
||||||
|
static int vec0Filter_knn_diskann(
|
||||||
|
vec0_cursor *pCur, vec0_vtab *p, int idxNum,
|
||||||
|
const char *idxStr, int argc, sqlite3_value **argv) {
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
int vectorColumnIdx = idxNum;
|
||||||
|
struct VectorColumnDefinition *vector_column = &p->vector_columns[vectorColumnIdx];
|
||||||
|
struct vec0_query_knn_data *knn_data;
|
||||||
|
|
||||||
|
knn_data = sqlite3_malloc(sizeof(*knn_data));
|
||||||
|
if (!knn_data) return SQLITE_NOMEM;
|
||||||
|
memset(knn_data, 0, sizeof(*knn_data));
|
||||||
|
|
||||||
|
// Parse query_idx and k_idx from idxStr
|
||||||
|
int query_idx = -1;
|
||||||
|
int k_idx = -1;
|
||||||
|
for (int i = 0; i < argc; i++) {
|
||||||
|
if (idxStr[1 + (i * 4)] == VEC0_IDXSTR_KIND_KNN_MATCH) {
|
||||||
|
query_idx = i;
|
||||||
|
}
|
||||||
|
if (idxStr[1 + (i * 4)] == VEC0_IDXSTR_KIND_KNN_K) {
|
||||||
|
k_idx = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert(query_idx >= 0);
|
||||||
|
assert(k_idx >= 0);
|
||||||
|
|
||||||
|
// Extract query vector
|
||||||
|
void *queryVector;
|
||||||
|
size_t dimensions;
|
||||||
|
enum VectorElementType elementType;
|
||||||
|
vector_cleanup queryVectorCleanup = vector_cleanup_noop;
|
||||||
|
char *pzError;
|
||||||
|
|
||||||
|
rc = vector_from_value(argv[query_idx], &queryVector, &dimensions,
|
||||||
|
&elementType, &queryVectorCleanup, &pzError);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
vtab_set_error(&p->base, "Invalid query vector: %z", pzError);
|
||||||
|
sqlite3_free(knn_data);
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (elementType != vector_column->element_type ||
|
||||||
|
dimensions != vector_column->dimensions) {
|
||||||
|
vtab_set_error(&p->base, "Query vector type/dimension mismatch");
|
||||||
|
queryVectorCleanup(queryVector);
|
||||||
|
sqlite3_free(knn_data);
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
i64 k = sqlite3_value_int64(argv[k_idx]);
|
||||||
|
if (k <= 0) {
|
||||||
|
knn_data->k = 0;
|
||||||
|
knn_data->k_used = 0;
|
||||||
|
pCur->knn_data = knn_data;
|
||||||
|
pCur->query_plan = VEC0_QUERY_PLAN_KNN;
|
||||||
|
queryVectorCleanup(queryVector);
|
||||||
|
return SQLITE_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run DiskANN search
|
||||||
|
i64 *resultRowids = sqlite3_malloc(k * sizeof(i64));
|
||||||
|
f32 *resultDistances = sqlite3_malloc(k * sizeof(f32));
|
||||||
|
if (!resultRowids || !resultDistances) {
|
||||||
|
sqlite3_free(resultRowids);
|
||||||
|
sqlite3_free(resultDistances);
|
||||||
|
queryVectorCleanup(queryVector);
|
||||||
|
sqlite3_free(knn_data);
|
||||||
|
return SQLITE_NOMEM;
|
||||||
|
}
|
||||||
|
|
||||||
|
int resultCount;
|
||||||
|
rc = diskann_search(p, vectorColumnIdx, queryVector, dimensions,
|
||||||
|
elementType, (int)k, 0,
|
||||||
|
resultRowids, resultDistances, &resultCount);
|
||||||
|
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
queryVectorCleanup(queryVector);
|
||||||
|
sqlite3_free(resultRowids);
|
||||||
|
sqlite3_free(resultDistances);
|
||||||
|
sqlite3_free(knn_data);
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan _diskann_buffer for any buffered (unflushed) vectors and merge
|
||||||
|
// with graph results. This ensures no recall loss for buffered vectors.
|
||||||
|
{
|
||||||
|
sqlite3_stmt *bufStmt = NULL;
|
||||||
|
char *zSql = sqlite3_mprintf(
|
||||||
|
"SELECT rowid, vector FROM " VEC0_SHADOW_DISKANN_BUFFER_N_NAME,
|
||||||
|
p->schemaName, p->tableName, vectorColumnIdx);
|
||||||
|
if (!zSql) {
|
||||||
|
queryVectorCleanup(queryVector);
|
||||||
|
sqlite3_free(resultRowids);
|
||||||
|
sqlite3_free(resultDistances);
|
||||||
|
sqlite3_free(knn_data);
|
||||||
|
return SQLITE_NOMEM;
|
||||||
|
}
|
||||||
|
int bufRc = sqlite3_prepare_v2(p->db, zSql, -1, &bufStmt, NULL);
|
||||||
|
sqlite3_free(zSql);
|
||||||
|
if (bufRc == SQLITE_OK) {
|
||||||
|
while (sqlite3_step(bufStmt) == SQLITE_ROW) {
|
||||||
|
i64 bufRowid = sqlite3_column_int64(bufStmt, 0);
|
||||||
|
const void *bufVec = sqlite3_column_blob(bufStmt, 1);
|
||||||
|
f32 dist = vec0_distance_full(
|
||||||
|
queryVector, bufVec, dimensions, elementType,
|
||||||
|
vector_column->distance_metric);
|
||||||
|
|
||||||
|
// Check if this buffer vector should replace the worst graph result
|
||||||
|
if (resultCount < (int)k) {
|
||||||
|
// Still have room, just add it
|
||||||
|
resultRowids[resultCount] = bufRowid;
|
||||||
|
resultDistances[resultCount] = dist;
|
||||||
|
resultCount++;
|
||||||
|
} else {
|
||||||
|
// Find worst (largest distance) in results
|
||||||
|
int worstIdx = 0;
|
||||||
|
for (int wi = 1; wi < resultCount; wi++) {
|
||||||
|
if (resultDistances[wi] > resultDistances[worstIdx]) {
|
||||||
|
worstIdx = wi;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (dist < resultDistances[worstIdx]) {
|
||||||
|
resultRowids[worstIdx] = bufRowid;
|
||||||
|
resultDistances[worstIdx] = dist;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sqlite3_finalize(bufStmt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
queryVectorCleanup(queryVector);
|
||||||
|
|
||||||
|
// Sort results by distance (ascending)
|
||||||
|
for (int si = 0; si < resultCount - 1; si++) {
|
||||||
|
for (int sj = si + 1; sj < resultCount; sj++) {
|
||||||
|
if (resultDistances[sj] < resultDistances[si]) {
|
||||||
|
f32 tmpD = resultDistances[si];
|
||||||
|
resultDistances[si] = resultDistances[sj];
|
||||||
|
resultDistances[sj] = tmpD;
|
||||||
|
i64 tmpR = resultRowids[si];
|
||||||
|
resultRowids[si] = resultRowids[sj];
|
||||||
|
resultRowids[sj] = tmpR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
knn_data->k = resultCount;
|
||||||
|
knn_data->k_used = resultCount;
|
||||||
|
knn_data->rowids = resultRowids;
|
||||||
|
knn_data->distances = resultDistances;
|
||||||
|
knn_data->current_idx = 0;
|
||||||
|
|
||||||
|
pCur->knn_data = knn_data;
|
||||||
|
pCur->query_plan = VEC0_QUERY_PLAN_KNN;
|
||||||
|
|
||||||
|
return SQLITE_OK;
|
||||||
|
}
|
||||||
|
#endif /* SQLITE_VEC_ENABLE_DISKANN */
|
||||||
|
|
||||||
int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
|
int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
|
||||||
const char *idxStr, int argc, sqlite3_value **argv) {
|
const char *idxStr, int argc, sqlite3_value **argv) {
|
||||||
assert(argc == (strlen(idxStr)-1) / 4);
|
assert(argc == (strlen(idxStr)-1) / 4);
|
||||||
|
|
@ -7098,6 +7732,13 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
|
||||||
struct VectorColumnDefinition *vector_column =
|
struct VectorColumnDefinition *vector_column =
|
||||||
&p->vector_columns[vectorColumnIdx];
|
&p->vector_columns[vectorColumnIdx];
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_DISKANN
|
||||||
|
// DiskANN dispatch
|
||||||
|
if (vector_column->index_type == VEC0_INDEX_TYPE_DISKANN) {
|
||||||
|
return vec0Filter_knn_diskann(pCur, p, idxNum, idxStr, argc, argv);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
struct Array *arrayRowidsIn = NULL;
|
struct Array *arrayRowidsIn = NULL;
|
||||||
sqlite3_stmt *stmtChunks = NULL;
|
sqlite3_stmt *stmtChunks = NULL;
|
||||||
void *queryVector;
|
void *queryVector;
|
||||||
|
|
@ -8567,24 +9208,37 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
|
||||||
goto cleanup;
|
goto cleanup;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step #2: Find the next "available" position in the _chunks table for this
|
if (!vec0_all_columns_diskann(p)) {
|
||||||
// row.
|
// Step #2: Find the next "available" position in the _chunks table for this
|
||||||
rc = vec0Update_InsertNextAvailableStep(p, partitionKeyValues,
|
// row.
|
||||||
&chunk_rowid, &chunk_offset,
|
rc = vec0Update_InsertNextAvailableStep(p, partitionKeyValues,
|
||||||
&blobChunksValidity,
|
&chunk_rowid, &chunk_offset,
|
||||||
&bufferChunksValidity);
|
&blobChunksValidity,
|
||||||
if (rc != SQLITE_OK) {
|
&bufferChunksValidity);
|
||||||
goto cleanup;
|
if (rc != SQLITE_OK) {
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step #3: With the next available chunk position, write out all the vectors
|
||||||
|
// to their specified location.
|
||||||
|
rc = vec0Update_InsertWriteFinalStep(p, chunk_rowid, chunk_offset, rowid,
|
||||||
|
vectorDatas, blobChunksValidity,
|
||||||
|
bufferChunksValidity);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step #3: With the next available chunk position, write out all the vectors
|
#if SQLITE_VEC_ENABLE_DISKANN
|
||||||
// to their specified location.
|
// Step #4: Insert into DiskANN graph for indexed vector columns
|
||||||
rc = vec0Update_InsertWriteFinalStep(p, chunk_rowid, chunk_offset, rowid,
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
vectorDatas, blobChunksValidity,
|
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_DISKANN) continue;
|
||||||
bufferChunksValidity);
|
rc = diskann_insert(p, i, rowid, vectorDatas[i]);
|
||||||
if (rc != SQLITE_OK) {
|
if (rc != SQLITE_OK) {
|
||||||
goto cleanup;
|
goto cleanup;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#if SQLITE_VEC_ENABLE_RESCORE
|
#if SQLITE_VEC_ENABLE_RESCORE
|
||||||
rc = rescore_on_insert(p, chunk_rowid, chunk_offset, rowid, vectorDatas);
|
rc = rescore_on_insert(p, chunk_rowid, chunk_offset, rowid, vectorDatas);
|
||||||
|
|
@ -9126,29 +9780,43 @@ int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite3_value *idValue) {
|
||||||
// 4. Zero out vector data in all vector column chunks
|
// 4. Zero out vector data in all vector column chunks
|
||||||
// 5. Delete value in _rowids table
|
// 5. Delete value in _rowids table
|
||||||
|
|
||||||
// 1. get chunk_id and chunk_offset from _rowids
|
#if SQLITE_VEC_ENABLE_DISKANN
|
||||||
rc = vec0_get_chunk_position(p, rowid, NULL, &chunk_id, &chunk_offset);
|
// DiskANN graph deletion for indexed columns
|
||||||
if (rc != SQLITE_OK) {
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
return rc;
|
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_DISKANN) continue;
|
||||||
|
rc = diskann_delete(p, i, rowid);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (!vec0_all_columns_diskann(p)) {
|
||||||
|
// 1. get chunk_id and chunk_offset from _rowids
|
||||||
|
rc = vec0_get_chunk_position(p, rowid, NULL, &chunk_id, &chunk_offset);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. clear validity bit
|
||||||
|
rc = vec0Update_Delete_ClearValidity(p, chunk_id, chunk_offset);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. zero out rowid in chunks.rowids
|
||||||
|
rc = vec0Update_Delete_ClearRowid(p, chunk_id, chunk_offset);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. zero out any data in vector chunks tables
|
||||||
|
rc = vec0Update_Delete_ClearVectors(p, chunk_id, chunk_offset);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. clear validity bit
|
|
||||||
rc = vec0Update_Delete_ClearValidity(p, chunk_id, chunk_offset);
|
|
||||||
if (rc != SQLITE_OK) {
|
|
||||||
return rc;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. zero out rowid in chunks.rowids
|
|
||||||
rc = vec0Update_Delete_ClearRowid(p, chunk_id, chunk_offset);
|
|
||||||
if (rc != SQLITE_OK) {
|
|
||||||
return rc;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4. zero out any data in vector chunks tables
|
|
||||||
rc = vec0Update_Delete_ClearVectors(p, chunk_id, chunk_offset);
|
|
||||||
if (rc != SQLITE_OK) {
|
|
||||||
return rc;
|
|
||||||
}
|
|
||||||
|
|
||||||
#if SQLITE_VEC_ENABLE_RESCORE
|
#if SQLITE_VEC_ENABLE_RESCORE
|
||||||
// 4b. zero out quantized data in rescore chunk tables, delete from rescore vectors
|
// 4b. zero out quantized data in rescore chunk tables, delete from rescore vectors
|
||||||
|
|
@ -9172,20 +9840,22 @@ int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite3_value *idValue) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 7. delete metadata
|
// 7. delete metadata and reclaim chunk (only when using chunk-based storage)
|
||||||
for(int i = 0; i < p->numMetadataColumns; i++) {
|
if (!vec0_all_columns_diskann(p)) {
|
||||||
rc = vec0Update_Delete_ClearMetadata(p, i, rowid, chunk_id, chunk_offset);
|
for(int i = 0; i < p->numMetadataColumns; i++) {
|
||||||
if (rc != SQLITE_OK) {
|
rc = vec0Update_Delete_ClearMetadata(p, i, rowid, chunk_id, chunk_offset);
|
||||||
return rc;
|
if (rc != SQLITE_OK) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// 8. reclaim chunk if fully empty
|
// 8. reclaim chunk if fully empty
|
||||||
{
|
{
|
||||||
int chunkDeleted;
|
int chunkDeleted;
|
||||||
rc = vec0Update_Delete_DeleteChunkIfEmpty(p, chunk_id, &chunkDeleted);
|
rc = vec0Update_Delete_DeleteChunkIfEmpty(p, chunk_id, &chunkDeleted);
|
||||||
if (rc != SQLITE_OK) {
|
if (rc != SQLITE_OK) {
|
||||||
return rc;
|
return rc;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -9481,8 +10151,12 @@ static int vec0Update(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
|
||||||
const char *cmd = (const char *)sqlite3_value_text(idVal);
|
const char *cmd = (const char *)sqlite3_value_text(idVal);
|
||||||
vec0_vtab *p = (vec0_vtab *)pVTab;
|
vec0_vtab *p = (vec0_vtab *)pVTab;
|
||||||
int cmdRc = ivf_handle_command(p, cmd, argc, argv);
|
int cmdRc = ivf_handle_command(p, cmd, argc, argv);
|
||||||
|
#if SQLITE_VEC_ENABLE_DISKANN
|
||||||
|
if (cmdRc == SQLITE_EMPTY)
|
||||||
|
cmdRc = diskann_handle_command(p, cmd);
|
||||||
|
#endif
|
||||||
if (cmdRc != SQLITE_EMPTY) return cmdRc; // handled (or error)
|
if (cmdRc != SQLITE_EMPTY) return cmdRc; // handled (or error)
|
||||||
// SQLITE_EMPTY means not an IVF command — fall through to normal insert
|
// SQLITE_EMPTY means not a recognized command — fall through to normal insert
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
return vec0Update_Insert(pVTab, argc, argv, pRowid);
|
return vec0Update_Insert(pVTab, argc, argv, pRowid);
|
||||||
|
|
@ -9638,9 +10312,16 @@ static sqlite3_module vec0Module = {
|
||||||
#define SQLITE_VEC_DEBUG_BUILD_IVF ""
|
#define SQLITE_VEC_DEBUG_BUILD_IVF ""
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_DISKANN
|
||||||
|
#define SQLITE_VEC_DEBUG_BUILD_DISKANN "diskann"
|
||||||
|
#else
|
||||||
|
#define SQLITE_VEC_DEBUG_BUILD_DISKANN ""
|
||||||
|
#endif
|
||||||
|
|
||||||
#define SQLITE_VEC_DEBUG_BUILD \
|
#define SQLITE_VEC_DEBUG_BUILD \
|
||||||
SQLITE_VEC_DEBUG_BUILD_AVX " " SQLITE_VEC_DEBUG_BUILD_NEON " " \
|
SQLITE_VEC_DEBUG_BUILD_AVX " " SQLITE_VEC_DEBUG_BUILD_NEON " " \
|
||||||
SQLITE_VEC_DEBUG_BUILD_RESCORE " " SQLITE_VEC_DEBUG_BUILD_IVF
|
SQLITE_VEC_DEBUG_BUILD_RESCORE " " SQLITE_VEC_DEBUG_BUILD_IVF " " \
|
||||||
|
SQLITE_VEC_DEBUG_BUILD_DISKANN
|
||||||
|
|
||||||
#define SQLITE_VEC_DEBUG_STRING \
|
#define SQLITE_VEC_DEBUG_STRING \
|
||||||
"Version: " SQLITE_VEC_VERSION "\n" \
|
"Version: " SQLITE_VEC_VERSION "\n" \
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ FUZZ_LDFLAGS ?= $(shell \
|
||||||
echo "-Wl,-ld_classic"; \
|
echo "-Wl,-ld_classic"; \
|
||||||
fi)
|
fi)
|
||||||
|
|
||||||
FUZZ_CFLAGS = $(FUZZ_SANITIZERS) -I ../../ -I ../../vendor -DSQLITE_CORE -g $(FUZZ_LDFLAGS)
|
FUZZ_CFLAGS = $(FUZZ_SANITIZERS) -I ../../ -I ../../vendor -DSQLITE_CORE -DSQLITE_VEC_ENABLE_DISKANN=1 -g $(FUZZ_LDFLAGS)
|
||||||
FUZZ_SRCS = ../../vendor/sqlite3.c ../../sqlite-vec.c
|
FUZZ_SRCS = ../../vendor/sqlite3.c ../../sqlite-vec.c
|
||||||
|
|
||||||
TARGET_DIR = ./targets
|
TARGET_DIR = ./targets
|
||||||
|
|
@ -115,6 +115,34 @@ $(TARGET_DIR)/ivf_cell_overflow: ivf-cell-overflow.c $(FUZZ_SRCS) | $(TARGET_DIR
|
||||||
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
$(TARGET_DIR)/ivf_rescore: ivf-rescore.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
$(TARGET_DIR)/ivf_rescore: ivf-rescore.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(TARGET_DIR)/diskann_operations: diskann-operations.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/diskann_create: diskann-create.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/diskann_graph_corrupt: diskann-graph-corrupt.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/diskann_deep_search: diskann-deep-search.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/diskann_blob_truncate: diskann-blob-truncate.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/diskann_delete_stress: diskann-delete-stress.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/diskann_buffer_flush: diskann-buffer-flush.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/diskann_int8_quant: diskann-int8-quant.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/diskann_prune_direct: diskann-prune-direct.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/diskann_command_inject: diskann-command-inject.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
FUZZ_TARGETS = vec0_create exec json numpy \
|
FUZZ_TARGETS = vec0_create exec json numpy \
|
||||||
|
|
@ -127,6 +155,11 @@ FUZZ_TARGETS = vec0_create exec json numpy \
|
||||||
ivf_create ivf_operations \
|
ivf_create ivf_operations \
|
||||||
ivf_quantize ivf_kmeans ivf_shadow_corrupt \
|
ivf_quantize ivf_kmeans ivf_shadow_corrupt \
|
||||||
ivf_knn_deep ivf_cell_overflow ivf_rescore
|
ivf_knn_deep ivf_cell_overflow ivf_rescore
|
||||||
|
diskann_operations diskann_create diskann_graph_corrupt \
|
||||||
|
diskann_deep_search diskann_blob_truncate \
|
||||||
|
diskann_delete_stress diskann_buffer_flush \
|
||||||
|
diskann_int8_quant diskann_prune_direct \
|
||||||
|
diskann_command_inject
|
||||||
|
|
||||||
all: $(addprefix $(TARGET_DIR)/,$(FUZZ_TARGETS))
|
all: $(addprefix $(TARGET_DIR)/,$(FUZZ_TARGETS))
|
||||||
|
|
||||||
|
|
|
||||||
250
tests/fuzz/diskann-blob-truncate.c
Normal file
250
tests/fuzz/diskann-blob-truncate.c
Normal file
|
|
@ -0,0 +1,250 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target for DiskANN shadow table blob size mismatches.
|
||||||
|
*
|
||||||
|
* The critical vulnerability: diskann_node_read() copies whatever blob size
|
||||||
|
* SQLite returns, but diskann_search/insert/delete index into those blobs
|
||||||
|
* using cfg->n_neighbors * sizeof(i64) etc. If the blob is truncated,
|
||||||
|
* extended, or has wrong size, this causes out-of-bounds reads/writes.
|
||||||
|
*
|
||||||
|
* This fuzzer:
|
||||||
|
* 1. Creates a valid DiskANN graph with several nodes
|
||||||
|
* 2. Uses fuzz data to directly write malformed blobs to shadow tables:
|
||||||
|
* - Truncated neighbor_ids (fewer bytes than n_neighbors * 8)
|
||||||
|
* - Truncated validity bitmaps
|
||||||
|
* - Oversized blobs with garbage trailing data
|
||||||
|
* - Zero-length blobs
|
||||||
|
* - Blobs with valid headers but corrupted neighbor rowids
|
||||||
|
* 3. Runs INSERT, DELETE, and KNN operations that traverse the corrupted graph
|
||||||
|
*
|
||||||
|
* Key code paths targeted:
|
||||||
|
* - diskann_node_read with mismatched blob sizes
|
||||||
|
* - diskann_validity_get / diskann_neighbor_id_get on truncated blobs
|
||||||
|
* - diskann_add_reverse_edge reading corrupted neighbor data
|
||||||
|
* - diskann_repair_reverse_edges traversing corrupted neighbor lists
|
||||||
|
* - diskann_search iterating neighbors from corrupted blobs
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
|
||||||
|
if (*size == 0) return def;
|
||||||
|
uint8_t b = **data;
|
||||||
|
(*data)++;
|
||||||
|
(*size)--;
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 32) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
/* Use binary quantizer, float[16], n_neighbors=8 for predictable blob sizes:
|
||||||
|
* validity: 8/8 = 1 byte
|
||||||
|
* neighbor_ids: 8 * 8 = 64 bytes
|
||||||
|
* qvecs: 8 * (16/8) = 16 bytes (binary: 2 bytes per qvec)
|
||||||
|
*/
|
||||||
|
rc = sqlite3_exec(db,
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[16] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8))",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
/* Insert 12 vectors to create a valid graph structure */
|
||||||
|
{
|
||||||
|
sqlite3_stmt *stmt;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmt, NULL);
|
||||||
|
for (int i = 1; i <= 12; i++) {
|
||||||
|
float vec[16];
|
||||||
|
for (int j = 0; j < 16; j++) {
|
||||||
|
vec[j] = (float)i * 0.1f + (float)j * 0.01f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmt);
|
||||||
|
sqlite3_bind_int64(stmt, 1, i);
|
||||||
|
sqlite3_bind_blob(stmt, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmt);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Now corrupt shadow table blobs using fuzz data */
|
||||||
|
const char *columns[] = {
|
||||||
|
"neighbors_validity",
|
||||||
|
"neighbor_ids",
|
||||||
|
"neighbor_quantized_vectors"
|
||||||
|
};
|
||||||
|
|
||||||
|
/* Expected sizes for n_neighbors=8, dims=16, binary quantizer */
|
||||||
|
int expected_sizes[] = {1, 64, 16};
|
||||||
|
|
||||||
|
while (size >= 4) {
|
||||||
|
int target_row = (fuzz_byte(&data, &size, 0) % 12) + 1;
|
||||||
|
int col_idx = fuzz_byte(&data, &size, 0) % 3;
|
||||||
|
uint8_t corrupt_mode = fuzz_byte(&data, &size, 0) % 6;
|
||||||
|
uint8_t extra = fuzz_byte(&data, &size, 0);
|
||||||
|
|
||||||
|
char sqlbuf[256];
|
||||||
|
snprintf(sqlbuf, sizeof(sqlbuf),
|
||||||
|
"UPDATE v_diskann_nodes00 SET %s = ? WHERE rowid = ?",
|
||||||
|
columns[col_idx]);
|
||||||
|
|
||||||
|
sqlite3_stmt *writeStmt;
|
||||||
|
rc = sqlite3_prepare_v2(db, sqlbuf, -1, &writeStmt, NULL);
|
||||||
|
if (rc != SQLITE_OK) continue;
|
||||||
|
|
||||||
|
int expected = expected_sizes[col_idx];
|
||||||
|
unsigned char *blob = NULL;
|
||||||
|
int blob_size = 0;
|
||||||
|
|
||||||
|
switch (corrupt_mode) {
|
||||||
|
case 0: {
|
||||||
|
/* Truncated blob: 0 to expected-1 bytes */
|
||||||
|
blob_size = extra % expected;
|
||||||
|
if (blob_size == 0) blob_size = 0; /* zero-length is interesting */
|
||||||
|
blob = sqlite3_malloc(blob_size > 0 ? blob_size : 1);
|
||||||
|
if (!blob) { sqlite3_finalize(writeStmt); continue; }
|
||||||
|
for (int i = 0; i < blob_size; i++) {
|
||||||
|
blob[i] = fuzz_byte(&data, &size, 0);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 1: {
|
||||||
|
/* Oversized blob: expected + extra bytes */
|
||||||
|
blob_size = expected + (extra % 64);
|
||||||
|
blob = sqlite3_malloc(blob_size);
|
||||||
|
if (!blob) { sqlite3_finalize(writeStmt); continue; }
|
||||||
|
for (int i = 0; i < blob_size; i++) {
|
||||||
|
blob[i] = fuzz_byte(&data, &size, 0xFF);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2: {
|
||||||
|
/* Zero-length blob */
|
||||||
|
blob_size = 0;
|
||||||
|
blob = NULL;
|
||||||
|
sqlite3_bind_zeroblob(writeStmt, 1, 0);
|
||||||
|
sqlite3_bind_int64(writeStmt, 2, target_row);
|
||||||
|
sqlite3_step(writeStmt);
|
||||||
|
sqlite3_finalize(writeStmt);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
case 3: {
|
||||||
|
/* Correct size but all-ones validity (all slots "valid") with
|
||||||
|
* garbage neighbor IDs -- forces reading non-existent nodes */
|
||||||
|
blob_size = expected;
|
||||||
|
blob = sqlite3_malloc(blob_size);
|
||||||
|
if (!blob) { sqlite3_finalize(writeStmt); continue; }
|
||||||
|
memset(blob, 0xFF, blob_size);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 4: {
|
||||||
|
/* neighbor_ids with very large rowid values (near INT64_MAX) */
|
||||||
|
blob_size = expected;
|
||||||
|
blob = sqlite3_malloc(blob_size);
|
||||||
|
if (!blob) { sqlite3_finalize(writeStmt); continue; }
|
||||||
|
memset(blob, 0x7F, blob_size); /* fills with large positive values */
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 5: {
|
||||||
|
/* neighbor_ids with negative rowid values (rowid=0 is sentinel) */
|
||||||
|
blob_size = expected;
|
||||||
|
blob = sqlite3_malloc(blob_size);
|
||||||
|
if (!blob) { sqlite3_finalize(writeStmt); continue; }
|
||||||
|
memset(blob, 0x80, blob_size); /* fills with large negative values */
|
||||||
|
/* Flip some bytes from fuzz data */
|
||||||
|
for (int i = 0; i < blob_size && size > 0; i++) {
|
||||||
|
blob[i] ^= fuzz_byte(&data, &size, 0);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (blob) {
|
||||||
|
sqlite3_bind_blob(writeStmt, 1, blob, blob_size, SQLITE_TRANSIENT);
|
||||||
|
} else {
|
||||||
|
sqlite3_bind_blob(writeStmt, 1, "", 0, SQLITE_STATIC);
|
||||||
|
}
|
||||||
|
sqlite3_bind_int64(writeStmt, 2, target_row);
|
||||||
|
sqlite3_step(writeStmt);
|
||||||
|
sqlite3_finalize(writeStmt);
|
||||||
|
sqlite3_free(blob);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Exercise the corrupted graph with various operations */
|
||||||
|
|
||||||
|
/* KNN query */
|
||||||
|
{
|
||||||
|
float qvec[16];
|
||||||
|
for (int j = 0; j < 16; j++) qvec[j] = (float)j * 0.1f;
|
||||||
|
sqlite3_stmt *knnStmt;
|
||||||
|
rc = sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = 5",
|
||||||
|
-1, &knnStmt, NULL);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_bind_blob(knnStmt, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||||
|
while (sqlite3_step(knnStmt) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(knnStmt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Insert into corrupted graph (triggers add_reverse_edge on corrupted nodes) */
|
||||||
|
{
|
||||||
|
float vec[16];
|
||||||
|
for (int j = 0; j < 16; j++) vec[j] = 0.5f;
|
||||||
|
sqlite3_stmt *stmt;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmt, NULL);
|
||||||
|
if (stmt) {
|
||||||
|
sqlite3_bind_int64(stmt, 1, 100);
|
||||||
|
sqlite3_bind_blob(stmt, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmt);
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Delete from corrupted graph (triggers repair_reverse_edges) */
|
||||||
|
{
|
||||||
|
sqlite3_stmt *stmt;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"DELETE FROM v WHERE rowid = ?", -1, &stmt, NULL);
|
||||||
|
if (stmt) {
|
||||||
|
sqlite3_bind_int64(stmt, 1, 5);
|
||||||
|
sqlite3_step(stmt);
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Another KNN to traverse the post-mutation graph */
|
||||||
|
{
|
||||||
|
float qvec[16];
|
||||||
|
for (int j = 0; j < 16; j++) qvec[j] = -0.5f + (float)j * 0.07f;
|
||||||
|
sqlite3_stmt *knnStmt;
|
||||||
|
rc = sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = 12",
|
||||||
|
-1, &knnStmt, NULL);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_bind_blob(knnStmt, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||||
|
while (sqlite3_step(knnStmt) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(knnStmt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Full scan */
|
||||||
|
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
164
tests/fuzz/diskann-buffer-flush.c
Normal file
164
tests/fuzz/diskann-buffer-flush.c
Normal file
|
|
@ -0,0 +1,164 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target for DiskANN buffered insert and flush paths.
|
||||||
|
*
|
||||||
|
* When buffer_threshold > 0, inserts go into a flat buffer table and
|
||||||
|
* are flushed into the graph in batch. This fuzzer exercises:
|
||||||
|
*
|
||||||
|
* - diskann_buffer_write / diskann_buffer_delete / diskann_buffer_exists
|
||||||
|
* - diskann_flush_buffer (batch graph insertion)
|
||||||
|
* - diskann_insert with buffer_threshold (batching logic)
|
||||||
|
* - Buffer-graph merge in vec0Filter_knn_diskann (unflushed vectors
|
||||||
|
* must be scanned during KNN and merged with graph results)
|
||||||
|
* - Delete of a buffered (not yet flushed) vector
|
||||||
|
* - Delete of a graph vector while buffer has pending inserts
|
||||||
|
* - Interaction: insert to buffer, query (triggers buffer scan), flush,
|
||||||
|
* query again (now from graph)
|
||||||
|
*
|
||||||
|
* The buffer merge path in vec0Filter_knn_diskann is particularly
|
||||||
|
* interesting because it does a brute-force scan of buffer vectors and
|
||||||
|
* merges with the top-k from graph search.
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
|
||||||
|
if (*size == 0) return def;
|
||||||
|
uint8_t b = **data;
|
||||||
|
(*data)++;
|
||||||
|
(*size)--;
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 16) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
/* buffer_threshold: small (3-8) to trigger frequent flushes */
|
||||||
|
int buf_threshold = 3 + (fuzz_byte(&data, &size, 0) % 6);
|
||||||
|
int dims = 8;
|
||||||
|
|
||||||
|
char sql[512];
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[%d] INDEXED BY diskann("
|
||||||
|
"neighbor_quantizer=binary, n_neighbors=8, "
|
||||||
|
"search_list_size=16, buffer_threshold=%d"
|
||||||
|
"))", dims, buf_threshold);
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
sqlite3_stmt *stmtInsert = NULL, *stmtDelete = NULL, *stmtKnn = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?",
|
||||||
|
-1, &stmtKnn, NULL);
|
||||||
|
|
||||||
|
if (!stmtInsert || !stmtDelete || !stmtKnn) goto cleanup;
|
||||||
|
|
||||||
|
float vec[8];
|
||||||
|
int next_rowid = 1;
|
||||||
|
|
||||||
|
while (size >= 2) {
|
||||||
|
uint8_t op = fuzz_byte(&data, &size, 0) % 6;
|
||||||
|
uint8_t param = fuzz_byte(&data, &size, 0);
|
||||||
|
|
||||||
|
switch (op) {
|
||||||
|
case 0: { /* Insert: accumulates in buffer until threshold */
|
||||||
|
int64_t rowid = next_rowid++;
|
||||||
|
if (next_rowid > 64) next_rowid = 1; /* wrap around for reuse */
|
||||||
|
for (int j = 0; j < dims; j++) {
|
||||||
|
vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 1: { /* KNN query while buffer may have unflushed vectors */
|
||||||
|
for (int j = 0; j < dims; j++) {
|
||||||
|
vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f;
|
||||||
|
}
|
||||||
|
int k = (param % 10) + 1;
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_bind_int(stmtKnn, 2, k);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2: { /* Delete a potentially-buffered vector */
|
||||||
|
int64_t rowid = (int64_t)(param % 64) + 1;
|
||||||
|
sqlite3_reset(stmtDelete);
|
||||||
|
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||||
|
sqlite3_step(stmtDelete);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 3: { /* Insert several at once to trigger flush mid-batch */
|
||||||
|
for (int i = 0; i < buf_threshold + 1 && size >= 2; i++) {
|
||||||
|
int64_t rowid = (int64_t)(fuzz_byte(&data, &size, 0) % 64) + 1;
|
||||||
|
for (int j = 0; j < dims; j++) {
|
||||||
|
vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 4: { /* Insert then immediately delete (still in buffer) */
|
||||||
|
int64_t rowid = (int64_t)(param % 64) + 1;
|
||||||
|
for (int j = 0; j < dims; j++) vec[j] = 0.1f * param;
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
|
||||||
|
sqlite3_reset(stmtDelete);
|
||||||
|
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||||
|
sqlite3_step(stmtDelete);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 5: { /* Query with k=0 and k=1 (boundary) */
|
||||||
|
for (int j = 0; j < dims; j++) vec[j] = 0.0f;
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_bind_int(stmtKnn, 2, param % 2); /* k=0 or k=1 */
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Final query to exercise post-operation state */
|
||||||
|
{
|
||||||
|
float qvec[8] = {1.0f, -1.0f, 0.5f, -0.5f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_bind_int(stmtKnn, 2, 20);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup:
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
sqlite3_finalize(stmtDelete);
|
||||||
|
sqlite3_finalize(stmtKnn);
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
158
tests/fuzz/diskann-command-inject.c
Normal file
158
tests/fuzz/diskann-command-inject.c
Normal file
|
|
@ -0,0 +1,158 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target for DiskANN runtime command dispatch (diskann_handle_command).
|
||||||
|
*
|
||||||
|
* The command handler parses strings like "search_list_size_search=42" and
|
||||||
|
* modifies live DiskANN config. This fuzzer exercises:
|
||||||
|
*
|
||||||
|
* - atoi on fuzz-controlled strings (integer overflow, negative, non-numeric)
|
||||||
|
* - strncmp boundary with fuzz data (near-matches to valid commands)
|
||||||
|
* - Changing search_list_size mid-operation (affects subsequent queries)
|
||||||
|
* - Setting search_list_size to 1 (minimum - single-candidate beam search)
|
||||||
|
* - Setting search_list_size very large (memory pressure)
|
||||||
|
* - Interleaving command changes with inserts and queries
|
||||||
|
*
|
||||||
|
* Also tests the UPDATE v SET command = ? path through the vtable.
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
|
||||||
|
if (*size == 0) return def;
|
||||||
|
uint8_t b = **data;
|
||||||
|
(*data)++;
|
||||||
|
(*size)--;
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 20) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db,
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8))",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
/* Insert some vectors first */
|
||||||
|
{
|
||||||
|
sqlite3_stmt *stmt;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmt, NULL);
|
||||||
|
for (int i = 1; i <= 8; i++) {
|
||||||
|
float vec[8];
|
||||||
|
for (int j = 0; j < 8; j++) vec[j] = (float)i * 0.1f + (float)j * 0.01f;
|
||||||
|
sqlite3_reset(stmt);
|
||||||
|
sqlite3_bind_int64(stmt, 1, i);
|
||||||
|
sqlite3_bind_blob(stmt, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmt);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_stmt *stmtCmd = NULL;
|
||||||
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
|
sqlite3_stmt *stmtKnn = NULL;
|
||||||
|
|
||||||
|
/* Commands are dispatched via INSERT INTO t(rowid) VALUES ('cmd_string') */
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid) VALUES (?)", -1, &stmtCmd, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?",
|
||||||
|
-1, &stmtKnn, NULL);
|
||||||
|
|
||||||
|
if (!stmtCmd || !stmtInsert || !stmtKnn) goto cleanup;
|
||||||
|
|
||||||
|
/* Fuzz-driven command + operation interleaving */
|
||||||
|
while (size >= 2) {
|
||||||
|
uint8_t op = fuzz_byte(&data, &size, 0) % 5;
|
||||||
|
|
||||||
|
switch (op) {
|
||||||
|
case 0: { /* Send fuzz command string */
|
||||||
|
int cmd_len = fuzz_byte(&data, &size, 0) % 64;
|
||||||
|
char cmd[65];
|
||||||
|
for (int i = 0; i < cmd_len && size > 0; i++) {
|
||||||
|
cmd[i] = (char)fuzz_byte(&data, &size, 0);
|
||||||
|
}
|
||||||
|
cmd[cmd_len] = '\0';
|
||||||
|
sqlite3_reset(stmtCmd);
|
||||||
|
sqlite3_bind_text(stmtCmd, 1, cmd, -1, SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtCmd); /* May fail -- that's expected */
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 1: { /* Send valid-looking command with fuzz value */
|
||||||
|
const char *prefixes[] = {
|
||||||
|
"search_list_size=",
|
||||||
|
"search_list_size_search=",
|
||||||
|
"search_list_size_insert=",
|
||||||
|
};
|
||||||
|
int prefix_idx = fuzz_byte(&data, &size, 0) % 3;
|
||||||
|
int val = (int)(int8_t)fuzz_byte(&data, &size, 0);
|
||||||
|
|
||||||
|
char cmd[128];
|
||||||
|
snprintf(cmd, sizeof(cmd), "%s%d", prefixes[prefix_idx], val);
|
||||||
|
sqlite3_reset(stmtCmd);
|
||||||
|
sqlite3_bind_text(stmtCmd, 1, cmd, -1, SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtCmd);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2: { /* KNN query (uses whatever search_list_size is set) */
|
||||||
|
float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||||
|
qvec[0] = (float)((int8_t)fuzz_byte(&data, &size, 127)) / 10.0f;
|
||||||
|
int k = fuzz_byte(&data, &size, 3) % 10 + 1;
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_bind_int(stmtKnn, 2, k);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 3: { /* Insert (uses whatever search_list_size_insert is set) */
|
||||||
|
int64_t rowid = (int64_t)(fuzz_byte(&data, &size, 0) % 32) + 1;
|
||||||
|
float vec[8];
|
||||||
|
for (int j = 0; j < 8; j++) {
|
||||||
|
vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 4: { /* Set search_list_size to extreme values */
|
||||||
|
const char *extreme_cmds[] = {
|
||||||
|
"search_list_size=1",
|
||||||
|
"search_list_size=2",
|
||||||
|
"search_list_size=1000",
|
||||||
|
"search_list_size_search=1",
|
||||||
|
"search_list_size_insert=1",
|
||||||
|
};
|
||||||
|
int idx = fuzz_byte(&data, &size, 0) % 5;
|
||||||
|
sqlite3_reset(stmtCmd);
|
||||||
|
sqlite3_bind_text(stmtCmd, 1, extreme_cmds[idx], -1, SQLITE_STATIC);
|
||||||
|
sqlite3_step(stmtCmd);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup:
|
||||||
|
sqlite3_finalize(stmtCmd);
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
sqlite3_finalize(stmtKnn);
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
44
tests/fuzz/diskann-create.c
Normal file
44
tests/fuzz/diskann-create.c
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target for DiskANN CREATE TABLE config parsing.
|
||||||
|
* Feeds fuzz data as the INDEXED BY diskann(...) option string.
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size > 4096) return 0; /* Limit input size */
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
sqlite3_stmt *stmt;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
sqlite3_str *s = sqlite3_str_new(NULL);
|
||||||
|
assert(s);
|
||||||
|
sqlite3_str_appendall(s,
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[64] INDEXED BY diskann(");
|
||||||
|
sqlite3_str_appendf(s, "%.*s", (int)size, data);
|
||||||
|
sqlite3_str_appendall(s, "))");
|
||||||
|
const char *zSql = sqlite3_str_finish(s);
|
||||||
|
assert(zSql);
|
||||||
|
|
||||||
|
rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, NULL);
|
||||||
|
sqlite3_free((char *)zSql);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_step(stmt);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
187
tests/fuzz/diskann-deep-search.c
Normal file
187
tests/fuzz/diskann-deep-search.c
Normal file
|
|
@ -0,0 +1,187 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target for DiskANN greedy beam search deep paths.
|
||||||
|
*
|
||||||
|
* Builds a graph with enough nodes to force multi-hop traversal, then
|
||||||
|
* uses fuzz data to control: query vector values, k, search_list_size
|
||||||
|
* overrides, and interleaved insert/delete/query sequences that stress
|
||||||
|
* the candidate list growth, visited set hash collisions, and the
|
||||||
|
* re-ranking logic.
|
||||||
|
*
|
||||||
|
* Key code paths targeted:
|
||||||
|
* - diskann_candidate_list_insert (sorted insert, dedup, eviction)
|
||||||
|
* - diskann_visited_set (hash collisions, capacity)
|
||||||
|
* - diskann_search (full beam search loop, re-ranking with exact dist)
|
||||||
|
* - diskann_distance_quantized_precomputed (both binary and int8)
|
||||||
|
* - Buffer merge in vec0Filter_knn_diskann
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
/* Consume one byte from fuzz input, or return default. */
|
||||||
|
static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
|
||||||
|
if (*size == 0) return def;
|
||||||
|
uint8_t b = **data;
|
||||||
|
(*data)++;
|
||||||
|
(*size)--;
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
static uint16_t fuzz_u16(const uint8_t **data, size_t *size) {
|
||||||
|
uint8_t lo = fuzz_byte(data, size, 0);
|
||||||
|
uint8_t hi = fuzz_byte(data, size, 0);
|
||||||
|
return (uint16_t)hi << 8 | lo;
|
||||||
|
}
|
||||||
|
|
||||||
|
static float fuzz_float(const uint8_t **data, size_t *size) {
|
||||||
|
return (float)((int8_t)fuzz_byte(data, size, 0)) / 10.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 32) return 0;
|
||||||
|
|
||||||
|
/* Use first bytes to pick quantizer type and dimensions */
|
||||||
|
uint8_t quantizer_choice = fuzz_byte(&data, &size, 0) % 2;
|
||||||
|
const char *quantizer = quantizer_choice ? "int8" : "binary";
|
||||||
|
|
||||||
|
/* Dimensions must be divisible by 8. Pick from {8, 16, 32} */
|
||||||
|
int dim_choices[] = {8, 16, 32};
|
||||||
|
int dims = dim_choices[fuzz_byte(&data, &size, 0) % 3];
|
||||||
|
|
||||||
|
/* n_neighbors: 8 or 16 -- small to force full-neighbor scenarios quickly */
|
||||||
|
int n_neighbors = (fuzz_byte(&data, &size, 0) % 2) ? 16 : 8;
|
||||||
|
|
||||||
|
/* search_list_size: small so beam search terminates quickly but still exercises loops */
|
||||||
|
int search_list_size = 8 + (fuzz_byte(&data, &size, 0) % 24);
|
||||||
|
|
||||||
|
/* alpha: vary to test RobustPrune pruning logic */
|
||||||
|
float alpha_choices[] = {1.0f, 1.2f, 1.5f, 2.0f};
|
||||||
|
float alpha = alpha_choices[fuzz_byte(&data, &size, 0) % 4];
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
char sql[512];
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[%d] INDEXED BY diskann("
|
||||||
|
"neighbor_quantizer=%s, n_neighbors=%d, "
|
||||||
|
"search_list_size=%d"
|
||||||
|
"))", dims, quantizer, n_neighbors, search_list_size);
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
sqlite3_stmt *stmtInsert = NULL, *stmtDelete = NULL, *stmtKnn = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||||
|
|
||||||
|
char knn_sql[256];
|
||||||
|
snprintf(knn_sql, sizeof(knn_sql),
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?");
|
||||||
|
sqlite3_prepare_v2(db, knn_sql, -1, &stmtKnn, NULL);
|
||||||
|
|
||||||
|
if (!stmtInsert || !stmtDelete || !stmtKnn) goto cleanup;
|
||||||
|
|
||||||
|
/* Phase 1: Seed the graph with enough nodes to create multi-hop structure.
|
||||||
|
* Insert 2*n_neighbors nodes so the graph is dense enough for search
|
||||||
|
* to actually traverse multiple hops. */
|
||||||
|
int seed_count = n_neighbors * 2;
|
||||||
|
if (seed_count > 64) seed_count = 64; /* Bound for performance */
|
||||||
|
{
|
||||||
|
float *vec = malloc(dims * sizeof(float));
|
||||||
|
if (!vec) goto cleanup;
|
||||||
|
for (int i = 1; i <= seed_count; i++) {
|
||||||
|
for (int j = 0; j < dims; j++) {
|
||||||
|
vec[j] = fuzz_float(&data, &size);
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, i);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
}
|
||||||
|
free(vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Phase 2: Fuzz-driven operations on the seeded graph */
|
||||||
|
float *vec = malloc(dims * sizeof(float));
|
||||||
|
if (!vec) goto cleanup;
|
||||||
|
|
||||||
|
while (size >= 2) {
|
||||||
|
uint8_t op = fuzz_byte(&data, &size, 0) % 5;
|
||||||
|
uint8_t param = fuzz_byte(&data, &size, 0);
|
||||||
|
|
||||||
|
switch (op) {
|
||||||
|
case 0: { /* INSERT with fuzz-controlled vector and rowid */
|
||||||
|
int64_t rowid = (int64_t)(param % 128) + 1;
|
||||||
|
for (int j = 0; j < dims; j++) {
|
||||||
|
vec[j] = fuzz_float(&data, &size);
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 1: { /* DELETE */
|
||||||
|
int64_t rowid = (int64_t)(param % 128) + 1;
|
||||||
|
sqlite3_reset(stmtDelete);
|
||||||
|
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||||
|
sqlite3_step(stmtDelete);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2: { /* KNN with fuzz query vector and variable k */
|
||||||
|
for (int j = 0; j < dims; j++) {
|
||||||
|
vec[j] = fuzz_float(&data, &size);
|
||||||
|
}
|
||||||
|
int k = (param % 20) + 1;
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_bind_int(stmtKnn, 2, k);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 3: { /* KNN with k > number of nodes (boundary) */
|
||||||
|
for (int j = 0; j < dims; j++) {
|
||||||
|
vec[j] = fuzz_float(&data, &size);
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_bind_int(stmtKnn, 2, 1000); /* k >> graph size */
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 4: { /* INSERT duplicate rowid (triggers OR REPLACE path) */
|
||||||
|
int64_t rowid = (int64_t)(param % 32) + 1;
|
||||||
|
for (int j = 0; j < dims; j++) {
|
||||||
|
vec[j] = (float)(param + j) / 50.0f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
free(vec);
|
||||||
|
|
||||||
|
cleanup:
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
sqlite3_finalize(stmtDelete);
|
||||||
|
sqlite3_finalize(stmtKnn);
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
175
tests/fuzz/diskann-delete-stress.c
Normal file
175
tests/fuzz/diskann-delete-stress.c
Normal file
|
|
@ -0,0 +1,175 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target for DiskANN delete path and graph connectivity maintenance.
|
||||||
|
*
|
||||||
|
* The delete path is the most complex graph mutation:
|
||||||
|
* 1. Read deleted node's neighbor list
|
||||||
|
* 2. For each neighbor, remove deleted node from their list
|
||||||
|
* 3. Try to fill the gap with one of deleted node's other neighbors
|
||||||
|
* 4. Handle medoid deletion (pick new medoid)
|
||||||
|
*
|
||||||
|
* Edge cases this targets:
|
||||||
|
* - Delete the medoid (entry point) -- forces medoid reassignment
|
||||||
|
* - Delete all nodes except one -- graph degenerates
|
||||||
|
* - Delete nodes in a chain -- cascading dangling edges
|
||||||
|
* - Re-insert at deleted rowids -- stale graph edges to old data
|
||||||
|
* - Delete nonexistent rowids -- should be no-op
|
||||||
|
* - Insert-delete-insert same rowid rapidly
|
||||||
|
* - Delete when graph has exactly n_neighbors entries (full nodes)
|
||||||
|
*
|
||||||
|
* Key code paths:
|
||||||
|
* - diskann_delete -> diskann_repair_reverse_edges
|
||||||
|
* - diskann_medoid_handle_delete
|
||||||
|
* - diskann_node_clear_neighbor
|
||||||
|
* - Interaction between delete and concurrent search
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
|
||||||
|
if (*size == 0) return def;
|
||||||
|
uint8_t b = **data;
|
||||||
|
(*data)++;
|
||||||
|
(*size)--;
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 20) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
/* int8 quantizer to exercise that distance code path */
|
||||||
|
uint8_t quant = fuzz_byte(&data, &size, 0) % 2;
|
||||||
|
const char *qname = quant ? "int8" : "binary";
|
||||||
|
|
||||||
|
char sql[256];
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[8] INDEXED BY diskann(neighbor_quantizer=%s, n_neighbors=8))",
|
||||||
|
qname);
|
||||||
|
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
sqlite3_stmt *stmtInsert = NULL, *stmtDelete = NULL, *stmtKnn = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?",
|
||||||
|
-1, &stmtKnn, NULL);
|
||||||
|
|
||||||
|
if (!stmtInsert || !stmtDelete || !stmtKnn) goto cleanup;
|
||||||
|
|
||||||
|
/* Phase 1: Build a graph of exactly n_neighbors+2 = 10 nodes.
|
||||||
|
* This makes every node nearly full, maximizing the chance that
|
||||||
|
* inserts trigger the "full node" path in add_reverse_edge. */
|
||||||
|
for (int i = 1; i <= 10; i++) {
|
||||||
|
float vec[8];
|
||||||
|
for (int j = 0; j < 8; j++) {
|
||||||
|
vec[j] = (float)((int8_t)fuzz_byte(&data, &size, (uint8_t)(i*13+j*7))) / 20.0f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, i);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Phase 2: Fuzz-driven delete-heavy workload */
|
||||||
|
while (size >= 2) {
|
||||||
|
uint8_t op = fuzz_byte(&data, &size, 0);
|
||||||
|
uint8_t param = fuzz_byte(&data, &size, 0);
|
||||||
|
|
||||||
|
switch (op % 6) {
|
||||||
|
case 0: /* Delete existing node */
|
||||||
|
case 1: { /* (weighted toward deletes) */
|
||||||
|
int64_t rowid = (int64_t)(param % 16) + 1;
|
||||||
|
sqlite3_reset(stmtDelete);
|
||||||
|
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||||
|
sqlite3_step(stmtDelete);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2: { /* Delete then immediately re-insert same rowid */
|
||||||
|
int64_t rowid = (int64_t)(param % 10) + 1;
|
||||||
|
sqlite3_reset(stmtDelete);
|
||||||
|
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||||
|
sqlite3_step(stmtDelete);
|
||||||
|
|
||||||
|
float vec[8];
|
||||||
|
for (int j = 0; j < 8; j++) {
|
||||||
|
vec[j] = (float)((int8_t)fuzz_byte(&data, &size, (uint8_t)(rowid+j))) / 15.0f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 3: { /* KNN query on potentially sparse/empty graph */
|
||||||
|
float qvec[8];
|
||||||
|
for (int j = 0; j < 8; j++) {
|
||||||
|
qvec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f;
|
||||||
|
}
|
||||||
|
int k = (param % 15) + 1;
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_bind_int(stmtKnn, 2, k);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 4: { /* Insert new node */
|
||||||
|
int64_t rowid = (int64_t)(param % 32) + 1;
|
||||||
|
float vec[8];
|
||||||
|
for (int j = 0; j < 8; j++) {
|
||||||
|
vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 5: { /* Delete ALL remaining nodes, then insert fresh */
|
||||||
|
for (int i = 1; i <= 32; i++) {
|
||||||
|
sqlite3_reset(stmtDelete);
|
||||||
|
sqlite3_bind_int64(stmtDelete, 1, i);
|
||||||
|
sqlite3_step(stmtDelete);
|
||||||
|
}
|
||||||
|
/* Now insert one node into empty graph */
|
||||||
|
float vec[8] = {1.0f, 0, 0, 0, 0, 0, 0, 0};
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, 1);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Final KNN on whatever state the graph is in */
|
||||||
|
{
|
||||||
|
float qvec[8] = {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f};
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_bind_int(stmtKnn, 2, 10);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup:
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
sqlite3_finalize(stmtDelete);
|
||||||
|
sqlite3_finalize(stmtKnn);
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
123
tests/fuzz/diskann-graph-corrupt.c
Normal file
123
tests/fuzz/diskann-graph-corrupt.c
Normal file
|
|
@ -0,0 +1,123 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target for DiskANN shadow table corruption resilience.
|
||||||
|
* Creates and populates a DiskANN table, then corrupts shadow table blobs
|
||||||
|
* using fuzz data and runs queries.
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 16) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db,
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8))",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
/* Insert a few vectors to create graph structure */
|
||||||
|
{
|
||||||
|
sqlite3_stmt *stmt;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmt, NULL);
|
||||||
|
for (int i = 1; i <= 10; i++) {
|
||||||
|
float vec[8];
|
||||||
|
for (int j = 0; j < 8; j++) {
|
||||||
|
vec[j] = (float)i * 0.1f + (float)j * 0.01f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmt);
|
||||||
|
sqlite3_bind_int64(stmt, 1, i);
|
||||||
|
sqlite3_bind_blob(stmt, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmt);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Corrupt shadow table data using fuzz bytes */
|
||||||
|
size_t offset = 0;
|
||||||
|
|
||||||
|
/* Determine which row and column to corrupt */
|
||||||
|
int target_row = (data[offset++] % 10) + 1;
|
||||||
|
int corrupt_type = data[offset++] % 3; /* 0=validity, 1=neighbor_ids, 2=qvecs */
|
||||||
|
|
||||||
|
const char *column_name;
|
||||||
|
switch (corrupt_type) {
|
||||||
|
case 0: column_name = "neighbors_validity"; break;
|
||||||
|
case 1: column_name = "neighbor_ids"; break;
|
||||||
|
default: column_name = "neighbor_quantized_vectors"; break;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Read the blob, corrupt it, write it back */
|
||||||
|
{
|
||||||
|
sqlite3_stmt *readStmt;
|
||||||
|
char sqlbuf[256];
|
||||||
|
snprintf(sqlbuf, sizeof(sqlbuf),
|
||||||
|
"SELECT %s FROM v_diskann_nodes00 WHERE rowid = ?", column_name);
|
||||||
|
rc = sqlite3_prepare_v2(db, sqlbuf, -1, &readStmt, NULL);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_bind_int64(readStmt, 1, target_row);
|
||||||
|
if (sqlite3_step(readStmt) == SQLITE_ROW) {
|
||||||
|
const void *blob = sqlite3_column_blob(readStmt, 0);
|
||||||
|
int blobSize = sqlite3_column_bytes(readStmt, 0);
|
||||||
|
if (blob && blobSize > 0) {
|
||||||
|
unsigned char *corrupt = sqlite3_malloc(blobSize);
|
||||||
|
if (corrupt) {
|
||||||
|
memcpy(corrupt, blob, blobSize);
|
||||||
|
/* Apply fuzz bytes as XOR corruption */
|
||||||
|
size_t remaining = size - offset;
|
||||||
|
for (size_t i = 0; i < remaining && i < (size_t)blobSize; i++) {
|
||||||
|
corrupt[i % blobSize] ^= data[offset + i];
|
||||||
|
}
|
||||||
|
/* Write back */
|
||||||
|
sqlite3_stmt *writeStmt;
|
||||||
|
snprintf(sqlbuf, sizeof(sqlbuf),
|
||||||
|
"UPDATE v_diskann_nodes00 SET %s = ? WHERE rowid = ?", column_name);
|
||||||
|
rc = sqlite3_prepare_v2(db, sqlbuf, -1, &writeStmt, NULL);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_bind_blob(writeStmt, 1, corrupt, blobSize, SQLITE_TRANSIENT);
|
||||||
|
sqlite3_bind_int64(writeStmt, 2, target_row);
|
||||||
|
sqlite3_step(writeStmt);
|
||||||
|
sqlite3_finalize(writeStmt);
|
||||||
|
}
|
||||||
|
sqlite3_free(corrupt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sqlite3_finalize(readStmt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Run queries on corrupted graph -- should not crash */
|
||||||
|
{
|
||||||
|
float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||||
|
sqlite3_stmt *knnStmt;
|
||||||
|
rc = sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = 5",
|
||||||
|
-1, &knnStmt, NULL);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_bind_blob(knnStmt, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||||
|
while (sqlite3_step(knnStmt) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(knnStmt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Full scan */
|
||||||
|
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
164
tests/fuzz/diskann-int8-quant.c
Normal file
164
tests/fuzz/diskann-int8-quant.c
Normal file
|
|
@ -0,0 +1,164 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target for DiskANN int8 quantizer edge cases.
|
||||||
|
*
|
||||||
|
* The binary quantizer is simple (sign bit), but the int8 quantizer has
|
||||||
|
* interesting arithmetic:
|
||||||
|
* i8_val = (i8)(((src - (-1.0f)) / step) - 128.0f)
|
||||||
|
* where step = 2.0f / 255.0f
|
||||||
|
*
|
||||||
|
* Edge cases in this formula:
|
||||||
|
* - src values outside [-1, 1] cause clamping issues (no explicit clamp!)
|
||||||
|
* - src = NaN, +Inf, -Inf (from corrupted vectors or div-by-zero)
|
||||||
|
* - src very close to boundaries (-1.0, 1.0) -- rounding
|
||||||
|
* - The cast to i8 can overflow for extreme src values
|
||||||
|
*
|
||||||
|
* Also exercises int8 distance functions:
|
||||||
|
* - distance_l2_sqr_int8: accumulates squared differences, possible overflow
|
||||||
|
* - distance_cosine_int8: dot product with normalization
|
||||||
|
* - distance_l1_int8: absolute differences
|
||||||
|
*
|
||||||
|
* This fuzzer also tests the cosine distance metric path which the
|
||||||
|
* other fuzzers (using L2 default) don't cover.
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
|
||||||
|
if (*size == 0) return def;
|
||||||
|
uint8_t b = **data;
|
||||||
|
(*data)++;
|
||||||
|
(*size)--;
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
static float fuzz_extreme_float(const uint8_t **data, size_t *size) {
|
||||||
|
uint8_t mode = fuzz_byte(data, size, 0) % 8;
|
||||||
|
uint8_t raw = fuzz_byte(data, size, 0);
|
||||||
|
switch (mode) {
|
||||||
|
case 0: return (float)((int8_t)raw) / 10.0f; /* Normal range */
|
||||||
|
case 1: return (float)((int8_t)raw) * 100.0f; /* Large values */
|
||||||
|
case 2: return (float)((int8_t)raw) / 1000.0f; /* Tiny values near 0 */
|
||||||
|
case 3: return -1.0f; /* Exact boundary */
|
||||||
|
case 4: return 1.0f; /* Exact boundary */
|
||||||
|
case 5: return 0.0f; /* Zero */
|
||||||
|
case 6: return (float)raw / 255.0f; /* [0, 1] range */
|
||||||
|
case 7: return -(float)raw / 255.0f; /* [-1, 0] range */
|
||||||
|
}
|
||||||
|
return 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 40) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
/* Test both distance metrics with int8 quantizer */
|
||||||
|
uint8_t metric_choice = fuzz_byte(&data, &size, 0) % 2;
|
||||||
|
const char *metric = metric_choice ? "cosine" : "L2";
|
||||||
|
|
||||||
|
int dims = 8 + (fuzz_byte(&data, &size, 0) % 3) * 8; /* 8, 16, or 24 */
|
||||||
|
|
||||||
|
char sql[512];
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[%d] distance_metric=%s "
|
||||||
|
"INDEXED BY diskann(neighbor_quantizer=int8, n_neighbors=8, search_list_size=16))",
|
||||||
|
dims, metric);
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
sqlite3_stmt *stmtInsert = NULL, *stmtKnn = NULL, *stmtDelete = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?",
|
||||||
|
-1, &stmtKnn, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||||
|
|
||||||
|
if (!stmtInsert || !stmtKnn || !stmtDelete) goto cleanup;
|
||||||
|
|
||||||
|
/* Insert vectors with extreme float values to stress quantization */
|
||||||
|
float *vec = malloc(dims * sizeof(float));
|
||||||
|
if (!vec) goto cleanup;
|
||||||
|
|
||||||
|
for (int i = 1; i <= 16; i++) {
|
||||||
|
for (int j = 0; j < dims; j++) {
|
||||||
|
vec[j] = fuzz_extreme_float(&data, &size);
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, i);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Fuzz-driven operations */
|
||||||
|
while (size >= 2) {
|
||||||
|
uint8_t op = fuzz_byte(&data, &size, 0) % 4;
|
||||||
|
uint8_t param = fuzz_byte(&data, &size, 0);
|
||||||
|
|
||||||
|
switch (op) {
|
||||||
|
case 0: { /* KNN with extreme query values */
|
||||||
|
for (int j = 0; j < dims; j++) {
|
||||||
|
vec[j] = fuzz_extreme_float(&data, &size);
|
||||||
|
}
|
||||||
|
int k = (param % 10) + 1;
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_bind_int(stmtKnn, 2, k);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 1: { /* Insert with extreme values */
|
||||||
|
int64_t rowid = (int64_t)(param % 32) + 1;
|
||||||
|
for (int j = 0; j < dims; j++) {
|
||||||
|
vec[j] = fuzz_extreme_float(&data, &size);
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2: { /* Delete */
|
||||||
|
int64_t rowid = (int64_t)(param % 32) + 1;
|
||||||
|
sqlite3_reset(stmtDelete);
|
||||||
|
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||||
|
sqlite3_step(stmtDelete);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 3: { /* KNN with all-zero or all-same-value query */
|
||||||
|
float val = (param % 3 == 0) ? 0.0f :
|
||||||
|
(param % 3 == 1) ? 1.0f : -1.0f;
|
||||||
|
for (int j = 0; j < dims; j++) vec[j] = val;
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_bind_int(stmtKnn, 2, 5);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
free(vec);
|
||||||
|
|
||||||
|
cleanup:
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
sqlite3_finalize(stmtKnn);
|
||||||
|
sqlite3_finalize(stmtDelete);
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
100
tests/fuzz/diskann-operations.c
Normal file
100
tests/fuzz/diskann-operations.c
Normal file
|
|
@ -0,0 +1,100 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target for DiskANN insert/delete/query operation sequences.
|
||||||
|
* Uses fuzz bytes to drive random operations on a DiskANN-indexed table.
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 6) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
|
sqlite3_stmt *stmtDelete = NULL;
|
||||||
|
sqlite3_stmt *stmtKnn = NULL;
|
||||||
|
sqlite3_stmt *stmtScan = NULL;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db,
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8))",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = 3",
|
||||||
|
-1, &stmtKnn, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid FROM v", -1, &stmtScan, NULL);
|
||||||
|
|
||||||
|
if (!stmtInsert || !stmtDelete || !stmtKnn || !stmtScan) goto cleanup;
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
while (i + 2 <= size) {
|
||||||
|
uint8_t op = data[i++] % 4;
|
||||||
|
uint8_t rowid_byte = data[i++];
|
||||||
|
int64_t rowid = (int64_t)(rowid_byte % 32) + 1;
|
||||||
|
|
||||||
|
switch (op) {
|
||||||
|
case 0: {
|
||||||
|
/* INSERT: consume 32 bytes for 8 floats, or use what's left */
|
||||||
|
float vec[8] = {0};
|
||||||
|
for (int j = 0; j < 8 && i < size; j++, i++) {
|
||||||
|
vec[j] = (float)((int8_t)data[i]) / 10.0f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 1: {
|
||||||
|
/* DELETE */
|
||||||
|
sqlite3_reset(stmtDelete);
|
||||||
|
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||||
|
sqlite3_step(stmtDelete);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2: {
|
||||||
|
/* KNN query */
|
||||||
|
float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 3: {
|
||||||
|
/* Full scan */
|
||||||
|
sqlite3_reset(stmtScan);
|
||||||
|
while (sqlite3_step(stmtScan) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Final operations -- must not crash regardless of prior state */
|
||||||
|
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
cleanup:
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
sqlite3_finalize(stmtDelete);
|
||||||
|
sqlite3_finalize(stmtKnn);
|
||||||
|
sqlite3_finalize(stmtScan);
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
131
tests/fuzz/diskann-prune-direct.c
Normal file
131
tests/fuzz/diskann-prune-direct.c
Normal file
|
|
@ -0,0 +1,131 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target for DiskANN RobustPrune algorithm (diskann_prune_select).
|
||||||
|
*
|
||||||
|
* diskann_prune_select is exposed for testing and takes:
|
||||||
|
* - inter_distances: flattened NxN matrix of inter-candidate distances
|
||||||
|
* - p_distances: N distances from node p to each candidate
|
||||||
|
* - num_candidates, alpha, max_neighbors
|
||||||
|
*
|
||||||
|
* This is a pure function that doesn't need a database, so we can
|
||||||
|
* call it directly with fuzz-controlled inputs. This gives the fuzzer
|
||||||
|
* maximum speed (no SQLite overhead) to explore:
|
||||||
|
*
|
||||||
|
* - alpha boundary: alpha=0 (prunes nothing), alpha=very large (prunes all)
|
||||||
|
* - max_neighbors = 0, 1, num_candidates, > num_candidates
|
||||||
|
* - num_candidates = 0, 1, large
|
||||||
|
* - Distance matrices with: all zeros, all same, negative values, NaN, Inf
|
||||||
|
* - Non-symmetric distance matrices (should still work)
|
||||||
|
* - Memory: large num_candidates to stress malloc
|
||||||
|
*
|
||||||
|
* Key code paths:
|
||||||
|
* - diskann_prune_select alpha-pruning loop
|
||||||
|
* - Boundary: selectedCount reaches max_neighbors exactly
|
||||||
|
* - All candidates pruned before max_neighbors reached
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
/* Declare the test-exposed function.
|
||||||
|
* diskann_prune_select is not static -- it's a public symbol. */
|
||||||
|
extern int diskann_prune_select(
|
||||||
|
const float *inter_distances, const float *p_distances,
|
||||||
|
int num_candidates, float alpha, int max_neighbors,
|
||||||
|
int *outSelected, int *outCount);
|
||||||
|
|
||||||
|
static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
|
||||||
|
if (*size == 0) return def;
|
||||||
|
uint8_t b = **data;
|
||||||
|
(*data)++;
|
||||||
|
(*size)--;
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 8) return 0;
|
||||||
|
|
||||||
|
/* Consume parameters from fuzz data */
|
||||||
|
int num_candidates = fuzz_byte(&data, &size, 0) % 33; /* 0..32 */
|
||||||
|
int max_neighbors = fuzz_byte(&data, &size, 0) % 17; /* 0..16 */
|
||||||
|
|
||||||
|
/* Alpha: pick from interesting values */
|
||||||
|
uint8_t alpha_idx = fuzz_byte(&data, &size, 0) % 8;
|
||||||
|
float alpha_values[] = {0.0f, 0.5f, 1.0f, 1.2f, 1.5f, 2.0f, 10.0f, 100.0f};
|
||||||
|
float alpha = alpha_values[alpha_idx];
|
||||||
|
|
||||||
|
if (num_candidates == 0) {
|
||||||
|
/* Test empty case */
|
||||||
|
int outCount = -1;
|
||||||
|
int rc = diskann_prune_select(NULL, NULL, 0, alpha, max_neighbors,
|
||||||
|
NULL, &outCount);
|
||||||
|
assert(rc == 0 /* SQLITE_OK */);
|
||||||
|
assert(outCount == 0);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Allocate arrays */
|
||||||
|
int n = num_candidates;
|
||||||
|
float *inter_distances = malloc(n * n * sizeof(float));
|
||||||
|
float *p_distances = malloc(n * sizeof(float));
|
||||||
|
int *outSelected = malloc(n * sizeof(int));
|
||||||
|
if (!inter_distances || !p_distances || !outSelected) {
|
||||||
|
free(inter_distances);
|
||||||
|
free(p_distances);
|
||||||
|
free(outSelected);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Fill p_distances from fuzz data (sorted ascending for correct input) */
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
uint8_t raw = fuzz_byte(&data, &size, (uint8_t)(i * 10));
|
||||||
|
p_distances[i] = (float)raw / 10.0f;
|
||||||
|
}
|
||||||
|
/* Sort p_distances ascending (prune_select expects sorted input) */
|
||||||
|
for (int i = 1; i < n; i++) {
|
||||||
|
float tmp = p_distances[i];
|
||||||
|
int j = i - 1;
|
||||||
|
while (j >= 0 && p_distances[j] > tmp) {
|
||||||
|
p_distances[j + 1] = p_distances[j];
|
||||||
|
j--;
|
||||||
|
}
|
||||||
|
p_distances[j + 1] = tmp;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Fill inter-distance matrix from fuzz data */
|
||||||
|
for (int i = 0; i < n * n; i++) {
|
||||||
|
uint8_t raw = fuzz_byte(&data, &size, (uint8_t)(i % 256));
|
||||||
|
inter_distances[i] = (float)raw / 10.0f;
|
||||||
|
}
|
||||||
|
/* Make diagonal zero */
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
inter_distances[i * n + i] = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
int outCount = -1;
|
||||||
|
int rc = diskann_prune_select(inter_distances, p_distances,
|
||||||
|
n, alpha, max_neighbors,
|
||||||
|
outSelected, &outCount);
|
||||||
|
/* Basic sanity: should not crash, count should be valid */
|
||||||
|
assert(rc == 0);
|
||||||
|
assert(outCount >= 0);
|
||||||
|
assert(outCount <= max_neighbors || max_neighbors == 0);
|
||||||
|
assert(outCount <= n);
|
||||||
|
|
||||||
|
/* Verify outSelected flags are consistent with outCount */
|
||||||
|
int flagCount = 0;
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
if (outSelected[i]) flagCount++;
|
||||||
|
}
|
||||||
|
assert(flagCount == outCount);
|
||||||
|
|
||||||
|
free(inter_distances);
|
||||||
|
free(p_distances);
|
||||||
|
free(outSelected);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
10
tests/fuzz/diskann.dict
Normal file
10
tests/fuzz/diskann.dict
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
"neighbor_quantizer"
|
||||||
|
"binary"
|
||||||
|
"int8"
|
||||||
|
"n_neighbors"
|
||||||
|
"search_list_size"
|
||||||
|
"search_list_size_search"
|
||||||
|
"search_list_size_insert"
|
||||||
|
"alpha"
|
||||||
|
"="
|
||||||
|
","
|
||||||
|
|
@ -73,6 +73,7 @@ enum Vec0IndexType {
|
||||||
VEC0_INDEX_TYPE_RESCORE = 2,
|
VEC0_INDEX_TYPE_RESCORE = 2,
|
||||||
#endif
|
#endif
|
||||||
VEC0_INDEX_TYPE_IVF = 3,
|
VEC0_INDEX_TYPE_IVF = 3,
|
||||||
|
VEC0_INDEX_TYPE_DISKANN = 4,
|
||||||
};
|
};
|
||||||
|
|
||||||
enum Vec0RescoreQuantizerType {
|
enum Vec0RescoreQuantizerType {
|
||||||
|
|
@ -114,6 +115,20 @@ struct Vec0RescoreConfig {
|
||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
enum Vec0DiskannQuantizerType {
|
||||||
|
VEC0_DISKANN_QUANTIZER_BINARY = 1,
|
||||||
|
VEC0_DISKANN_QUANTIZER_INT8 = 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Vec0DiskannConfig {
|
||||||
|
enum Vec0DiskannQuantizerType quantizer_type;
|
||||||
|
int n_neighbors;
|
||||||
|
int search_list_size;
|
||||||
|
int search_list_size_search;
|
||||||
|
int search_list_size_insert;
|
||||||
|
float alpha;
|
||||||
|
int buffer_threshold;
|
||||||
|
};
|
||||||
|
|
||||||
struct VectorColumnDefinition {
|
struct VectorColumnDefinition {
|
||||||
char *name;
|
char *name;
|
||||||
|
|
@ -126,6 +141,7 @@ struct VectorColumnDefinition {
|
||||||
struct Vec0RescoreConfig rescore;
|
struct Vec0RescoreConfig rescore;
|
||||||
#endif
|
#endif
|
||||||
struct Vec0IvfConfig ivf;
|
struct Vec0IvfConfig ivf;
|
||||||
|
struct Vec0DiskannConfig diskann;
|
||||||
};
|
};
|
||||||
|
|
||||||
int vec0_parse_vector_column(const char *source, int source_length,
|
int vec0_parse_vector_column(const char *source, int source_length,
|
||||||
|
|
@ -136,6 +152,48 @@ int vec0_parse_partition_key_definition(const char *source, int source_length,
|
||||||
int *out_column_name_length,
|
int *out_column_name_length,
|
||||||
int *out_column_type);
|
int *out_column_type);
|
||||||
|
|
||||||
|
size_t diskann_quantized_vector_byte_size(
|
||||||
|
enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions);
|
||||||
|
|
||||||
|
int diskann_validity_byte_size(int n_neighbors);
|
||||||
|
size_t diskann_neighbor_ids_byte_size(int n_neighbors);
|
||||||
|
size_t diskann_neighbor_qvecs_byte_size(
|
||||||
|
int n_neighbors, enum Vec0DiskannQuantizerType quantizer_type,
|
||||||
|
size_t dimensions);
|
||||||
|
int diskann_node_init(
|
||||||
|
int n_neighbors, enum Vec0DiskannQuantizerType quantizer_type,
|
||||||
|
size_t dimensions,
|
||||||
|
unsigned char **outValidity, int *outValiditySize,
|
||||||
|
unsigned char **outNeighborIds, int *outNeighborIdsSize,
|
||||||
|
unsigned char **outNeighborQvecs, int *outNeighborQvecsSize);
|
||||||
|
int diskann_validity_get(const unsigned char *validity, int i);
|
||||||
|
void diskann_validity_set(unsigned char *validity, int i, int value);
|
||||||
|
int diskann_validity_count(const unsigned char *validity, int n_neighbors);
|
||||||
|
long long diskann_neighbor_id_get(const unsigned char *neighbor_ids, int i);
|
||||||
|
void diskann_neighbor_id_set(unsigned char *neighbor_ids, int i, long long rowid);
|
||||||
|
const unsigned char *diskann_neighbor_qvec_get(
|
||||||
|
const unsigned char *qvecs, int i,
|
||||||
|
enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions);
|
||||||
|
void diskann_neighbor_qvec_set(
|
||||||
|
unsigned char *qvecs, int i, const unsigned char *src_qvec,
|
||||||
|
enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions);
|
||||||
|
void diskann_node_set_neighbor(
|
||||||
|
unsigned char *validity, unsigned char *neighbor_ids, unsigned char *qvecs, int i,
|
||||||
|
long long neighbor_rowid, const unsigned char *neighbor_qvec,
|
||||||
|
enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions);
|
||||||
|
void diskann_node_clear_neighbor(
|
||||||
|
unsigned char *validity, unsigned char *neighbor_ids, unsigned char *qvecs, int i,
|
||||||
|
enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions);
|
||||||
|
int diskann_quantize_vector(
|
||||||
|
const float *src, size_t dimensions,
|
||||||
|
enum Vec0DiskannQuantizerType quantizer_type,
|
||||||
|
unsigned char *out);
|
||||||
|
|
||||||
|
int diskann_prune_select(
|
||||||
|
const float *inter_distances, const float *p_distances,
|
||||||
|
int num_candidates, float alpha, int max_neighbors,
|
||||||
|
int *outSelected, int *outCount);
|
||||||
|
|
||||||
#ifdef SQLITE_VEC_TEST
|
#ifdef SQLITE_VEC_TEST
|
||||||
float _test_distance_l2_sqr_float(const float *a, const float *b, size_t dims);
|
float _test_distance_l2_sqr_float(const float *a, const float *b, size_t dims);
|
||||||
float _test_distance_cosine_float(const float *a, const float *b, size_t dims);
|
float _test_distance_cosine_float(const float *a, const float *b, size_t dims);
|
||||||
|
|
@ -151,6 +209,33 @@ size_t _test_rescore_quantized_byte_size_int8(size_t dimensions);
|
||||||
void ivf_quantize_int8(const float *src, int8_t *dst, int D);
|
void ivf_quantize_int8(const float *src, int8_t *dst, int D);
|
||||||
void ivf_quantize_binary(const float *src, uint8_t *dst, int D);
|
void ivf_quantize_binary(const float *src, uint8_t *dst, int D);
|
||||||
#endif
|
#endif
|
||||||
|
// DiskANN candidate list (opaque struct, use accessors)
|
||||||
|
struct DiskannCandidateList {
|
||||||
|
void *items; // opaque
|
||||||
|
int count;
|
||||||
|
int capacity;
|
||||||
|
};
|
||||||
|
|
||||||
|
int _test_diskann_candidate_list_init(struct DiskannCandidateList *list, int capacity);
|
||||||
|
void _test_diskann_candidate_list_free(struct DiskannCandidateList *list);
|
||||||
|
int _test_diskann_candidate_list_insert(struct DiskannCandidateList *list, long long rowid, float distance);
|
||||||
|
int _test_diskann_candidate_list_next_unvisited(const struct DiskannCandidateList *list);
|
||||||
|
int _test_diskann_candidate_list_count(const struct DiskannCandidateList *list);
|
||||||
|
long long _test_diskann_candidate_list_rowid(const struct DiskannCandidateList *list, int i);
|
||||||
|
float _test_diskann_candidate_list_distance(const struct DiskannCandidateList *list, int i);
|
||||||
|
void _test_diskann_candidate_list_set_visited(struct DiskannCandidateList *list, int i);
|
||||||
|
|
||||||
|
// DiskANN visited set (opaque struct, use accessors)
|
||||||
|
struct DiskannVisitedSet {
|
||||||
|
void *slots; // opaque
|
||||||
|
int capacity;
|
||||||
|
int count;
|
||||||
|
};
|
||||||
|
|
||||||
|
int _test_diskann_visited_set_init(struct DiskannVisitedSet *set, int capacity);
|
||||||
|
void _test_diskann_visited_set_free(struct DiskannVisitedSet *set);
|
||||||
|
int _test_diskann_visited_set_contains(const struct DiskannVisitedSet *set, long long rowid);
|
||||||
|
int _test_diskann_visited_set_insert(struct DiskannVisitedSet *set, long long rowid);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#endif /* SQLITE_VEC_INTERNAL_H */
|
#endif /* SQLITE_VEC_INTERNAL_H */
|
||||||
|
|
|
||||||
1160
tests/test-diskann.py
Normal file
1160
tests/test-diskann.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -1187,6 +1187,7 @@ void test_ivf_quantize_binary() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void test_ivf_config_parsing() {
|
void test_ivf_config_parsing() {
|
||||||
|
void test_vec0_parse_vector_column_diskann() {
|
||||||
printf("Starting %s...\n", __func__);
|
printf("Starting %s...\n", __func__);
|
||||||
struct VectorColumnDefinition col;
|
struct VectorColumnDefinition col;
|
||||||
int rc;
|
int rc;
|
||||||
|
|
@ -1199,6 +1200,34 @@ void test_ivf_config_parsing() {
|
||||||
assert(col.index_type == VEC0_INDEX_TYPE_RESCORE);
|
assert(col.index_type == VEC0_INDEX_TYPE_RESCORE);
|
||||||
assert(col.rescore.quantizer_type == VEC0_RESCORE_QUANTIZER_BIT);
|
assert(col.rescore.quantizer_type == VEC0_RESCORE_QUANTIZER_BIT);
|
||||||
assert(col.rescore.oversample == 8); // default
|
assert(col.rescore.oversample == 8); // default
|
||||||
|
// Existing syntax (no INDEXED BY) should have diskann.enabled == 0
|
||||||
|
{
|
||||||
|
const char *input = "emb float[128]";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.index_type != VEC0_INDEX_TYPE_DISKANN);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// With distance_metric but no INDEXED BY
|
||||||
|
{
|
||||||
|
const char *input = "emb float[128] distance_metric=cosine";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.index_type != VEC0_INDEX_TYPE_DISKANN);
|
||||||
|
assert(col.distance_metric == VEC0_DISTANCE_METRIC_COSINE);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basic binary quantizer
|
||||||
|
{
|
||||||
|
const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.index_type == VEC0_INDEX_TYPE_DISKANN);
|
||||||
|
assert(col.diskann.quantizer_type == VEC0_DISKANN_QUANTIZER_BINARY);
|
||||||
|
assert(col.diskann.n_neighbors == 72); // default
|
||||||
|
assert(col.diskann.search_list_size == 128); // default
|
||||||
assert(col.dimensions == 128);
|
assert(col.dimensions == 128);
|
||||||
sqlite3_free(col.name);
|
sqlite3_free(col.name);
|
||||||
}
|
}
|
||||||
|
|
@ -1370,6 +1399,681 @@ void test_ivf_config_parsing() {
|
||||||
printf(" All ivf_config_parsing tests passed.\n");
|
printf(" All ivf_config_parsing tests passed.\n");
|
||||||
}
|
}
|
||||||
#endif /* SQLITE_VEC_ENABLE_IVF */
|
#endif /* SQLITE_VEC_ENABLE_IVF */
|
||||||
|
// INT8 quantizer
|
||||||
|
{
|
||||||
|
const char *input = "v float[64] INDEXED BY diskann(neighbor_quantizer=int8)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.index_type == VEC0_INDEX_TYPE_DISKANN);
|
||||||
|
assert(col.diskann.quantizer_type == VEC0_DISKANN_QUANTIZER_INT8);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Custom n_neighbors
|
||||||
|
{
|
||||||
|
const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=48)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.index_type == VEC0_INDEX_TYPE_DISKANN);
|
||||||
|
assert(col.diskann.n_neighbors == 48);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Custom search_list_size
|
||||||
|
{
|
||||||
|
const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, search_list_size=256)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.diskann.search_list_size == 256);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Combined with distance_metric (distance_metric first)
|
||||||
|
{
|
||||||
|
const char *input = "emb float[128] distance_metric=cosine INDEXED BY diskann(neighbor_quantizer=int8)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.distance_metric == VEC0_DISTANCE_METRIC_COSINE);
|
||||||
|
assert(col.index_type == VEC0_INDEX_TYPE_DISKANN);
|
||||||
|
assert(col.diskann.quantizer_type == VEC0_DISKANN_QUANTIZER_INT8);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error: missing neighbor_quantizer (required)
|
||||||
|
{
|
||||||
|
const char *input = "emb float[128] INDEXED BY diskann(n_neighbors=72)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error: empty parens
|
||||||
|
{
|
||||||
|
const char *input = "emb float[128] INDEXED BY diskann()";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error: unknown quantizer
|
||||||
|
{
|
||||||
|
const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=unknown)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error: bad n_neighbors (not divisible by 8)
|
||||||
|
{
|
||||||
|
const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=13)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error: n_neighbors too large
|
||||||
|
{
|
||||||
|
const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=512)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error: missing BY
|
||||||
|
{
|
||||||
|
const char *input = "emb float[128] INDEXED diskann(neighbor_quantizer=binary)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error: unknown algorithm
|
||||||
|
{
|
||||||
|
const char *input = "emb float[128] INDEXED BY hnsw(neighbor_quantizer=binary)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error: unknown option key
|
||||||
|
{
|
||||||
|
const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, foobar=baz)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case insensitivity for keywords
|
||||||
|
{
|
||||||
|
const char *input = "emb float[128] indexed by DISKANN(NEIGHBOR_QUANTIZER=BINARY)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.index_type == VEC0_INDEX_TYPE_DISKANN);
|
||||||
|
assert(col.diskann.quantizer_type == VEC0_DISKANN_QUANTIZER_BINARY);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split search_list_size: search and insert
|
||||||
|
{
|
||||||
|
const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, search_list_size_search=256, search_list_size_insert=64)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.diskann.search_list_size == 128); // default (unified)
|
||||||
|
assert(col.diskann.search_list_size_search == 256);
|
||||||
|
assert(col.diskann.search_list_size_insert == 64);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split search_list_size: only search
|
||||||
|
{
|
||||||
|
const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, search_list_size_search=200)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.diskann.search_list_size_search == 200);
|
||||||
|
assert(col.diskann.search_list_size_insert == 0);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error: cannot mix search_list_size with search_list_size_search
|
||||||
|
{
|
||||||
|
const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, search_list_size=128, search_list_size_search=256)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error: cannot mix search_list_size with search_list_size_insert
|
||||||
|
{
|
||||||
|
const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, search_list_size=128, search_list_size_insert=64)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
printf(" All vec0_parse_vector_column_diskann tests passed.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_diskann_validity_bitmap() {
|
||||||
|
printf("Starting %s...\n", __func__);
|
||||||
|
|
||||||
|
unsigned char validity[3]; // 24 bits
|
||||||
|
memset(validity, 0, sizeof(validity));
|
||||||
|
|
||||||
|
// All initially invalid
|
||||||
|
for (int i = 0; i < 24; i++) {
|
||||||
|
assert(diskann_validity_get(validity, i) == 0);
|
||||||
|
}
|
||||||
|
assert(diskann_validity_count(validity, 24) == 0);
|
||||||
|
|
||||||
|
// Set bit 0
|
||||||
|
diskann_validity_set(validity, 0, 1);
|
||||||
|
assert(diskann_validity_get(validity, 0) == 1);
|
||||||
|
assert(diskann_validity_count(validity, 24) == 1);
|
||||||
|
|
||||||
|
// Set bit 7 (last bit of first byte)
|
||||||
|
diskann_validity_set(validity, 7, 1);
|
||||||
|
assert(diskann_validity_get(validity, 7) == 1);
|
||||||
|
assert(diskann_validity_count(validity, 24) == 2);
|
||||||
|
|
||||||
|
// Set bit 8 (first bit of second byte)
|
||||||
|
diskann_validity_set(validity, 8, 1);
|
||||||
|
assert(diskann_validity_get(validity, 8) == 1);
|
||||||
|
assert(diskann_validity_count(validity, 24) == 3);
|
||||||
|
|
||||||
|
// Set bit 23 (last bit)
|
||||||
|
diskann_validity_set(validity, 23, 1);
|
||||||
|
assert(diskann_validity_get(validity, 23) == 1);
|
||||||
|
assert(diskann_validity_count(validity, 24) == 4);
|
||||||
|
|
||||||
|
// Clear bit 0
|
||||||
|
diskann_validity_set(validity, 0, 0);
|
||||||
|
assert(diskann_validity_get(validity, 0) == 0);
|
||||||
|
assert(diskann_validity_count(validity, 24) == 3);
|
||||||
|
|
||||||
|
// Other bits unaffected
|
||||||
|
assert(diskann_validity_get(validity, 7) == 1);
|
||||||
|
assert(diskann_validity_get(validity, 8) == 1);
|
||||||
|
|
||||||
|
printf(" All diskann_validity_bitmap tests passed.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_diskann_neighbor_ids() {
|
||||||
|
printf("Starting %s...\n", __func__);
|
||||||
|
|
||||||
|
unsigned char ids[8 * 8]; // 8 slots * 8 bytes each
|
||||||
|
memset(ids, 0, sizeof(ids));
|
||||||
|
|
||||||
|
// Set and get slot 0
|
||||||
|
diskann_neighbor_id_set(ids, 0, 42);
|
||||||
|
assert(diskann_neighbor_id_get(ids, 0) == 42);
|
||||||
|
|
||||||
|
// Set and get middle slot
|
||||||
|
diskann_neighbor_id_set(ids, 3, 12345);
|
||||||
|
assert(diskann_neighbor_id_get(ids, 3) == 12345);
|
||||||
|
|
||||||
|
// Set and get last slot
|
||||||
|
diskann_neighbor_id_set(ids, 7, 99999);
|
||||||
|
assert(diskann_neighbor_id_get(ids, 7) == 99999);
|
||||||
|
|
||||||
|
// Slot 0 still correct
|
||||||
|
assert(diskann_neighbor_id_get(ids, 0) == 42);
|
||||||
|
|
||||||
|
// Large value
|
||||||
|
diskann_neighbor_id_set(ids, 1, INT64_MAX);
|
||||||
|
assert(diskann_neighbor_id_get(ids, 1) == INT64_MAX);
|
||||||
|
|
||||||
|
printf(" All diskann_neighbor_ids tests passed.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_diskann_quantize_binary() {
|
||||||
|
printf("Starting %s...\n", __func__);
|
||||||
|
|
||||||
|
// 8-dimensional vector: positive values -> 1, negative/zero -> 0
|
||||||
|
float src[8] = {1.0f, -1.0f, 0.5f, 0.0f, -0.5f, 0.1f, -0.1f, 100.0f};
|
||||||
|
unsigned char out[1]; // 8 bits = 1 byte
|
||||||
|
|
||||||
|
int rc = diskann_quantize_vector(src, 8, VEC0_DISKANN_QUANTIZER_BINARY, out);
|
||||||
|
assert(rc == 0);
|
||||||
|
|
||||||
|
// Expected bits (LSB first within each byte):
|
||||||
|
// bit 0: 1.0 > 0 -> 1
|
||||||
|
// bit 1: -1.0 > 0 -> 0
|
||||||
|
// bit 2: 0.5 > 0 -> 1
|
||||||
|
// bit 3: 0.0 > 0 -> 0 (not strictly greater)
|
||||||
|
// bit 4: -0.5 > 0 -> 0
|
||||||
|
// bit 5: 0.1 > 0 -> 1
|
||||||
|
// bit 6: -0.1 > 0 -> 0
|
||||||
|
// bit 7: 100.0 > 0 -> 1
|
||||||
|
// Expected byte: 1 + 0 + 4 + 0 + 0 + 32 + 0 + 128 = 0b10100101 = 0xA5
|
||||||
|
assert(out[0] == 0xA5);
|
||||||
|
|
||||||
|
printf(" All diskann_quantize_binary tests passed.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_diskann_node_init_sizes() {
|
||||||
|
printf("Starting %s...\n", __func__);
|
||||||
|
|
||||||
|
unsigned char *validity, *ids, *qvecs;
|
||||||
|
int validitySize, idsSize, qvecsSize;
|
||||||
|
|
||||||
|
// 72 neighbors, binary quantizer, 1024 dims
|
||||||
|
int rc = diskann_node_init(72, VEC0_DISKANN_QUANTIZER_BINARY, 1024,
|
||||||
|
&validity, &validitySize, &ids, &idsSize, &qvecs, &qvecsSize);
|
||||||
|
assert(rc == 0);
|
||||||
|
assert(validitySize == 9); // 72/8
|
||||||
|
assert(idsSize == 576); // 72 * 8
|
||||||
|
assert(qvecsSize == 9216); // 72 * (1024/8)
|
||||||
|
|
||||||
|
// All validity bits should be 0
|
||||||
|
assert(diskann_validity_count(validity, 72) == 0);
|
||||||
|
|
||||||
|
sqlite3_free(validity);
|
||||||
|
sqlite3_free(ids);
|
||||||
|
sqlite3_free(qvecs);
|
||||||
|
|
||||||
|
// 8 neighbors, int8 quantizer, 32 dims
|
||||||
|
rc = diskann_node_init(8, VEC0_DISKANN_QUANTIZER_INT8, 32,
|
||||||
|
&validity, &validitySize, &ids, &idsSize, &qvecs, &qvecsSize);
|
||||||
|
assert(rc == 0);
|
||||||
|
assert(validitySize == 1); // 8/8
|
||||||
|
assert(idsSize == 64); // 8 * 8
|
||||||
|
assert(qvecsSize == 256); // 8 * 32
|
||||||
|
|
||||||
|
sqlite3_free(validity);
|
||||||
|
sqlite3_free(ids);
|
||||||
|
sqlite3_free(qvecs);
|
||||||
|
|
||||||
|
printf(" All diskann_node_init_sizes tests passed.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_diskann_node_set_clear_neighbor() {
|
||||||
|
printf("Starting %s...\n", __func__);
|
||||||
|
|
||||||
|
unsigned char *validity, *ids, *qvecs;
|
||||||
|
int validitySize, idsSize, qvecsSize;
|
||||||
|
|
||||||
|
// 8 neighbors, binary quantizer, 16 dims (2 bytes per qvec)
|
||||||
|
int rc = diskann_node_init(8, VEC0_DISKANN_QUANTIZER_BINARY, 16,
|
||||||
|
&validity, &validitySize, &ids, &idsSize, &qvecs, &qvecsSize);
|
||||||
|
assert(rc == 0);
|
||||||
|
|
||||||
|
// Create a test quantized vector (2 bytes)
|
||||||
|
unsigned char test_qvec[2] = {0xAB, 0xCD};
|
||||||
|
|
||||||
|
// Set neighbor at slot 3
|
||||||
|
diskann_node_set_neighbor(validity, ids, qvecs, 3,
|
||||||
|
42, test_qvec, VEC0_DISKANN_QUANTIZER_BINARY, 16);
|
||||||
|
|
||||||
|
// Verify slot 3 is valid
|
||||||
|
assert(diskann_validity_get(validity, 3) == 1);
|
||||||
|
assert(diskann_validity_count(validity, 8) == 1);
|
||||||
|
|
||||||
|
// Verify rowid
|
||||||
|
assert(diskann_neighbor_id_get(ids, 3) == 42);
|
||||||
|
|
||||||
|
// Verify quantized vector
|
||||||
|
const unsigned char *read_qvec = diskann_neighbor_qvec_get(
|
||||||
|
qvecs, 3, VEC0_DISKANN_QUANTIZER_BINARY, 16);
|
||||||
|
assert(read_qvec[0] == 0xAB);
|
||||||
|
assert(read_qvec[1] == 0xCD);
|
||||||
|
|
||||||
|
// Clear slot 3
|
||||||
|
diskann_node_clear_neighbor(validity, ids, qvecs, 3,
|
||||||
|
VEC0_DISKANN_QUANTIZER_BINARY, 16);
|
||||||
|
assert(diskann_validity_get(validity, 3) == 0);
|
||||||
|
assert(diskann_neighbor_id_get(ids, 3) == 0);
|
||||||
|
assert(diskann_validity_count(validity, 8) == 0);
|
||||||
|
|
||||||
|
sqlite3_free(validity);
|
||||||
|
sqlite3_free(ids);
|
||||||
|
sqlite3_free(qvecs);
|
||||||
|
|
||||||
|
printf(" All diskann_node_set_clear_neighbor tests passed.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_diskann_prune_select() {
|
||||||
|
printf("Starting %s...\n", __func__);
|
||||||
|
|
||||||
|
// Scenario: 5 candidates, sorted by distance to p
|
||||||
|
// Candidates: A(0), B(1), C(2), D(3), E(4)
|
||||||
|
// p_distances (already sorted): A=1.0, B=2.0, C=3.0, D=4.0, E=5.0
|
||||||
|
//
|
||||||
|
// Inter-candidate distances (5x5 matrix):
|
||||||
|
// A B C D E
|
||||||
|
// A 0.0 1.5 3.0 4.0 5.0
|
||||||
|
// B 1.5 0.0 1.5 3.0 4.0
|
||||||
|
// C 3.0 1.5 0.0 1.5 3.0
|
||||||
|
// D 4.0 3.0 1.5 0.0 1.5
|
||||||
|
// E 5.0 4.0 3.0 1.5 0.0
|
||||||
|
|
||||||
|
float p_distances[5] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
|
||||||
|
float inter[25] = {
|
||||||
|
0.0f, 1.5f, 3.0f, 4.0f, 5.0f,
|
||||||
|
1.5f, 0.0f, 1.5f, 3.0f, 4.0f,
|
||||||
|
3.0f, 1.5f, 0.0f, 1.5f, 3.0f,
|
||||||
|
4.0f, 3.0f, 1.5f, 0.0f, 1.5f,
|
||||||
|
5.0f, 4.0f, 3.0f, 1.5f, 0.0f,
|
||||||
|
};
|
||||||
|
int selected[5];
|
||||||
|
int count;
|
||||||
|
|
||||||
|
// alpha=1.0, R=3: greedy selection
|
||||||
|
// Round 1: Pick A (closest). Prune check:
|
||||||
|
// B: 1.0*1.5 <= 2.0? yes -> pruned
|
||||||
|
// C: 1.0*3.0 <= 3.0? yes -> pruned
|
||||||
|
// D: 1.0*4.0 <= 4.0? yes -> pruned
|
||||||
|
// E: 1.0*5.0 <= 5.0? yes -> pruned
|
||||||
|
// Result: only A selected
|
||||||
|
{
|
||||||
|
int rc = diskann_prune_select(inter, p_distances, 5, 1.0f, 3, selected, &count);
|
||||||
|
assert(rc == 0);
|
||||||
|
assert(count == 1);
|
||||||
|
assert(selected[0] == 1); // A
|
||||||
|
}
|
||||||
|
|
||||||
|
// alpha=1.5, R=3: diversity-aware
|
||||||
|
// Round 1: Pick A. Prune check:
|
||||||
|
// B: 1.5*1.5=2.25 <= 2.0? no -> keep
|
||||||
|
// C: 1.5*3.0=4.5 <= 3.0? no -> keep
|
||||||
|
// D: 1.5*4.0=6.0 <= 4.0? no -> keep
|
||||||
|
// E: 1.5*5.0=7.5 <= 5.0? no -> keep
|
||||||
|
// Round 2: Pick B. Prune check:
|
||||||
|
// C: 1.5*1.5=2.25 <= 3.0? yes -> pruned
|
||||||
|
// D: 1.5*3.0=4.5 <= 4.0? no -> keep
|
||||||
|
// E: 1.5*4.0=6.0 <= 5.0? no -> keep
|
||||||
|
// Round 3: Pick D. Done, 3 selected.
|
||||||
|
{
|
||||||
|
int rc = diskann_prune_select(inter, p_distances, 5, 1.5f, 3, selected, &count);
|
||||||
|
assert(rc == 0);
|
||||||
|
assert(count == 3);
|
||||||
|
assert(selected[0] == 1); // A
|
||||||
|
assert(selected[1] == 1); // B
|
||||||
|
assert(selected[3] == 1); // D
|
||||||
|
assert(selected[2] == 0); // C pruned
|
||||||
|
assert(selected[4] == 0); // E not reached
|
||||||
|
}
|
||||||
|
|
||||||
|
// R > num_candidates with very high alpha (no pruning): select all
|
||||||
|
{
|
||||||
|
int rc = diskann_prune_select(inter, p_distances, 5, 100.0f, 10, selected, &count);
|
||||||
|
assert(rc == 0);
|
||||||
|
assert(count == 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty candidate set
|
||||||
|
{
|
||||||
|
int rc = diskann_prune_select(NULL, NULL, 0, 1.2f, 3, selected, &count);
|
||||||
|
assert(rc == 0);
|
||||||
|
assert(count == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
printf(" All diskann_prune_select tests passed.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_diskann_quantized_vector_byte_size() {
|
||||||
|
printf("Starting %s...\n", __func__);
|
||||||
|
|
||||||
|
// Binary quantizer: 1 bit per dimension, so 128 dims = 16 bytes
|
||||||
|
assert(diskann_quantized_vector_byte_size(VEC0_DISKANN_QUANTIZER_BINARY, 128) == 16);
|
||||||
|
assert(diskann_quantized_vector_byte_size(VEC0_DISKANN_QUANTIZER_BINARY, 8) == 1);
|
||||||
|
assert(diskann_quantized_vector_byte_size(VEC0_DISKANN_QUANTIZER_BINARY, 1024) == 128);
|
||||||
|
|
||||||
|
// INT8 quantizer: 1 byte per dimension
|
||||||
|
assert(diskann_quantized_vector_byte_size(VEC0_DISKANN_QUANTIZER_INT8, 128) == 128);
|
||||||
|
assert(diskann_quantized_vector_byte_size(VEC0_DISKANN_QUANTIZER_INT8, 1) == 1);
|
||||||
|
assert(diskann_quantized_vector_byte_size(VEC0_DISKANN_QUANTIZER_INT8, 768) == 768);
|
||||||
|
|
||||||
|
printf(" All diskann_quantized_vector_byte_size tests passed.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_diskann_config_defaults() {
|
||||||
|
printf("Starting %s...\n", __func__);
|
||||||
|
|
||||||
|
// A freshly zero-initialized VectorColumnDefinition should have diskann.enabled == 0
|
||||||
|
struct VectorColumnDefinition col;
|
||||||
|
memset(&col, 0, sizeof(col));
|
||||||
|
assert(col.index_type != VEC0_INDEX_TYPE_DISKANN);
|
||||||
|
assert(col.diskann.n_neighbors == 0);
|
||||||
|
assert(col.diskann.search_list_size == 0);
|
||||||
|
|
||||||
|
// Verify parsing a normal vector column still works and diskann is not enabled
|
||||||
|
{
|
||||||
|
const char *input = "embedding float[768]";
|
||||||
|
int rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == 0 /* SQLITE_OK */);
|
||||||
|
assert(col.index_type != VEC0_INDEX_TYPE_DISKANN);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
printf(" All diskann_config_defaults tests passed.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ======================================================================
|
||||||
|
// Additional DiskANN unit tests
|
||||||
|
// ======================================================================
|
||||||
|
|
||||||
|
void test_diskann_quantize_int8() {
|
||||||
|
printf("Starting %s...\n", __func__);
|
||||||
|
|
||||||
|
// INT8 quantization uses fixed range [-1, 1]:
|
||||||
|
// step = 2.0 / 255.0
|
||||||
|
// out[i] = (i8)((src[i] + 1.0) / step - 128.0)
|
||||||
|
float src[4] = {-1.0f, 0.0f, 0.5f, 1.0f};
|
||||||
|
unsigned char out[4];
|
||||||
|
|
||||||
|
int rc = diskann_quantize_vector(src, 4, VEC0_DISKANN_QUANTIZER_INT8, out);
|
||||||
|
assert(rc == 0);
|
||||||
|
|
||||||
|
int8_t *signed_out = (int8_t *)out;
|
||||||
|
// -1.0 -> (0/step) - 128 = -128
|
||||||
|
assert(signed_out[0] == -128);
|
||||||
|
// 0.0 -> (1.0/step) - 128 ~= 127.5 - 128 ~= -0.5 -> (i8)(-0.5) = 0
|
||||||
|
assert(signed_out[1] >= -2 && signed_out[1] <= 2);
|
||||||
|
// 0.5 -> (1.5/step) - 128 ~= 191.25 - 128 = 63.25 -> (i8) 63
|
||||||
|
assert(signed_out[2] >= 60 && signed_out[2] <= 66);
|
||||||
|
// 1.0 -> should be close to 127 (may have float precision issues)
|
||||||
|
assert(signed_out[3] >= 126 && signed_out[3] <= 127);
|
||||||
|
|
||||||
|
printf(" All diskann_quantize_int8 tests passed.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_diskann_quantize_binary_16d() {
|
||||||
|
printf("Starting %s...\n", __func__);
|
||||||
|
|
||||||
|
// 16-dimensional vector (2 bytes output)
|
||||||
|
float src[16] = {
|
||||||
|
1.0f, -1.0f, 0.5f, -0.5f, // byte 0: bit0=1, bit1=0, bit2=1, bit3=0
|
||||||
|
0.1f, -0.1f, 0.0f, 100.0f, // byte 0: bit4=1, bit5=0, bit6=0, bit7=1
|
||||||
|
-1.0f, 1.0f, 1.0f, 1.0f, // byte 1: bit0=0, bit1=1, bit2=1, bit3=1
|
||||||
|
-1.0f, -1.0f, 1.0f, -1.0f // byte 1: bit4=0, bit5=0, bit6=1, bit7=0
|
||||||
|
};
|
||||||
|
unsigned char out[2];
|
||||||
|
|
||||||
|
int rc = diskann_quantize_vector(src, 16, VEC0_DISKANN_QUANTIZER_BINARY, out);
|
||||||
|
assert(rc == 0);
|
||||||
|
|
||||||
|
// byte 0: bits 0,2,4,7 set -> 0b10010101 = 0x95
|
||||||
|
assert(out[0] == 0x95);
|
||||||
|
// byte 1: bits 1,2,3,6 set -> 0b01001110 = 0x4E
|
||||||
|
assert(out[1] == 0x4E);
|
||||||
|
|
||||||
|
printf(" All diskann_quantize_binary_16d tests passed.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_diskann_quantize_binary_all_positive() {
|
||||||
|
printf("Starting %s...\n", __func__);
|
||||||
|
|
||||||
|
float src[8] = {1.0f, 2.0f, 0.1f, 0.001f, 100.0f, 42.0f, 0.5f, 3.14f};
|
||||||
|
unsigned char out[1];
|
||||||
|
|
||||||
|
int rc = diskann_quantize_vector(src, 8, VEC0_DISKANN_QUANTIZER_BINARY, out);
|
||||||
|
assert(rc == 0);
|
||||||
|
assert(out[0] == 0xFF); // All bits set
|
||||||
|
|
||||||
|
printf(" All diskann_quantize_binary_all_positive tests passed.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_diskann_quantize_binary_all_negative() {
|
||||||
|
printf("Starting %s...\n", __func__);
|
||||||
|
|
||||||
|
float src[8] = {-1.0f, -2.0f, -0.1f, -0.001f, -100.0f, -42.0f, -0.5f, 0.0f};
|
||||||
|
unsigned char out[1];
|
||||||
|
|
||||||
|
int rc = diskann_quantize_vector(src, 8, VEC0_DISKANN_QUANTIZER_BINARY, out);
|
||||||
|
assert(rc == 0);
|
||||||
|
assert(out[0] == 0x00); // No bits set (all <= 0)
|
||||||
|
|
||||||
|
printf(" All diskann_quantize_binary_all_negative tests passed.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_diskann_candidate_list_operations() {
|
||||||
|
printf("Starting %s...\n", __func__);
|
||||||
|
|
||||||
|
struct DiskannCandidateList list;
|
||||||
|
int rc = _test_diskann_candidate_list_init(&list, 5);
|
||||||
|
assert(rc == 0);
|
||||||
|
|
||||||
|
// Insert candidates in non-sorted order
|
||||||
|
_test_diskann_candidate_list_insert(&list, 10, 3.0f);
|
||||||
|
_test_diskann_candidate_list_insert(&list, 20, 1.0f);
|
||||||
|
_test_diskann_candidate_list_insert(&list, 30, 2.0f);
|
||||||
|
|
||||||
|
assert(_test_diskann_candidate_list_count(&list) == 3);
|
||||||
|
// Should be sorted by distance
|
||||||
|
assert(_test_diskann_candidate_list_rowid(&list, 0) == 20); // dist 1.0
|
||||||
|
assert(_test_diskann_candidate_list_rowid(&list, 1) == 30); // dist 2.0
|
||||||
|
assert(_test_diskann_candidate_list_rowid(&list, 2) == 10); // dist 3.0
|
||||||
|
|
||||||
|
assert(_test_diskann_candidate_list_distance(&list, 0) == 1.0f);
|
||||||
|
assert(_test_diskann_candidate_list_distance(&list, 1) == 2.0f);
|
||||||
|
assert(_test_diskann_candidate_list_distance(&list, 2) == 3.0f);
|
||||||
|
|
||||||
|
// Deduplication: inserting same rowid with better distance should update
|
||||||
|
_test_diskann_candidate_list_insert(&list, 10, 0.5f);
|
||||||
|
assert(_test_diskann_candidate_list_count(&list) == 3); // Same count
|
||||||
|
assert(_test_diskann_candidate_list_rowid(&list, 0) == 10); // Now first
|
||||||
|
assert(_test_diskann_candidate_list_distance(&list, 0) == 0.5f);
|
||||||
|
|
||||||
|
// Next unvisited: should be index 0
|
||||||
|
int idx = _test_diskann_candidate_list_next_unvisited(&list);
|
||||||
|
assert(idx == 0);
|
||||||
|
|
||||||
|
// Mark visited
|
||||||
|
_test_diskann_candidate_list_set_visited(&list, 0);
|
||||||
|
idx = _test_diskann_candidate_list_next_unvisited(&list);
|
||||||
|
assert(idx == 1); // Skip visited
|
||||||
|
|
||||||
|
// Fill to capacity (5) and try inserting a worse candidate
|
||||||
|
_test_diskann_candidate_list_insert(&list, 40, 4.0f);
|
||||||
|
_test_diskann_candidate_list_insert(&list, 50, 5.0f);
|
||||||
|
assert(_test_diskann_candidate_list_count(&list) == 5);
|
||||||
|
|
||||||
|
// Insert worse than worst -> should be discarded
|
||||||
|
int inserted = _test_diskann_candidate_list_insert(&list, 60, 10.0f);
|
||||||
|
assert(inserted == 0);
|
||||||
|
assert(_test_diskann_candidate_list_count(&list) == 5);
|
||||||
|
|
||||||
|
// Insert better than worst -> should replace worst
|
||||||
|
inserted = _test_diskann_candidate_list_insert(&list, 60, 3.5f);
|
||||||
|
assert(inserted == 1);
|
||||||
|
assert(_test_diskann_candidate_list_count(&list) == 5);
|
||||||
|
|
||||||
|
_test_diskann_candidate_list_free(&list);
|
||||||
|
|
||||||
|
printf(" All diskann_candidate_list_operations tests passed.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_diskann_visited_set_operations() {
|
||||||
|
printf("Starting %s...\n", __func__);
|
||||||
|
|
||||||
|
struct DiskannVisitedSet set;
|
||||||
|
int rc = _test_diskann_visited_set_init(&set, 32);
|
||||||
|
assert(rc == 0);
|
||||||
|
|
||||||
|
// Empty set
|
||||||
|
assert(_test_diskann_visited_set_contains(&set, 1) == 0);
|
||||||
|
assert(_test_diskann_visited_set_contains(&set, 100) == 0);
|
||||||
|
|
||||||
|
// Insert and check
|
||||||
|
int inserted = _test_diskann_visited_set_insert(&set, 42);
|
||||||
|
assert(inserted == 1);
|
||||||
|
assert(_test_diskann_visited_set_contains(&set, 42) == 1);
|
||||||
|
assert(_test_diskann_visited_set_contains(&set, 43) == 0);
|
||||||
|
|
||||||
|
// Double insert returns 0
|
||||||
|
inserted = _test_diskann_visited_set_insert(&set, 42);
|
||||||
|
assert(inserted == 0);
|
||||||
|
|
||||||
|
// Insert several
|
||||||
|
_test_diskann_visited_set_insert(&set, 1);
|
||||||
|
_test_diskann_visited_set_insert(&set, 2);
|
||||||
|
_test_diskann_visited_set_insert(&set, 100);
|
||||||
|
_test_diskann_visited_set_insert(&set, 999);
|
||||||
|
assert(_test_diskann_visited_set_contains(&set, 1) == 1);
|
||||||
|
assert(_test_diskann_visited_set_contains(&set, 2) == 1);
|
||||||
|
assert(_test_diskann_visited_set_contains(&set, 100) == 1);
|
||||||
|
assert(_test_diskann_visited_set_contains(&set, 999) == 1);
|
||||||
|
assert(_test_diskann_visited_set_contains(&set, 3) == 0);
|
||||||
|
|
||||||
|
// Sentinel value (rowid 0) should not be insertable
|
||||||
|
assert(_test_diskann_visited_set_contains(&set, 0) == 0);
|
||||||
|
inserted = _test_diskann_visited_set_insert(&set, 0);
|
||||||
|
assert(inserted == 0);
|
||||||
|
|
||||||
|
_test_diskann_visited_set_free(&set);
|
||||||
|
|
||||||
|
printf(" All diskann_visited_set_operations tests passed.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_diskann_prune_select_single_candidate() {
|
||||||
|
printf("Starting %s...\n", __func__);
|
||||||
|
|
||||||
|
float p_distances[1] = {5.0f};
|
||||||
|
float inter[1] = {0.0f};
|
||||||
|
int selected[1];
|
||||||
|
int count;
|
||||||
|
|
||||||
|
int rc = diskann_prune_select(inter, p_distances, 1, 1.0f, 3, selected, &count);
|
||||||
|
assert(rc == 0);
|
||||||
|
assert(count == 1);
|
||||||
|
assert(selected[0] == 1);
|
||||||
|
|
||||||
|
printf(" All diskann_prune_select_single_candidate tests passed.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_diskann_prune_select_all_identical_distances() {
|
||||||
|
printf("Starting %s...\n", __func__);
|
||||||
|
|
||||||
|
float p_distances[4] = {2.0f, 2.0f, 2.0f, 2.0f};
|
||||||
|
// All inter-distances are equal too
|
||||||
|
float inter[16] = {
|
||||||
|
0.0f, 1.0f, 1.0f, 1.0f,
|
||||||
|
1.0f, 0.0f, 1.0f, 1.0f,
|
||||||
|
1.0f, 1.0f, 0.0f, 1.0f,
|
||||||
|
1.0f, 1.0f, 1.0f, 0.0f,
|
||||||
|
};
|
||||||
|
int selected[4];
|
||||||
|
int count;
|
||||||
|
|
||||||
|
// alpha=1.0: pick first, then check if alpha * inter[0][j] <= p_dist[j]
|
||||||
|
// 1.0 * 1.0 <= 2.0? yes, so all are pruned after picking the first
|
||||||
|
int rc = diskann_prune_select(inter, p_distances, 4, 1.0f, 4, selected, &count);
|
||||||
|
assert(rc == 0);
|
||||||
|
assert(count >= 1); // At least one selected
|
||||||
|
|
||||||
|
printf(" All diskann_prune_select_all_identical_distances tests passed.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_diskann_prune_select_max_neighbors_1() {
|
||||||
|
printf("Starting %s...\n", __func__);
|
||||||
|
|
||||||
|
float p_distances[3] = {1.0f, 2.0f, 3.0f};
|
||||||
|
float inter[9] = {
|
||||||
|
0.0f, 5.0f, 5.0f,
|
||||||
|
5.0f, 0.0f, 5.0f,
|
||||||
|
5.0f, 5.0f, 0.0f,
|
||||||
|
};
|
||||||
|
int selected[3];
|
||||||
|
int count;
|
||||||
|
|
||||||
|
// R=1: should select exactly 1
|
||||||
|
int rc = diskann_prune_select(inter, p_distances, 3, 1.0f, 1, selected, &count);
|
||||||
|
assert(rc == 0);
|
||||||
|
assert(count == 1);
|
||||||
|
assert(selected[0] == 1); // First (closest) is selected
|
||||||
|
|
||||||
|
printf(" All diskann_prune_select_max_neighbors_1 tests passed.\n");
|
||||||
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
printf("Starting unit tests...\n");
|
printf("Starting unit tests...\n");
|
||||||
|
|
@ -1402,5 +2106,23 @@ int main() {
|
||||||
test_ivf_quantize_binary();
|
test_ivf_quantize_binary();
|
||||||
test_ivf_config_parsing();
|
test_ivf_config_parsing();
|
||||||
#endif
|
#endif
|
||||||
|
test_vec0_parse_vector_column_diskann();
|
||||||
|
test_diskann_validity_bitmap();
|
||||||
|
test_diskann_neighbor_ids();
|
||||||
|
test_diskann_quantize_binary();
|
||||||
|
test_diskann_node_init_sizes();
|
||||||
|
test_diskann_node_set_clear_neighbor();
|
||||||
|
test_diskann_prune_select();
|
||||||
|
test_diskann_quantized_vector_byte_size();
|
||||||
|
test_diskann_config_defaults();
|
||||||
|
test_diskann_quantize_int8();
|
||||||
|
test_diskann_quantize_binary_16d();
|
||||||
|
test_diskann_quantize_binary_all_positive();
|
||||||
|
test_diskann_quantize_binary_all_negative();
|
||||||
|
test_diskann_candidate_list_operations();
|
||||||
|
test_diskann_visited_set_operations();
|
||||||
|
test_diskann_prune_select_single_candidate();
|
||||||
|
test_diskann_prune_select_all_identical_distances();
|
||||||
|
test_diskann_prune_select_max_neighbors_1();
|
||||||
printf("All unit tests passed.\n");
|
printf("All unit tests passed.\n");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue