mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 08:46:49 +02:00
Compare commits
24 commits
v0.1.10-al
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5778fecfeb | ||
|
|
4e2dfcb79d | ||
|
|
b95c05b3aa | ||
|
|
89f6203536 | ||
|
|
c5607100ab | ||
|
|
6e2c4c6bab | ||
|
|
b7fc459be4 | ||
|
|
01b4b2a965 | ||
|
|
c36a995f1e | ||
|
|
c4c23bd8ba | ||
|
|
5522e86cd2 | ||
|
|
f2c9fb8f08 | ||
|
|
d684178a12 | ||
|
|
d033bf5728 | ||
|
|
b00865429b | ||
|
|
2f4c2e4bdb | ||
|
|
7de925be70 | ||
|
|
4bee88384b | ||
|
|
5e4c557f93 | ||
|
|
82f4eb08bf | ||
|
|
9df59b4c03 | ||
|
|
07f56e3cbe | ||
|
|
3cfc2e0c1f | ||
|
|
85cf415397 |
37 changed files with 3193 additions and 273 deletions
4
Makefile
4
Makefile
|
|
@ -37,7 +37,7 @@ endif
|
||||||
|
|
||||||
ifndef OMIT_SIMD
|
ifndef OMIT_SIMD
|
||||||
ifeq ($(shell uname -sm),Darwin x86_64)
|
ifeq ($(shell uname -sm),Darwin x86_64)
|
||||||
CFLAGS += -mavx -DSQLITE_VEC_ENABLE_AVX
|
CFLAGS += -mavx -mavx2 -DSQLITE_VEC_ENABLE_AVX
|
||||||
endif
|
endif
|
||||||
ifeq ($(shell uname -sm),Darwin arm64)
|
ifeq ($(shell uname -sm),Darwin arm64)
|
||||||
CFLAGS += -mcpu=apple-m1 -DSQLITE_VEC_ENABLE_NEON
|
CFLAGS += -mcpu=apple-m1 -DSQLITE_VEC_ENABLE_NEON
|
||||||
|
|
@ -45,7 +45,7 @@ ifndef OMIT_SIMD
|
||||||
ifeq ($(shell uname -s),Linux)
|
ifeq ($(shell uname -s),Linux)
|
||||||
ifeq ($(findstring android,$(CC)),)
|
ifeq ($(findstring android,$(CC)),)
|
||||||
ifneq ($(filter avx,$(shell grep -o 'avx[^ ]*' /proc/cpuinfo 2>/dev/null | head -1)),)
|
ifneq ($(filter avx,$(shell grep -o 'avx[^ ]*' /proc/cpuinfo 2>/dev/null | head -1)),)
|
||||||
CFLAGS += -mavx -DSQLITE_VEC_ENABLE_AVX
|
CFLAGS += -mavx -mavx2 -DSQLITE_VEC_ENABLE_AVX
|
||||||
endif
|
endif
|
||||||
endif
|
endif
|
||||||
endif
|
endif
|
||||||
|
|
|
||||||
2
VERSION
2
VERSION
|
|
@ -1 +1 @@
|
||||||
0.1.10-alpha.1
|
0.1.10-alpha.3
|
||||||
|
|
@ -4,9 +4,9 @@ EXT = ../dist/vec0
|
||||||
|
|
||||||
# --- Baseline (brute-force) configs ---
|
# --- Baseline (brute-force) configs ---
|
||||||
BASELINES = \
|
BASELINES = \
|
||||||
"brute-float:type=baseline,variant=float" \
|
"brute-float:type=vec0-flat,variant=float" \
|
||||||
"brute-int8:type=baseline,variant=int8" \
|
"brute-int8:type=vec0-flat,variant=int8" \
|
||||||
"brute-bit:type=baseline,variant=bit"
|
"brute-bit:type=vec0-flat,variant=bit"
|
||||||
|
|
||||||
# --- IVF configs ---
|
# --- IVF configs ---
|
||||||
IVF_CONFIGS = \
|
IVF_CONFIGS = \
|
||||||
|
|
@ -43,7 +43,7 @@ ground-truth: seed
|
||||||
# --- Quick smoke test ---
|
# --- Quick smoke test ---
|
||||||
bench-smoke: seed
|
bench-smoke: seed
|
||||||
$(BENCH) --subset-size 5000 -k 10 -n 20 --dataset cohere1m -o runs \
|
$(BENCH) --subset-size 5000 -k 10 -n 20 --dataset cohere1m -o runs \
|
||||||
"brute-float:type=baseline,variant=float" \
|
"brute-float:type=vec0-flat,variant=float" \
|
||||||
"ivf-quick:type=ivf,nlist=16,nprobe=4" \
|
"ivf-quick:type=ivf,nlist=16,nprobe=4" \
|
||||||
"diskann-quick:type=diskann,R=48,L=64,quantizer=binary"
|
"diskann-quick:type=diskann,R=48,L=64,quantizer=binary"
|
||||||
|
|
||||||
|
|
|
||||||
3
benchmarks-ann/bench-delete/.gitignore
vendored
Normal file
3
benchmarks-ann/bench-delete/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
runs/
|
||||||
|
*.db
|
||||||
|
__pycache__/
|
||||||
41
benchmarks-ann/bench-delete/Makefile
Normal file
41
benchmarks-ann/bench-delete/Makefile
Normal file
|
|
@ -0,0 +1,41 @@
|
||||||
|
BENCH = python bench_delete.py
|
||||||
|
EXT = ../../dist/vec0
|
||||||
|
|
||||||
|
# --- Configs to test ---
|
||||||
|
FLAT = "flat:type=vec0-flat,variant=float"
|
||||||
|
RESCORE_BIT = "rescore-bit:type=rescore,quantizer=bit,oversample=8"
|
||||||
|
RESCORE_INT8 = "rescore-int8:type=rescore,quantizer=int8,oversample=8"
|
||||||
|
DISKANN_R48 = "diskann-R48:type=diskann,R=48,L=128,quantizer=binary"
|
||||||
|
DISKANN_R72 = "diskann-R72:type=diskann,R=72,L=128,quantizer=binary"
|
||||||
|
|
||||||
|
ALL_CONFIGS = $(FLAT) $(RESCORE_BIT) $(RESCORE_INT8) $(DISKANN_R48) $(DISKANN_R72)
|
||||||
|
|
||||||
|
DELETE_PCTS = 5,10,25,50,75,90
|
||||||
|
|
||||||
|
.PHONY: smoke bench-10k bench-50k bench-all report clean
|
||||||
|
|
||||||
|
# Quick smoke test (small dataset, few queries)
|
||||||
|
smoke:
|
||||||
|
$(BENCH) --subset-size 5000 --delete-pct 10,50 -k 10 -n 20 \
|
||||||
|
--dataset cohere1m --ext $(EXT) \
|
||||||
|
$(FLAT) $(DISKANN_R48)
|
||||||
|
|
||||||
|
# Standard benchmarks
|
||||||
|
bench-10k:
|
||||||
|
$(BENCH) --subset-size 10000 --delete-pct $(DELETE_PCTS) -k 10 -n 50 \
|
||||||
|
--dataset cohere1m --ext $(EXT) $(ALL_CONFIGS)
|
||||||
|
|
||||||
|
bench-50k:
|
||||||
|
$(BENCH) --subset-size 50000 --delete-pct $(DELETE_PCTS) -k 10 -n 50 \
|
||||||
|
--dataset cohere1m --ext $(EXT) $(ALL_CONFIGS)
|
||||||
|
|
||||||
|
bench-all: bench-10k bench-50k
|
||||||
|
|
||||||
|
# Query saved results
|
||||||
|
report:
|
||||||
|
@echo "Query results:"
|
||||||
|
@echo " sqlite3 runs/cohere1m/10000/delete_results.db \\"
|
||||||
|
@echo " \"SELECT config_name, delete_pct, recall, query_mean_ms, vacuum_size_mb FROM delete_runs ORDER BY config_name, delete_pct\""
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -rf runs/
|
||||||
69
benchmarks-ann/bench-delete/README.md
Normal file
69
benchmarks-ann/bench-delete/README.md
Normal file
|
|
@ -0,0 +1,69 @@
|
||||||
|
# bench-delete: Recall degradation after random deletion
|
||||||
|
|
||||||
|
Measures how KNN recall changes after deleting a random percentage of rows
|
||||||
|
from different index types (flat, rescore, DiskANN).
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Ensure dataset exists
|
||||||
|
make -C ../datasets/cohere1m
|
||||||
|
|
||||||
|
# Ensure extension is built
|
||||||
|
make -C ../.. loadable
|
||||||
|
|
||||||
|
# Quick smoke test
|
||||||
|
make smoke
|
||||||
|
|
||||||
|
# Full benchmark at 10k vectors
|
||||||
|
make bench-10k
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python bench_delete.py --subset-size 10000 --delete-pct 10,25,50,75 \
|
||||||
|
"flat:type=vec0-flat,variant=float" \
|
||||||
|
"diskann-R72:type=diskann,R=72,L=128,quantizer=binary" \
|
||||||
|
"rescore-bit:type=rescore,quantizer=bit,oversample=8"
|
||||||
|
```
|
||||||
|
|
||||||
|
## What it measures
|
||||||
|
|
||||||
|
For each config and delete percentage:
|
||||||
|
|
||||||
|
| Metric | Description |
|
||||||
|
|--------|-------------|
|
||||||
|
| **recall** | KNN recall@k after deletion (ground truth recomputed over surviving rows) |
|
||||||
|
| **delta** | Recall change vs 0% baseline |
|
||||||
|
| **query latency** | Mean/median query time after deletion |
|
||||||
|
| **db_size_mb** | DB file size before VACUUM |
|
||||||
|
| **vacuum_size_mb** | DB file size after VACUUM (space reclaimed) |
|
||||||
|
| **delete_time_s** | Wall time for the DELETE operations |
|
||||||
|
|
||||||
|
## How it works
|
||||||
|
|
||||||
|
1. Build index with N vectors (one copy per config)
|
||||||
|
2. Measure recall at k=10 (pre-delete baseline)
|
||||||
|
3. For each delete %:
|
||||||
|
- Copy the master DB
|
||||||
|
- Delete a random selection of rows (deterministic seed)
|
||||||
|
- Measure recall (ground truth recomputed over surviving rows only)
|
||||||
|
- VACUUM and measure size savings
|
||||||
|
4. Print comparison table
|
||||||
|
|
||||||
|
## Expected behavior
|
||||||
|
|
||||||
|
- **Flat index**: Recall should be 1.0 at all delete percentages (brute-force is always exact)
|
||||||
|
- **Rescore**: Recall should stay close to baseline (quantized scan + rescore is robust)
|
||||||
|
- **DiskANN**: Recall may degrade at high delete % due to graph fragmentation (dangling edges, broken connectivity)
|
||||||
|
|
||||||
|
## Results DB
|
||||||
|
|
||||||
|
Results are stored in `runs/<dataset>/<subset_size>/delete_results.db`:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
SELECT config_name, delete_pct, recall, vacuum_size_mb
|
||||||
|
FROM delete_runs
|
||||||
|
ORDER BY config_name, delete_pct;
|
||||||
|
```
|
||||||
593
benchmarks-ann/bench-delete/bench_delete.py
Normal file
593
benchmarks-ann/bench-delete/bench_delete.py
Normal file
|
|
@ -0,0 +1,593 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Benchmark: measure recall degradation after random row deletion.
|
||||||
|
|
||||||
|
Given a dataset and index config, this script:
|
||||||
|
1. Builds the index (flat + ANN)
|
||||||
|
2. Measures recall at k=10 (pre-delete baseline)
|
||||||
|
3. Deletes a random % of rows
|
||||||
|
4. Measures recall again (post-delete)
|
||||||
|
5. Records DB size before/after deletion, recall delta, timings
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python bench_delete.py --subset-size 10000 --delete-pct 25 \
|
||||||
|
"diskann-R48:type=diskann,R=48,L=128,quantizer=binary"
|
||||||
|
|
||||||
|
# Multiple delete percentages in one run:
|
||||||
|
python bench_delete.py --subset-size 10000 --delete-pct 10,25,50,75 \
|
||||||
|
"diskann-R48:type=diskann,R=48,L=128,quantizer=binary"
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import shutil
|
||||||
|
import sqlite3
|
||||||
|
import statistics
|
||||||
|
import struct
|
||||||
|
import time
|
||||||
|
|
||||||
|
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
_BENCH_DIR = os.path.join(_SCRIPT_DIR, "..")
|
||||||
|
_ROOT_DIR = os.path.join(_BENCH_DIR, "..")
|
||||||
|
|
||||||
|
EXT_PATH = os.path.join(_ROOT_DIR, "dist", "vec0")
|
||||||
|
DATASETS_DIR = os.path.join(_BENCH_DIR, "datasets")
|
||||||
|
|
||||||
|
DATASETS = {
|
||||||
|
"cohere1m": {"base_db": os.path.join(DATASETS_DIR, "cohere1m", "base.db"), "dimensions": 768},
|
||||||
|
"cohere10m": {"base_db": os.path.join(DATASETS_DIR, "cohere10m", "base.db"), "dimensions": 768},
|
||||||
|
"nyt": {"base_db": os.path.join(DATASETS_DIR, "nyt", "base.db"), "dimensions": 256},
|
||||||
|
"nyt-768": {"base_db": os.path.join(DATASETS_DIR, "nyt-768", "base.db"), "dimensions": 768},
|
||||||
|
"nyt-1024": {"base_db": os.path.join(DATASETS_DIR, "nyt-1024", "base.db"), "dimensions": 1024},
|
||||||
|
"nyt-384": {"base_db": os.path.join(DATASETS_DIR, "nyt-384", "base.db"), "dimensions": 384},
|
||||||
|
}
|
||||||
|
|
||||||
|
INSERT_BATCH_SIZE = 1000
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Timing helpers
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def now_ns():
|
||||||
|
return time.time_ns()
|
||||||
|
|
||||||
|
def ns_to_s(ns):
|
||||||
|
return ns / 1_000_000_000
|
||||||
|
|
||||||
|
def ns_to_ms(ns):
|
||||||
|
return ns / 1_000_000
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Index registry (subset of bench.py — only types relevant to deletion)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def _vec0_flat_create(p):
|
||||||
|
dims = p["dimensions"]
|
||||||
|
variant = p.get("variant", "float")
|
||||||
|
col = f"embedding float[{dims}]"
|
||||||
|
if variant == "int8":
|
||||||
|
col = f"embedding int8[{dims}]"
|
||||||
|
elif variant == "bit":
|
||||||
|
col = f"embedding bit[{dims}]"
|
||||||
|
return f"CREATE VIRTUAL TABLE vec_items USING vec0(id INTEGER PRIMARY KEY, {col})"
|
||||||
|
|
||||||
|
def _rescore_create(p):
|
||||||
|
dims = p["dimensions"]
|
||||||
|
q = p.get("quantizer", "bit")
|
||||||
|
os_val = p.get("oversample", 8)
|
||||||
|
return (
|
||||||
|
f"CREATE VIRTUAL TABLE vec_items USING vec0("
|
||||||
|
f"id INTEGER PRIMARY KEY, "
|
||||||
|
f"embedding float[{dims}] indexed by rescore(quantizer={q}, oversample={os_val}))"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _diskann_create(p):
|
||||||
|
dims = p["dimensions"]
|
||||||
|
R = p.get("R", 72)
|
||||||
|
L = p.get("L", 128)
|
||||||
|
q = p.get("quantizer", "binary")
|
||||||
|
bt = p.get("buffer_threshold", 0)
|
||||||
|
sl_insert = p.get("search_list_size_insert", 0)
|
||||||
|
sl_search = p.get("search_list_size_search", 0)
|
||||||
|
parts = [
|
||||||
|
f"neighbor_quantizer={q}",
|
||||||
|
f"n_neighbors={R}",
|
||||||
|
f"buffer_threshold={bt}",
|
||||||
|
]
|
||||||
|
if sl_insert or sl_search:
|
||||||
|
# Per-path overrides — don't also set search_list_size
|
||||||
|
if sl_insert:
|
||||||
|
parts.append(f"search_list_size_insert={sl_insert}")
|
||||||
|
if sl_search:
|
||||||
|
parts.append(f"search_list_size_search={sl_search}")
|
||||||
|
else:
|
||||||
|
parts.append(f"search_list_size={L}")
|
||||||
|
opts = ", ".join(parts)
|
||||||
|
return (
|
||||||
|
f"CREATE VIRTUAL TABLE vec_items USING vec0("
|
||||||
|
f"id INTEGER PRIMARY KEY, "
|
||||||
|
f"embedding float[{dims}] indexed by diskann({opts}))"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _ivf_create(p):
|
||||||
|
dims = p["dimensions"]
|
||||||
|
nlist = p.get("nlist", 128)
|
||||||
|
nprobe = p.get("nprobe", 16)
|
||||||
|
q = p.get("quantizer", "none")
|
||||||
|
os_val = p.get("oversample", 1)
|
||||||
|
parts = [f"nlist={nlist}", f"nprobe={nprobe}"]
|
||||||
|
if q != "none":
|
||||||
|
parts.append(f"quantizer={q}")
|
||||||
|
if os_val > 1:
|
||||||
|
parts.append(f"oversample={os_val}")
|
||||||
|
opts = ", ".join(parts)
|
||||||
|
return (
|
||||||
|
f"CREATE VIRTUAL TABLE vec_items USING vec0("
|
||||||
|
f"id INTEGER PRIMARY KEY, "
|
||||||
|
f"embedding float[{dims}] indexed by ivf({opts}))"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
INDEX_REGISTRY = {
|
||||||
|
"vec0-flat": {
|
||||||
|
"defaults": {"variant": "float"},
|
||||||
|
"create_table_sql": _vec0_flat_create,
|
||||||
|
"post_insert_hook": None,
|
||||||
|
},
|
||||||
|
"rescore": {
|
||||||
|
"defaults": {"quantizer": "bit", "oversample": 8},
|
||||||
|
"create_table_sql": _rescore_create,
|
||||||
|
"post_insert_hook": None,
|
||||||
|
},
|
||||||
|
"ivf": {
|
||||||
|
"defaults": {"nlist": 128, "nprobe": 16, "quantizer": "none",
|
||||||
|
"oversample": 1},
|
||||||
|
"create_table_sql": _ivf_create,
|
||||||
|
"post_insert_hook": lambda conn, params: _ivf_train(conn),
|
||||||
|
},
|
||||||
|
"diskann": {
|
||||||
|
"defaults": {"R": 72, "L": 128, "quantizer": "binary",
|
||||||
|
"buffer_threshold": 0},
|
||||||
|
"create_table_sql": _diskann_create,
|
||||||
|
"post_insert_hook": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _ivf_train(conn):
|
||||||
|
"""Trigger built-in k-means training for IVF."""
|
||||||
|
t0 = now_ns()
|
||||||
|
conn.execute("INSERT INTO vec_items(vec_items) VALUES ('compute-centroids')")
|
||||||
|
conn.commit()
|
||||||
|
return ns_to_s(now_ns() - t0)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Config parsing (same format as bench.py)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
INT_KEYS = {"R", "L", "oversample", "nlist", "nprobe", "buffer_threshold",
|
||||||
|
"search_list_size_insert", "search_list_size_search"}
|
||||||
|
|
||||||
|
def parse_config(spec):
|
||||||
|
if ":" not in spec:
|
||||||
|
raise ValueError(f"Config must be 'name:key=val,...': {spec}")
|
||||||
|
name, rest = spec.split(":", 1)
|
||||||
|
params = {}
|
||||||
|
for kv in rest.split(","):
|
||||||
|
k, v = kv.split("=", 1)
|
||||||
|
k = k.strip()
|
||||||
|
v = v.strip()
|
||||||
|
if k in INT_KEYS:
|
||||||
|
v = int(v)
|
||||||
|
params[k] = v
|
||||||
|
index_type = params.pop("type", None)
|
||||||
|
if not index_type or index_type not in INDEX_REGISTRY:
|
||||||
|
raise ValueError(f"Unknown index type: {index_type}")
|
||||||
|
params["index_type"] = index_type
|
||||||
|
merged = dict(INDEX_REGISTRY[index_type]["defaults"])
|
||||||
|
merged.update(params)
|
||||||
|
return name, merged
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# DB helpers
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def create_bench_db(db_path, ext_path, base_db, page_size=4096):
|
||||||
|
if os.path.exists(db_path):
|
||||||
|
os.remove(db_path)
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
conn.execute(f"PRAGMA page_size={page_size}")
|
||||||
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
conn.enable_load_extension(True)
|
||||||
|
conn.load_extension(ext_path)
|
||||||
|
conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
def load_query_vectors(base_db, n):
|
||||||
|
conn = sqlite3.connect(base_db)
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT id, vector FROM query_vectors LIMIT ?", (n,)
|
||||||
|
).fetchall()
|
||||||
|
conn.close()
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
def insert_loop(conn, subset_size, label, start_from=0):
|
||||||
|
insert_sql = (
|
||||||
|
"INSERT INTO vec_items(id, embedding) "
|
||||||
|
"SELECT id, vector FROM base.train "
|
||||||
|
"WHERE id >= :lo AND id < :hi"
|
||||||
|
)
|
||||||
|
total = 0
|
||||||
|
for lo in range(start_from, subset_size, INSERT_BATCH_SIZE):
|
||||||
|
hi = min(lo + INSERT_BATCH_SIZE, subset_size)
|
||||||
|
conn.execute(insert_sql, {"lo": lo, "hi": hi})
|
||||||
|
conn.commit()
|
||||||
|
total += hi - lo
|
||||||
|
if total % 5000 == 0 or total == subset_size - start_from:
|
||||||
|
print(f" [{label}] inserted {total + start_from}/{subset_size}", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Recall measurement
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def measure_recall(conn, base_db, query_vectors, subset_size, k, alive_ids=None):
|
||||||
|
"""Measure KNN recall. If alive_ids is provided, ground truth is computed
|
||||||
|
only over those IDs (to match post-delete state)."""
|
||||||
|
recalls = []
|
||||||
|
times_ms = []
|
||||||
|
|
||||||
|
for qid, query in query_vectors:
|
||||||
|
t0 = now_ns()
|
||||||
|
results = conn.execute(
|
||||||
|
"SELECT id, distance FROM vec_items "
|
||||||
|
"WHERE embedding MATCH :query AND k = :k",
|
||||||
|
{"query": query, "k": k},
|
||||||
|
).fetchall()
|
||||||
|
t1 = now_ns()
|
||||||
|
times_ms.append(ns_to_ms(t1 - t0))
|
||||||
|
|
||||||
|
result_ids = set(r[0] for r in results)
|
||||||
|
|
||||||
|
# Ground truth: brute-force cosine over surviving rows
|
||||||
|
if alive_ids is not None:
|
||||||
|
# After deletion — compute GT only over alive IDs
|
||||||
|
# Use a temp table for the alive set for efficiency
|
||||||
|
gt_rows = conn.execute(
|
||||||
|
"SELECT id FROM ("
|
||||||
|
" SELECT id, vec_distance_l2(vector, :query) as dist "
|
||||||
|
" FROM base.train WHERE id < :n ORDER BY dist LIMIT :k2"
|
||||||
|
")",
|
||||||
|
{"query": query, "k2": k * 5, "n": subset_size},
|
||||||
|
).fetchall()
|
||||||
|
# Filter to only alive IDs, take top k
|
||||||
|
gt_alive = [r[0] for r in gt_rows if r[0] in alive_ids][:k]
|
||||||
|
gt_ids = set(gt_alive)
|
||||||
|
else:
|
||||||
|
gt_rows = conn.execute(
|
||||||
|
"SELECT id FROM ("
|
||||||
|
" SELECT id, vec_distance_l2(vector, :query) as dist "
|
||||||
|
" FROM base.train WHERE id < :n ORDER BY dist LIMIT :k"
|
||||||
|
")",
|
||||||
|
{"query": query, "k": k, "n": subset_size},
|
||||||
|
).fetchall()
|
||||||
|
gt_ids = set(r[0] for r in gt_rows)
|
||||||
|
|
||||||
|
if gt_ids:
|
||||||
|
recalls.append(len(result_ids & gt_ids) / len(gt_ids))
|
||||||
|
else:
|
||||||
|
recalls.append(0.0)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"recall": round(statistics.mean(recalls), 4) if recalls else 0.0,
|
||||||
|
"mean_ms": round(statistics.mean(times_ms), 2) if times_ms else 0.0,
|
||||||
|
"median_ms": round(statistics.median(times_ms), 2) if times_ms else 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Delete benchmark core
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def run_delete_benchmark(name, params, base_db, ext_path, subset_size, dims,
|
||||||
|
delete_pcts, k, n_queries, out_dir, seed_val):
|
||||||
|
params["dimensions"] = dims
|
||||||
|
reg = INDEX_REGISTRY[params["index_type"]]
|
||||||
|
create_sql = reg["create_table_sql"](params)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Build once, copy for each delete %
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"Config: {name} (type={params['index_type']})")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
master_db_path = os.path.join(out_dir, f"{name}.{subset_size}.db")
|
||||||
|
print(f" Building index ({subset_size} vectors)...")
|
||||||
|
build_t0 = now_ns()
|
||||||
|
conn = create_bench_db(master_db_path, ext_path, base_db)
|
||||||
|
conn.execute(create_sql)
|
||||||
|
insert_loop(conn, subset_size, name)
|
||||||
|
hook = reg.get("post_insert_hook")
|
||||||
|
if hook:
|
||||||
|
print(f" Training...")
|
||||||
|
hook(conn, params)
|
||||||
|
conn.close()
|
||||||
|
build_time_s = ns_to_s(now_ns() - build_t0)
|
||||||
|
master_size = os.path.getsize(master_db_path)
|
||||||
|
print(f" Built in {build_time_s:.1f}s ({master_size / (1024*1024):.1f} MB)")
|
||||||
|
|
||||||
|
# Load query vectors once
|
||||||
|
query_vectors = load_query_vectors(base_db, n_queries)
|
||||||
|
|
||||||
|
# Measure pre-delete baseline on the master copy
|
||||||
|
print(f"\n --- 0% deleted (baseline) ---")
|
||||||
|
conn = sqlite3.connect(master_db_path)
|
||||||
|
conn.enable_load_extension(True)
|
||||||
|
conn.load_extension(ext_path)
|
||||||
|
conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
|
||||||
|
baseline = measure_recall(conn, base_db, query_vectors, subset_size, k)
|
||||||
|
conn.close()
|
||||||
|
print(f" recall={baseline['recall']:.4f} "
|
||||||
|
f"query={baseline['mean_ms']:.2f}ms")
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"name": name,
|
||||||
|
"index_type": params["index_type"],
|
||||||
|
"subset_size": subset_size,
|
||||||
|
"delete_pct": 0,
|
||||||
|
"n_deleted": 0,
|
||||||
|
"n_remaining": subset_size,
|
||||||
|
"recall": baseline["recall"],
|
||||||
|
"query_mean_ms": baseline["mean_ms"],
|
||||||
|
"query_median_ms": baseline["median_ms"],
|
||||||
|
"db_size_mb": round(master_size / (1024 * 1024), 2),
|
||||||
|
"build_time_s": round(build_time_s, 1),
|
||||||
|
"delete_time_s": 0.0,
|
||||||
|
"vacuum_size_mb": round(master_size / (1024 * 1024), 2),
|
||||||
|
})
|
||||||
|
|
||||||
|
# All IDs in the dataset
|
||||||
|
all_ids = list(range(subset_size))
|
||||||
|
|
||||||
|
for pct in sorted(delete_pcts):
|
||||||
|
n_delete = int(subset_size * pct / 100)
|
||||||
|
print(f"\n --- {pct}% deleted ({n_delete} rows) ---")
|
||||||
|
|
||||||
|
# Copy master DB and work on the copy
|
||||||
|
copy_path = os.path.join(out_dir, f"{name}.{subset_size}.del{pct}.db")
|
||||||
|
shutil.copy2(master_db_path, copy_path)
|
||||||
|
# Also copy WAL/SHM if they exist
|
||||||
|
for suffix in ["-wal", "-shm"]:
|
||||||
|
src = master_db_path + suffix
|
||||||
|
if os.path.exists(src):
|
||||||
|
shutil.copy2(src, copy_path + suffix)
|
||||||
|
|
||||||
|
conn = sqlite3.connect(copy_path)
|
||||||
|
conn.enable_load_extension(True)
|
||||||
|
conn.load_extension(ext_path)
|
||||||
|
conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
|
||||||
|
|
||||||
|
# Pick random IDs to delete (deterministic per pct)
|
||||||
|
rng = random.Random(seed_val + pct)
|
||||||
|
to_delete = set(rng.sample(all_ids, n_delete))
|
||||||
|
alive_ids = set(all_ids) - to_delete
|
||||||
|
|
||||||
|
# Delete
|
||||||
|
delete_t0 = now_ns()
|
||||||
|
batch = []
|
||||||
|
for i, rid in enumerate(to_delete):
|
||||||
|
batch.append(rid)
|
||||||
|
if len(batch) >= 500 or i == len(to_delete) - 1:
|
||||||
|
placeholders = ",".join("?" for _ in batch)
|
||||||
|
conn.execute(
|
||||||
|
f"DELETE FROM vec_items WHERE id IN ({placeholders})",
|
||||||
|
batch,
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
batch = []
|
||||||
|
delete_time_s = ns_to_s(now_ns() - delete_t0)
|
||||||
|
|
||||||
|
remaining = conn.execute("SELECT count(*) FROM vec_items").fetchone()[0]
|
||||||
|
pre_vacuum_size = os.path.getsize(copy_path)
|
||||||
|
print(f" deleted {n_delete} rows in {delete_time_s:.2f}s "
|
||||||
|
f"({remaining} remaining)")
|
||||||
|
|
||||||
|
# Measure post-delete recall
|
||||||
|
post = measure_recall(conn, base_db, query_vectors, subset_size, k,
|
||||||
|
alive_ids=alive_ids)
|
||||||
|
print(f" recall={post['recall']:.4f} "
|
||||||
|
f"(delta={post['recall'] - baseline['recall']:+.4f}) "
|
||||||
|
f"query={post['mean_ms']:.2f}ms")
|
||||||
|
|
||||||
|
# VACUUM and measure size savings — close fully, reopen without base
|
||||||
|
conn.close()
|
||||||
|
vconn = sqlite3.connect(copy_path)
|
||||||
|
vconn.execute("VACUUM")
|
||||||
|
vconn.close()
|
||||||
|
post_vacuum_size = os.path.getsize(copy_path)
|
||||||
|
saved_mb = (pre_vacuum_size - post_vacuum_size) / (1024 * 1024)
|
||||||
|
print(f" size: {pre_vacuum_size/(1024*1024):.1f} MB -> "
|
||||||
|
f"{post_vacuum_size/(1024*1024):.1f} MB after VACUUM "
|
||||||
|
f"(saved {saved_mb:.1f} MB)")
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"name": name,
|
||||||
|
"index_type": params["index_type"],
|
||||||
|
"subset_size": subset_size,
|
||||||
|
"delete_pct": pct,
|
||||||
|
"n_deleted": n_delete,
|
||||||
|
"n_remaining": remaining,
|
||||||
|
"recall": post["recall"],
|
||||||
|
"query_mean_ms": post["mean_ms"],
|
||||||
|
"query_median_ms": post["median_ms"],
|
||||||
|
"db_size_mb": round(pre_vacuum_size / (1024 * 1024), 2),
|
||||||
|
"build_time_s": round(build_time_s, 1),
|
||||||
|
"delete_time_s": round(delete_time_s, 2),
|
||||||
|
"vacuum_size_mb": round(post_vacuum_size / (1024 * 1024), 2),
|
||||||
|
})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Results DB
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
RESULTS_SCHEMA = """\
|
||||||
|
CREATE TABLE IF NOT EXISTS delete_runs (
|
||||||
|
run_id INTEGER PRIMARY KEY,
|
||||||
|
config_name TEXT NOT NULL,
|
||||||
|
index_type TEXT NOT NULL,
|
||||||
|
params TEXT,
|
||||||
|
dataset TEXT NOT NULL,
|
||||||
|
subset_size INTEGER NOT NULL,
|
||||||
|
delete_pct INTEGER NOT NULL,
|
||||||
|
n_deleted INTEGER NOT NULL,
|
||||||
|
n_remaining INTEGER NOT NULL,
|
||||||
|
k INTEGER NOT NULL,
|
||||||
|
n_queries INTEGER NOT NULL,
|
||||||
|
seed INTEGER NOT NULL,
|
||||||
|
recall REAL,
|
||||||
|
query_mean_ms REAL,
|
||||||
|
query_median_ms REAL,
|
||||||
|
db_size_mb REAL,
|
||||||
|
vacuum_size_mb REAL,
|
||||||
|
build_time_s REAL,
|
||||||
|
delete_time_s REAL,
|
||||||
|
created_at TEXT DEFAULT (datetime('now'))
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
def save_results(results, out_dir, dataset, subset_size, params_json, k, n_queries, seed_val):
|
||||||
|
db_path = os.path.join(out_dir, "delete_results.db")
|
||||||
|
db = sqlite3.connect(db_path)
|
||||||
|
db.execute("PRAGMA journal_mode=WAL")
|
||||||
|
db.executescript(RESULTS_SCHEMA)
|
||||||
|
for r in results:
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO delete_runs "
|
||||||
|
"(config_name, index_type, params, dataset, subset_size, "
|
||||||
|
" delete_pct, n_deleted, n_remaining, k, n_queries, seed, "
|
||||||
|
" recall, query_mean_ms, query_median_ms, "
|
||||||
|
" db_size_mb, vacuum_size_mb, build_time_s, delete_time_s) "
|
||||||
|
"VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)",
|
||||||
|
(
|
||||||
|
r["name"], r["index_type"], params_json, dataset, r["subset_size"],
|
||||||
|
r["delete_pct"], r["n_deleted"], r["n_remaining"], k, n_queries, seed_val,
|
||||||
|
r["recall"], r["query_mean_ms"], r["query_median_ms"],
|
||||||
|
r["db_size_mb"], r["vacuum_size_mb"], r["build_time_s"], r["delete_time_s"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
db.close()
|
||||||
|
return db_path
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Reporting
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def print_report(all_results):
|
||||||
|
print(f"\n{'name':>22} {'del%':>5} {'deleted':>8} {'remain':>8} "
|
||||||
|
f"{'recall':>7} {'delta':>7} {'qry(ms)':>8} "
|
||||||
|
f"{'size(MB)':>9} {'vacuumed':>9} {'del(s)':>7}")
|
||||||
|
print("-" * 110)
|
||||||
|
|
||||||
|
# Group by config name
|
||||||
|
configs = {}
|
||||||
|
for r in all_results:
|
||||||
|
configs.setdefault(r["name"], []).append(r)
|
||||||
|
|
||||||
|
for name, rows in configs.items():
|
||||||
|
baseline_recall = rows[0]["recall"] # 0% delete is always first
|
||||||
|
for r in rows:
|
||||||
|
delta = r["recall"] - baseline_recall
|
||||||
|
delta_str = f"{delta:+.4f}" if r["delete_pct"] > 0 else "-"
|
||||||
|
print(
|
||||||
|
f"{r['name']:>22} {r['delete_pct']:>4}% "
|
||||||
|
f"{r['n_deleted']:>8} {r['n_remaining']:>8} "
|
||||||
|
f"{r['recall']:>7.4f} {delta_str:>7} {r['query_mean_ms']:>8.2f} "
|
||||||
|
f"{r['db_size_mb']:>9.1f} {r['vacuum_size_mb']:>9.1f} "
|
||||||
|
f"{r['delete_time_s']:>7.2f}"
|
||||||
|
)
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Main
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Benchmark recall degradation after random row deletion",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog=__doc__,
|
||||||
|
)
|
||||||
|
parser.add_argument("configs", nargs="+",
|
||||||
|
help="config specs (name:type=X,key=val,...)")
|
||||||
|
parser.add_argument("--subset-size", type=int, default=10000,
|
||||||
|
help="number of vectors to build (default: 10000)")
|
||||||
|
parser.add_argument("--delete-pct", type=str, default="10,25,50",
|
||||||
|
help="comma-separated delete percentages (default: 10,25,50)")
|
||||||
|
parser.add_argument("-k", type=int, default=10, help="KNN k (default 10)")
|
||||||
|
parser.add_argument("-n", type=int, default=50,
|
||||||
|
help="number of queries (default 50)")
|
||||||
|
parser.add_argument("--dataset", default="cohere1m",
|
||||||
|
choices=list(DATASETS.keys()))
|
||||||
|
parser.add_argument("--ext", default=EXT_PATH)
|
||||||
|
parser.add_argument("-o", "--out-dir",
|
||||||
|
default=os.path.join(_SCRIPT_DIR, "runs"))
|
||||||
|
parser.add_argument("--seed", type=int, default=42,
|
||||||
|
help="random seed for delete selection (default: 42)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
ds = DATASETS[args.dataset]
|
||||||
|
base_db = ds["base_db"]
|
||||||
|
dims = ds["dimensions"]
|
||||||
|
if not os.path.exists(base_db):
|
||||||
|
print(f"Error: dataset not found at {base_db}")
|
||||||
|
print(f"Run: make -C {os.path.dirname(base_db)}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
delete_pcts = [int(x.strip()) for x in args.delete_pct.split(",")]
|
||||||
|
for p in delete_pcts:
|
||||||
|
if not 0 < p < 100:
|
||||||
|
print(f"Error: delete percentage must be 1-99, got {p}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
out_dir = os.path.join(args.out_dir, args.dataset, str(args.subset_size))
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
|
||||||
|
all_results = []
|
||||||
|
for spec in args.configs:
|
||||||
|
name, params = parse_config(spec)
|
||||||
|
params_json = json.dumps(params)
|
||||||
|
results = run_delete_benchmark(
|
||||||
|
name, params, base_db, args.ext, args.subset_size, dims,
|
||||||
|
delete_pcts, args.k, args.n, out_dir, args.seed,
|
||||||
|
)
|
||||||
|
all_results.extend(results)
|
||||||
|
|
||||||
|
save_results(results, out_dir, args.dataset, args.subset_size,
|
||||||
|
params_json, args.k, args.n, args.seed)
|
||||||
|
|
||||||
|
print_report(all_results)
|
||||||
|
|
||||||
|
results_path = os.path.join(out_dir, "delete_results.db")
|
||||||
|
print(f"\nResults saved to: {results_path}")
|
||||||
|
print(f"Query: sqlite3 {results_path} "
|
||||||
|
f"\"SELECT config_name, delete_pct, recall, vacuum_size_mb "
|
||||||
|
f"FROM delete_runs ORDER BY config_name, delete_pct\"")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
124
benchmarks-ann/bench-delete/test_smoke.py
Normal file
124
benchmarks-ann/bench-delete/test_smoke.py
Normal file
|
|
@ -0,0 +1,124 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Quick self-contained smoke test using a synthetic dataset.
|
||||||
|
Creates a tiny base.db in a temp dir, runs the delete benchmark, verifies output.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sqlite3
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
_ROOT_DIR = os.path.join(_SCRIPT_DIR, "..", "..")
|
||||||
|
EXT_PATH = os.path.join(_ROOT_DIR, "dist", "vec0")
|
||||||
|
|
||||||
|
DIMS = 8
|
||||||
|
N_TRAIN = 200
|
||||||
|
N_QUERIES = 10
|
||||||
|
K_NEIGHBORS = 5
|
||||||
|
|
||||||
|
|
||||||
|
def _f32(vals):
|
||||||
|
return struct.pack(f"{len(vals)}f", *vals)
|
||||||
|
|
||||||
|
|
||||||
|
def make_synthetic_base_db(path):
|
||||||
|
"""Create a minimal base.db with train vectors and query vectors."""
|
||||||
|
rng = random.Random(123)
|
||||||
|
db = sqlite3.connect(path)
|
||||||
|
db.execute("CREATE TABLE train(id INTEGER PRIMARY KEY, vector BLOB)")
|
||||||
|
db.execute("CREATE TABLE query_vectors(id INTEGER PRIMARY KEY, vector BLOB)")
|
||||||
|
|
||||||
|
for i in range(N_TRAIN):
|
||||||
|
vec = [rng.gauss(0, 1) for _ in range(DIMS)]
|
||||||
|
db.execute("INSERT INTO train VALUES (?, ?)", (i, _f32(vec)))
|
||||||
|
|
||||||
|
for i in range(N_QUERIES):
|
||||||
|
vec = [rng.gauss(0, 1) for _ in range(DIMS)]
|
||||||
|
db.execute("INSERT INTO query_vectors VALUES (?, ?)", (i, _f32(vec)))
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if not os.path.exists(EXT_PATH + ".dylib") and not os.path.exists(EXT_PATH + ".so"):
|
||||||
|
# Try bare path (sqlite handles extension)
|
||||||
|
pass
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
base_db = os.path.join(tmpdir, "base.db")
|
||||||
|
make_synthetic_base_db(base_db)
|
||||||
|
|
||||||
|
# Patch DATASETS to use our synthetic DB
|
||||||
|
import bench_delete
|
||||||
|
bench_delete.DATASETS["synthetic"] = {
|
||||||
|
"base_db": base_db,
|
||||||
|
"dimensions": DIMS,
|
||||||
|
}
|
||||||
|
|
||||||
|
out_dir = os.path.join(tmpdir, "runs")
|
||||||
|
|
||||||
|
# Test flat index
|
||||||
|
print("=== Testing flat index ===")
|
||||||
|
name, params = bench_delete.parse_config("flat:type=vec0-flat,variant=float")
|
||||||
|
params["dimensions"] = DIMS
|
||||||
|
results = bench_delete.run_delete_benchmark(
|
||||||
|
name, params, base_db, EXT_PATH,
|
||||||
|
subset_size=N_TRAIN, dims=DIMS,
|
||||||
|
delete_pcts=[25, 50], k=K_NEIGHBORS, n_queries=N_QUERIES,
|
||||||
|
out_dir=out_dir, seed_val=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
bench_delete.print_report(results)
|
||||||
|
|
||||||
|
# Flat recall should be 1.0 at all delete %
|
||||||
|
for r in results:
|
||||||
|
assert r["recall"] == 1.0, \
|
||||||
|
f"Flat recall should be 1.0, got {r['recall']} at {r['delete_pct']}%"
|
||||||
|
print("\n PASS: flat recall is 1.0 at all delete percentages\n")
|
||||||
|
|
||||||
|
# Test DiskANN
|
||||||
|
print("=== Testing DiskANN ===")
|
||||||
|
name2, params2 = bench_delete.parse_config(
|
||||||
|
"diskann:type=diskann,R=8,L=32,quantizer=binary"
|
||||||
|
)
|
||||||
|
params2["dimensions"] = DIMS
|
||||||
|
results2 = bench_delete.run_delete_benchmark(
|
||||||
|
name2, params2, base_db, EXT_PATH,
|
||||||
|
subset_size=N_TRAIN, dims=DIMS,
|
||||||
|
delete_pcts=[25, 50], k=K_NEIGHBORS, n_queries=N_QUERIES,
|
||||||
|
out_dir=out_dir, seed_val=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
bench_delete.print_report(results2)
|
||||||
|
|
||||||
|
# DiskANN baseline (0%) should have decent recall
|
||||||
|
baseline = results2[0]
|
||||||
|
assert baseline["recall"] > 0.0, \
|
||||||
|
f"DiskANN baseline recall is zero"
|
||||||
|
print(f" PASS: DiskANN baseline recall={baseline['recall']}")
|
||||||
|
|
||||||
|
# Test rescore
|
||||||
|
print("\n=== Testing rescore ===")
|
||||||
|
name3, params3 = bench_delete.parse_config(
|
||||||
|
"rescore:type=rescore,quantizer=bit,oversample=4"
|
||||||
|
)
|
||||||
|
params3["dimensions"] = DIMS
|
||||||
|
results3 = bench_delete.run_delete_benchmark(
|
||||||
|
name3, params3, base_db, EXT_PATH,
|
||||||
|
subset_size=N_TRAIN, dims=DIMS,
|
||||||
|
delete_pcts=[25, 50], k=K_NEIGHBORS, n_queries=N_QUERIES,
|
||||||
|
out_dir=out_dir, seed_val=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
bench_delete.print_report(results3)
|
||||||
|
print(f" PASS: rescore baseline recall={results3[0]['recall']}")
|
||||||
|
|
||||||
|
print("\n ALL SMOKE TESTS PASSED")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
|
|
@ -456,7 +456,7 @@ def _ivf_create_table_sql(params):
|
||||||
def _ivf_post_insert_hook(conn, params):
|
def _ivf_post_insert_hook(conn, params):
|
||||||
print(" Training k-means centroids (built-in)...", flush=True)
|
print(" Training k-means centroids (built-in)...", flush=True)
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
conn.execute("INSERT INTO vec_items(id) VALUES ('compute-centroids')")
|
conn.execute("INSERT INTO vec_items(vec_items) VALUES ('compute-centroids')")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
elapsed = time.perf_counter() - t0
|
elapsed = time.perf_counter() - t0
|
||||||
print(f" Training done in {elapsed:.1f}s", flush=True)
|
print(f" Training done in {elapsed:.1f}s", flush=True)
|
||||||
|
|
@ -514,7 +514,7 @@ def _ivf_faiss_kmeans_hook(conn, params):
|
||||||
|
|
||||||
for cid, blob in centroids:
|
for cid, blob in centroids:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO vec_items(id, embedding) VALUES (?, ?)",
|
"INSERT INTO vec_items(vec_items, embedding) VALUES (?, ?)",
|
||||||
(f"set-centroid:{cid}", blob),
|
(f"set-centroid:{cid}", blob),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
@ -540,7 +540,7 @@ def _ivf_pre_query_hook(conn, params):
|
||||||
nprobe = params.get("nprobe")
|
nprobe = params.get("nprobe")
|
||||||
if nprobe:
|
if nprobe:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO vec_items(id) VALUES (?)",
|
"INSERT INTO vec_items(vec_items) VALUES (?)",
|
||||||
(f"nprobe={nprobe}",),
|
(f"nprobe={nprobe}",),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
@ -572,7 +572,7 @@ INDEX_REGISTRY["ivf"] = {
|
||||||
"insert_sql": None,
|
"insert_sql": None,
|
||||||
"post_insert_hook": _ivf_post_insert_hook,
|
"post_insert_hook": _ivf_post_insert_hook,
|
||||||
"pre_query_hook": _ivf_pre_query_hook,
|
"pre_query_hook": _ivf_pre_query_hook,
|
||||||
"train_sql": lambda _: "INSERT INTO vec_items(id) VALUES ('compute-centroids')",
|
"train_sql": lambda _: "INSERT INTO vec_items(vec_items) VALUES ('compute-centroids')",
|
||||||
"run_query": None,
|
"run_query": None,
|
||||||
"query_sql": None,
|
"query_sql": None,
|
||||||
"describe": _ivf_describe,
|
"describe": _ivf_describe,
|
||||||
|
|
@ -616,7 +616,7 @@ def _diskann_pre_query_hook(conn, params):
|
||||||
L_search = params.get("L_search", 0)
|
L_search = params.get("L_search", 0)
|
||||||
if L_search:
|
if L_search:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO vec_items(id) VALUES (?)",
|
"INSERT INTO vec_items(vec_items) VALUES (?)",
|
||||||
(f"search_list_size_search={L_search}",),
|
(f"search_list_size_search={L_search}",),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
mkdir -p vendor
|
mkdir -p vendor
|
||||||
curl -o sqlite-amalgamation.zip https://www.sqlite.org/2024/sqlite-amalgamation-3450300.zip
|
curl -o sqlite-amalgamation.zip https://www.sqlite.org/2024/sqlite-amalgamation-3450300.zip
|
||||||
unzip -d
|
|
||||||
unzip sqlite-amalgamation.zip
|
unzip sqlite-amalgamation.zip
|
||||||
mv sqlite-amalgamation-3450300/* vendor/
|
mv sqlite-amalgamation-3450300/* vendor/
|
||||||
rmdir sqlite-amalgamation-3450300
|
rmdir sqlite-amalgamation-3450300
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "sqlite-vec"
|
name = "sqlite-vec"
|
||||||
license = "MIT OR Apache"
|
license = "MIT OR Apache-2.0"
|
||||||
homepage = "https://alexgarcia.xyz/sqlite-vec"
|
homepage = "https://alexgarcia.xyz/sqlite-vec"
|
||||||
repo = "https://github.com/asg017/sqlite-vec"
|
repo = "https://github.com/asg017/sqlite-vec"
|
||||||
description = "A vector search SQLite extension."
|
description = "A vector search SQLite extension."
|
||||||
|
|
|
||||||
|
|
@ -410,9 +410,18 @@ static int diskann_node_read(vec0_vtab *p, int vec_col_idx, i64 rowid,
|
||||||
return SQLITE_NOMEM;
|
return SQLITE_NOMEM;
|
||||||
}
|
}
|
||||||
|
|
||||||
memcpy(v, sqlite3_column_blob(stmt, 0), vs);
|
const void *blobV = sqlite3_column_blob(stmt, 0);
|
||||||
memcpy(ids, sqlite3_column_blob(stmt, 1), is);
|
const void *blobIds = sqlite3_column_blob(stmt, 1);
|
||||||
memcpy(qv, sqlite3_column_blob(stmt, 2), qs);
|
const void *blobQv = sqlite3_column_blob(stmt, 2);
|
||||||
|
if (!blobV || !blobIds || !blobQv) {
|
||||||
|
sqlite3_free(v);
|
||||||
|
sqlite3_free(ids);
|
||||||
|
sqlite3_free(qv);
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
memcpy(v, blobV, vs);
|
||||||
|
memcpy(ids, blobIds, is);
|
||||||
|
memcpy(qv, blobQv, qs);
|
||||||
|
|
||||||
*outValidity = v; *outValiditySize = vs;
|
*outValidity = v; *outValiditySize = vs;
|
||||||
*outNeighborIds = ids; *outNeighborIdsSize = is;
|
*outNeighborIds = ids; *outNeighborIdsSize = is;
|
||||||
|
|
@ -480,9 +489,11 @@ static int diskann_vector_read(vec0_vtab *p, int vec_col_idx, i64 rowid,
|
||||||
}
|
}
|
||||||
|
|
||||||
int sz = sqlite3_column_bytes(stmt, 0);
|
int sz = sqlite3_column_bytes(stmt, 0);
|
||||||
|
const void *blob = sqlite3_column_blob(stmt, 0);
|
||||||
|
if (!blob || sz == 0) return SQLITE_ERROR;
|
||||||
void *vec = sqlite3_malloc(sz);
|
void *vec = sqlite3_malloc(sz);
|
||||||
if (!vec) return SQLITE_NOMEM;
|
if (!vec) return SQLITE_NOMEM;
|
||||||
memcpy(vec, sqlite3_column_blob(stmt, 0), sz);
|
memcpy(vec, blob, sz);
|
||||||
|
|
||||||
*outVector = vec;
|
*outVector = vec;
|
||||||
*outVectorSize = sz;
|
*outVectorSize = sz;
|
||||||
|
|
@ -597,6 +608,7 @@ static int diskann_candidate_list_insert(
|
||||||
list->items[lo].rowid = rowid;
|
list->items[lo].rowid = rowid;
|
||||||
list->items[lo].distance = distance;
|
list->items[lo].distance = distance;
|
||||||
list->items[lo].visited = 0;
|
list->items[lo].visited = 0;
|
||||||
|
list->items[lo].confirmed = 0;
|
||||||
list->count++;
|
list->count++;
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
@ -730,8 +742,9 @@ static int diskann_search(
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Seed with medoid
|
// Seed with medoid (confirmed — we already read its vector above)
|
||||||
diskann_candidate_list_insert(&candidates, medoid, medoidDist);
|
diskann_candidate_list_insert(&candidates, medoid, medoidDist);
|
||||||
|
candidates.items[0].confirmed = 1;
|
||||||
|
|
||||||
// Pre-quantize query vector once for all quantized distance comparisons
|
// Pre-quantize query vector once for all quantized distance comparisons
|
||||||
u8 *queryQuantized = NULL;
|
u8 *queryQuantized = NULL;
|
||||||
|
|
@ -804,16 +817,27 @@ static int diskann_search(
|
||||||
sqlite3_free(fullVec);
|
sqlite3_free(fullVec);
|
||||||
// Update distance in candidate list and re-sort
|
// Update distance in candidate list and re-sort
|
||||||
diskann_candidate_list_insert(&candidates, currentRowid, exactDist);
|
diskann_candidate_list_insert(&candidates, currentRowid, exactDist);
|
||||||
|
// Mark as confirmed (vector exists, distance is exact)
|
||||||
|
for (int ci = 0; ci < candidates.count; ci++) {
|
||||||
|
if (candidates.items[ci].rowid == currentRowid) {
|
||||||
|
candidates.items[ci].confirmed = 1;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
// If vector read failed, candidate stays unconfirmed (stale edge to deleted node)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. Output results (candidates are already sorted by distance)
|
// 5. Output results — only include confirmed candidates (whose vectors exist)
|
||||||
int resultCount = (candidates.count < k) ? candidates.count : k;
|
int resultCount = 0;
|
||||||
*outCount = resultCount;
|
for (int i = 0; i < candidates.count && resultCount < k; i++) {
|
||||||
for (int i = 0; i < resultCount; i++) {
|
if (candidates.items[i].confirmed) {
|
||||||
outRowids[i] = candidates.items[i].rowid;
|
outRowids[resultCount] = candidates.items[i].rowid;
|
||||||
outDistances[i] = candidates.items[i].distance;
|
outDistances[resultCount] = candidates.items[i].distance;
|
||||||
|
resultCount++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
*outCount = resultCount;
|
||||||
|
|
||||||
sqlite3_free(queryQuantized);
|
sqlite3_free(queryQuantized);
|
||||||
diskann_candidate_list_free(&candidates);
|
diskann_candidate_list_free(&candidates);
|
||||||
|
|
@ -1325,6 +1349,7 @@ static int diskann_flush_buffer(vec0_vtab *p, int vec_col_idx) {
|
||||||
while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) {
|
while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) {
|
||||||
i64 rowid = sqlite3_column_int64(stmt, 0);
|
i64 rowid = sqlite3_column_int64(stmt, 0);
|
||||||
const void *vector = sqlite3_column_blob(stmt, 1);
|
const void *vector = sqlite3_column_blob(stmt, 1);
|
||||||
|
if (!vector) continue;
|
||||||
// Note: vector is already written to _vectors table, so
|
// Note: vector is already written to _vectors table, so
|
||||||
// diskann_insert_graph will skip re-writing it (vector already exists).
|
// diskann_insert_graph will skip re-writing it (vector already exists).
|
||||||
// We call the graph-only insert path.
|
// We call the graph-only insert path.
|
||||||
|
|
@ -1596,13 +1621,14 @@ static int diskann_repair_reverse_edges(
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
diskann_node_write(p, vec_col_idx, nodeRowid,
|
rc = diskann_node_write(p, vec_col_idx, nodeRowid,
|
||||||
validity, vs, neighborIds, nis, qvecs, qs);
|
validity, vs, neighborIds, nis, qvecs, qs);
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlite3_free(validity);
|
sqlite3_free(validity);
|
||||||
sqlite3_free(neighborIds);
|
sqlite3_free(neighborIds);
|
||||||
sqlite3_free(qvecs);
|
sqlite3_free(qvecs);
|
||||||
|
if (rc != SQLITE_OK) return rc;
|
||||||
}
|
}
|
||||||
|
|
||||||
return SQLITE_OK;
|
return SQLITE_OK;
|
||||||
|
|
@ -1612,6 +1638,95 @@ static int diskann_repair_reverse_edges(
|
||||||
* Delete a vector from the DiskANN graph (Algorithm 3: LM-Delete).
|
* Delete a vector from the DiskANN graph (Algorithm 3: LM-Delete).
|
||||||
* If the vector is in the buffer (not yet flushed), just remove from buffer.
|
* If the vector is in the buffer (not yet flushed), just remove from buffer.
|
||||||
*/
|
*/
|
||||||
|
/**
|
||||||
|
* Scan all nodes and clear any neighbor slot referencing deleted_rowid.
|
||||||
|
* This removes stale reverse edges that the forward-edge repair misses,
|
||||||
|
* preventing data leaks (deleted rowid + quantized vector lingering in
|
||||||
|
* other nodes' blobs).
|
||||||
|
*/
|
||||||
|
static int diskann_scrub_deleted_rowid(
|
||||||
|
vec0_vtab *p, int vec_col_idx, i64 deleted_rowid) {
|
||||||
|
|
||||||
|
struct VectorColumnDefinition *col = &p->vector_columns[vec_col_idx];
|
||||||
|
struct Vec0DiskannConfig *cfg = &col->diskann;
|
||||||
|
int rc;
|
||||||
|
sqlite3_stmt *stmt = NULL;
|
||||||
|
|
||||||
|
// Lightweight scan: only read validity + neighbor_ids to find matches
|
||||||
|
char *zSql = sqlite3_mprintf(
|
||||||
|
"SELECT rowid, neighbors_validity, neighbor_ids "
|
||||||
|
"FROM " VEC0_SHADOW_DISKANN_NODES_N_NAME,
|
||||||
|
p->schemaName, p->tableName, vec_col_idx);
|
||||||
|
if (!zSql) return SQLITE_NOMEM;
|
||||||
|
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
|
||||||
|
sqlite3_free(zSql);
|
||||||
|
if (rc != SQLITE_OK) return rc;
|
||||||
|
|
||||||
|
// Collect rowids that need updating (avoid modifying while iterating)
|
||||||
|
i64 *dirty = NULL;
|
||||||
|
int nDirty = 0, capDirty = 0;
|
||||||
|
|
||||||
|
while (sqlite3_step(stmt) == SQLITE_ROW) {
|
||||||
|
const u8 *validity = (const u8 *)sqlite3_column_blob(stmt, 1);
|
||||||
|
const u8 *ids = (const u8 *)sqlite3_column_blob(stmt, 2);
|
||||||
|
int idsBytes = sqlite3_column_bytes(stmt, 2);
|
||||||
|
if (!validity || !ids) continue;
|
||||||
|
|
||||||
|
int nSlots = idsBytes / (int)sizeof(i64);
|
||||||
|
if (nSlots > cfg->n_neighbors) nSlots = cfg->n_neighbors;
|
||||||
|
|
||||||
|
for (int i = 0; i < nSlots; i++) {
|
||||||
|
if (!diskann_validity_get(validity, i)) continue;
|
||||||
|
i64 nid = diskann_neighbor_id_get(ids, i);
|
||||||
|
if (nid == deleted_rowid) {
|
||||||
|
i64 nodeRowid = sqlite3_column_int64(stmt, 0);
|
||||||
|
// Add to dirty list
|
||||||
|
if (nDirty >= capDirty) {
|
||||||
|
capDirty = capDirty ? capDirty * 2 : 16;
|
||||||
|
i64 *tmp = sqlite3_realloc64(dirty, capDirty * sizeof(i64));
|
||||||
|
if (!tmp) { sqlite3_free(dirty); sqlite3_finalize(stmt); return SQLITE_NOMEM; }
|
||||||
|
dirty = tmp;
|
||||||
|
}
|
||||||
|
dirty[nDirty++] = nodeRowid;
|
||||||
|
break; // one match per node is enough
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
|
||||||
|
// Now do full read/clear/write for each dirty node
|
||||||
|
for (int d = 0; d < nDirty; d++) {
|
||||||
|
u8 *val = NULL, *nids = NULL, *qvecs = NULL;
|
||||||
|
int vs, nis, qs;
|
||||||
|
rc = diskann_node_read(p, vec_col_idx, dirty[d],
|
||||||
|
&val, &vs, &nids, &nis, &qvecs, &qs);
|
||||||
|
if (rc != SQLITE_OK) continue;
|
||||||
|
|
||||||
|
int modified = 0;
|
||||||
|
for (int i = 0; i < cfg->n_neighbors; i++) {
|
||||||
|
if (diskann_validity_get(val, i) &&
|
||||||
|
diskann_neighbor_id_get(nids, i) == deleted_rowid) {
|
||||||
|
diskann_node_clear_neighbor(val, nids, qvecs, i,
|
||||||
|
cfg->quantizer_type, col->dimensions);
|
||||||
|
modified = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (modified) {
|
||||||
|
rc = diskann_node_write(p, vec_col_idx, dirty[d],
|
||||||
|
val, vs, nids, nis, qvecs, qs);
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_free(val);
|
||||||
|
sqlite3_free(nids);
|
||||||
|
sqlite3_free(qvecs);
|
||||||
|
if (rc != SQLITE_OK) break;
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_free(dirty);
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
|
||||||
static int diskann_delete(vec0_vtab *p, int vec_col_idx, i64 rowid) {
|
static int diskann_delete(vec0_vtab *p, int vec_col_idx, i64 rowid) {
|
||||||
struct VectorColumnDefinition *col = &p->vector_columns[vec_col_idx];
|
struct VectorColumnDefinition *col = &p->vector_columns[vec_col_idx];
|
||||||
struct Vec0DiskannConfig *cfg = &col->diskann;
|
struct Vec0DiskannConfig *cfg = &col->diskann;
|
||||||
|
|
@ -1680,6 +1795,12 @@ static int diskann_delete(vec0_vtab *p, int vec_col_idx, i64 rowid) {
|
||||||
rc = diskann_medoid_handle_delete(p, vec_col_idx, rowid);
|
rc = diskann_medoid_handle_delete(p, vec_col_idx, rowid);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 5. Scrub stale reverse edges — removes deleted rowid + quantized vector
|
||||||
|
// from any node that still references it (data leak prevention)
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
rc = diskann_scrub_deleted_rowid(p, vec_col_idx, rowid);
|
||||||
|
}
|
||||||
|
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -351,7 +351,9 @@ static int rescore_knn(vec0_vtab *p, vec0_cursor *pCur,
|
||||||
(void)pCur;
|
(void)pCur;
|
||||||
(void)aMetadataIn;
|
(void)aMetadataIn;
|
||||||
int rc = SQLITE_OK;
|
int rc = SQLITE_OK;
|
||||||
int oversample = vector_column->rescore.oversample;
|
int oversample = vector_column->rescore.oversample_search > 0
|
||||||
|
? vector_column->rescore.oversample_search
|
||||||
|
: vector_column->rescore.oversample;
|
||||||
i64 k_oversample = k * oversample;
|
i64 k_oversample = k * oversample;
|
||||||
if (k_oversample > 4096)
|
if (k_oversample > 4096)
|
||||||
k_oversample = 4096;
|
k_oversample = 4096;
|
||||||
|
|
@ -426,6 +428,18 @@ static int rescore_knn(vec0_vtab *p, vec0_cursor *pCur,
|
||||||
unsigned char *chunkValidity =
|
unsigned char *chunkValidity =
|
||||||
(unsigned char *)sqlite3_column_blob(stmtChunks, 1);
|
(unsigned char *)sqlite3_column_blob(stmtChunks, 1);
|
||||||
i64 *chunkRowids = (i64 *)sqlite3_column_blob(stmtChunks, 2);
|
i64 *chunkRowids = (i64 *)sqlite3_column_blob(stmtChunks, 2);
|
||||||
|
int validityBytes = sqlite3_column_bytes(stmtChunks, 1);
|
||||||
|
int rowidsBytes = sqlite3_column_bytes(stmtChunks, 2);
|
||||||
|
if (!chunkValidity || !chunkRowids) {
|
||||||
|
rc = SQLITE_ERROR;
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
// Validate blob sizes match chunk_size expectations
|
||||||
|
if (validityBytes < (p->chunk_size + 7) / 8 ||
|
||||||
|
rowidsBytes < p->chunk_size * (int)sizeof(i64)) {
|
||||||
|
rc = SQLITE_ERROR;
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
|
||||||
memset(chunk_distances, 0, p->chunk_size * sizeof(f32));
|
memset(chunk_distances, 0, p->chunk_size * sizeof(f32));
|
||||||
memset(chunk_topk_idxs, 0, k_oversample * sizeof(i32));
|
memset(chunk_topk_idxs, 0, k_oversample * sizeof(i32));
|
||||||
|
|
@ -461,7 +475,7 @@ static int rescore_knn(vec0_vtab *p, vec0_cursor *pCur,
|
||||||
for (int j = 0; j < p->chunk_size; j++) {
|
for (int j = 0; j < p->chunk_size; j++) {
|
||||||
if (!bitmap_get(b, j))
|
if (!bitmap_get(b, j))
|
||||||
continue;
|
continue;
|
||||||
f32 dist;
|
f32 dist = FLT_MAX;
|
||||||
switch (vector_column->rescore.quantizer_type) {
|
switch (vector_column->rescore.quantizer_type) {
|
||||||
case VEC0_RESCORE_QUANTIZER_BIT: {
|
case VEC0_RESCORE_QUANTIZER_BIT: {
|
||||||
const u8 *base_j = ((u8 *)baseVectors) + (j * (qdim / CHAR_BIT));
|
const u8 *base_j = ((u8 *)baseVectors) + (j * (qdim / CHAR_BIT));
|
||||||
|
|
@ -628,6 +642,27 @@ cleanup:
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle FTS5-style command dispatch for rescore parameters.
|
||||||
|
* Returns SQLITE_OK if handled, SQLITE_EMPTY if not a rescore command.
|
||||||
|
*/
|
||||||
|
static int rescore_handle_command(vec0_vtab *p, const char *command) {
|
||||||
|
if (strncmp(command, "oversample=", 11) == 0) {
|
||||||
|
int val = atoi(command + 11);
|
||||||
|
if (val < 1) {
|
||||||
|
vtab_set_error(&p->base, "oversample must be >= 1");
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
|
if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) {
|
||||||
|
p->vector_columns[i].rescore.oversample_search = val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return SQLITE_OK;
|
||||||
|
}
|
||||||
|
return SQLITE_EMPTY;
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef SQLITE_VEC_TEST
|
#ifdef SQLITE_VEC_TEST
|
||||||
void _test_rescore_quantize_float_to_bit(const float *src, uint8_t *dst, size_t dim) {
|
void _test_rescore_quantize_float_to_bit(const float *src, uint8_t *dst, size_t dim) {
|
||||||
rescore_quantize_float_to_bit(src, dst, dim);
|
rescore_quantize_float_to_bit(src, dst, dim);
|
||||||
|
|
|
||||||
581
sqlite-vec.c
581
sqlite-vec.c
|
|
@ -22,61 +22,10 @@ SQLITE_EXTENSION_INIT1
|
||||||
#include "sqlite3.h"
|
#include "sqlite3.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef UINT32_TYPE
|
|
||||||
#ifdef HAVE_UINT32_T
|
|
||||||
#define UINT32_TYPE uint32_t
|
|
||||||
#else
|
|
||||||
#define UINT32_TYPE unsigned int
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
#ifndef UINT16_TYPE
|
|
||||||
#ifdef HAVE_UINT16_T
|
|
||||||
#define UINT16_TYPE uint16_t
|
|
||||||
#else
|
|
||||||
#define UINT16_TYPE unsigned short int
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
#ifndef INT16_TYPE
|
|
||||||
#ifdef HAVE_INT16_T
|
|
||||||
#define INT16_TYPE int16_t
|
|
||||||
#else
|
|
||||||
#define INT16_TYPE short int
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
#ifndef UINT8_TYPE
|
|
||||||
#ifdef HAVE_UINT8_T
|
|
||||||
#define UINT8_TYPE uint8_t
|
|
||||||
#else
|
|
||||||
#define UINT8_TYPE unsigned char
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
#ifndef INT8_TYPE
|
|
||||||
#ifdef HAVE_INT8_T
|
|
||||||
#define INT8_TYPE int8_t
|
|
||||||
#else
|
|
||||||
#define INT8_TYPE signed char
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
#ifndef LONGDOUBLE_TYPE
|
|
||||||
#define LONGDOUBLE_TYPE long double
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifndef SQLITE_VEC_ENABLE_DISKANN
|
#ifndef SQLITE_VEC_ENABLE_DISKANN
|
||||||
#define SQLITE_VEC_ENABLE_DISKANN 1
|
#define SQLITE_VEC_ENABLE_DISKANN 1
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef _WIN32
|
|
||||||
#ifndef __EMSCRIPTEN__
|
|
||||||
#ifndef __COSMOPOLITAN__
|
|
||||||
#ifndef __wasi__
|
|
||||||
typedef u_int8_t uint8_t;
|
|
||||||
typedef u_int16_t uint16_t;
|
|
||||||
typedef u_int64_t uint64_t;
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
|
|
||||||
typedef int8_t i8;
|
typedef int8_t i8;
|
||||||
typedef uint8_t u8;
|
typedef uint8_t u8;
|
||||||
typedef int16_t i16;
|
typedef int16_t i16;
|
||||||
|
|
@ -309,13 +258,16 @@ static f32 l2_sqr_int8_neon(const void *pVect1v, const void *pVect2v,
|
||||||
pVect1 += 8;
|
pVect1 += 8;
|
||||||
pVect2 += 8;
|
pVect2 += 8;
|
||||||
|
|
||||||
// widen to protect against overflow
|
// widen i8 to i16 for subtraction
|
||||||
int16x8_t v1_wide = vmovl_s8(v1);
|
int16x8_t v1_wide = vmovl_s8(v1);
|
||||||
int16x8_t v2_wide = vmovl_s8(v2);
|
int16x8_t v2_wide = vmovl_s8(v2);
|
||||||
|
|
||||||
int16x8_t diff = vsubq_s16(v1_wide, v2_wide);
|
int16x8_t diff = vsubq_s16(v1_wide, v2_wide);
|
||||||
int16x8_t squared_diff = vmulq_s16(diff, diff);
|
|
||||||
int32x4_t sum = vpaddlq_s16(squared_diff);
|
// widening multiply: i16*i16 -> i32 to avoid i16 overflow
|
||||||
|
// (diff can be up to 255, so diff*diff can be up to 65025 > INT16_MAX)
|
||||||
|
int32x4_t sq_lo = vmull_s16(vget_low_s16(diff), vget_low_s16(diff));
|
||||||
|
int32x4_t sq_hi = vmull_s16(vget_high_s16(diff), vget_high_s16(diff));
|
||||||
|
int32x4_t sum = vaddq_s32(sq_lo, sq_hi);
|
||||||
|
|
||||||
sum_scalar += vgetq_lane_s32(sum, 0) + vgetq_lane_s32(sum, 1) +
|
sum_scalar += vgetq_lane_s32(sum, 0) + vgetq_lane_s32(sum, 1) +
|
||||||
vgetq_lane_s32(sum, 2) + vgetq_lane_s32(sum, 3);
|
vgetq_lane_s32(sum, 2) + vgetq_lane_s32(sum, 3);
|
||||||
|
|
@ -756,6 +708,58 @@ static f32 distance_hamming_neon(const u8 *a, const u8 *b, size_t n_bytes) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef SQLITE_VEC_ENABLE_AVX
|
||||||
|
/**
|
||||||
|
* AVX2 Hamming distance using VPSHUFB-based popcount.
|
||||||
|
* Processes 32 bytes (256 bits) per iteration.
|
||||||
|
*/
|
||||||
|
static f32 distance_hamming_avx2(const u8 *a, const u8 *b, size_t n_bytes) {
|
||||||
|
const u8 *pEnd = a + n_bytes;
|
||||||
|
|
||||||
|
// VPSHUFB lookup table: popcount of low nibble
|
||||||
|
const __m256i lookup = _mm256_setr_epi8(
|
||||||
|
0,1,1,2,1,2,2,3,1,2,2,3,2,3,3,4,
|
||||||
|
0,1,1,2,1,2,2,3,1,2,2,3,2,3,3,4);
|
||||||
|
const __m256i low_mask = _mm256_set1_epi8(0x0f);
|
||||||
|
|
||||||
|
__m256i acc = _mm256_setzero_si256();
|
||||||
|
|
||||||
|
while (a <= pEnd - 32) {
|
||||||
|
__m256i va = _mm256_loadu_si256((const __m256i *)a);
|
||||||
|
__m256i vb = _mm256_loadu_si256((const __m256i *)b);
|
||||||
|
__m256i xored = _mm256_xor_si256(va, vb);
|
||||||
|
|
||||||
|
// VPSHUFB popcount: split into nibbles, lookup each
|
||||||
|
__m256i lo = _mm256_and_si256(xored, low_mask);
|
||||||
|
__m256i hi = _mm256_and_si256(_mm256_srli_epi16(xored, 4), low_mask);
|
||||||
|
__m256i popcnt = _mm256_add_epi8(_mm256_shuffle_epi8(lookup, lo),
|
||||||
|
_mm256_shuffle_epi8(lookup, hi));
|
||||||
|
|
||||||
|
// Horizontal sum: u8 -> u64 via sad against zero
|
||||||
|
acc = _mm256_add_epi64(acc, _mm256_sad_epu8(popcnt, _mm256_setzero_si256()));
|
||||||
|
a += 32;
|
||||||
|
b += 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Horizontal sum of 4 x u64 lanes
|
||||||
|
u64 tmp[4];
|
||||||
|
_mm256_storeu_si256((__m256i *)tmp, acc);
|
||||||
|
u32 sum = (u32)(tmp[0] + tmp[1] + tmp[2] + tmp[3]);
|
||||||
|
|
||||||
|
// Scalar tail
|
||||||
|
while (a < pEnd) {
|
||||||
|
u8 x = *a ^ *b;
|
||||||
|
x = x - ((x >> 1) & 0x55);
|
||||||
|
x = (x & 0x33) + ((x >> 2) & 0x33);
|
||||||
|
sum += (x + (x >> 4)) & 0x0F;
|
||||||
|
a++;
|
||||||
|
b++;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (f32)sum;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
static f32 distance_hamming_u8(u8 *a, u8 *b, size_t n) {
|
static f32 distance_hamming_u8(u8 *a, u8 *b, size_t n) {
|
||||||
int same = 0;
|
int same = 0;
|
||||||
for (unsigned long i = 0; i < n; i++) {
|
for (unsigned long i = 0; i < n; i++) {
|
||||||
|
|
@ -782,10 +786,13 @@ static unsigned int __builtin_popcountl(unsigned int x) {
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
static f32 distance_hamming_u64(u64 *a, u64 *b, size_t n) {
|
static f32 distance_hamming_u64(const u8 *a, const u8 *b, size_t n) {
|
||||||
int same = 0;
|
int same = 0;
|
||||||
for (unsigned long i = 0; i < n; i++) {
|
for (unsigned long i = 0; i < n; i++) {
|
||||||
same += __builtin_popcountl(a[i] ^ b[i]);
|
u64 va, vb;
|
||||||
|
memcpy(&va, a + i * sizeof(u64), sizeof(u64));
|
||||||
|
memcpy(&vb, b + i * sizeof(u64), sizeof(u64));
|
||||||
|
same += __builtin_popcountl(va ^ vb);
|
||||||
}
|
}
|
||||||
return (f32)same;
|
return (f32)same;
|
||||||
}
|
}
|
||||||
|
|
@ -807,9 +814,14 @@ static f32 distance_hamming(const void *a, const void *b, const void *d) {
|
||||||
return distance_hamming_neon((const u8 *)a, (const u8 *)b, n_bytes);
|
return distance_hamming_neon((const u8 *)a, (const u8 *)b, n_bytes);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
#ifdef SQLITE_VEC_ENABLE_AVX
|
||||||
|
if (n_bytes >= 32) {
|
||||||
|
return distance_hamming_avx2((const u8 *)a, (const u8 *)b, n_bytes);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
if ((dimensions % 64) == 0) {
|
if ((dimensions % 64) == 0) {
|
||||||
return distance_hamming_u64((u64 *)a, (u64 *)b, n_bytes / sizeof(u64));
|
return distance_hamming_u64((const u8 *)a, (const u8 *)b, n_bytes / sizeof(u64));
|
||||||
}
|
}
|
||||||
return distance_hamming_u8((u8 *)a, (u8 *)b, n_bytes);
|
return distance_hamming_u8((u8 *)a, (u8 *)b, n_bytes);
|
||||||
}
|
}
|
||||||
|
|
@ -972,8 +984,18 @@ static int fvec_from_value(sqlite3_value *value, f32 **vector,
|
||||||
return SQLITE_NOMEM;
|
return SQLITE_NOMEM;
|
||||||
}
|
}
|
||||||
memcpy(buf, blob, bytes);
|
memcpy(buf, blob, bytes);
|
||||||
|
size_t n = bytes / sizeof(f32);
|
||||||
|
for (size_t i = 0; i < n; i++) {
|
||||||
|
if (isnan(buf[i]) || isinf(buf[i])) {
|
||||||
|
*pzErr = sqlite3_mprintf(
|
||||||
|
"invalid float32 vector: element %d is %s",
|
||||||
|
(int)i, isnan(buf[i]) ? "NaN" : "Inf");
|
||||||
|
sqlite3_free(buf);
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
}
|
||||||
*vector = buf;
|
*vector = buf;
|
||||||
*dimensions = bytes / sizeof(f32);
|
*dimensions = n;
|
||||||
*cleanup = sqlite3_free;
|
*cleanup = sqlite3_free;
|
||||||
return SQLITE_OK;
|
return SQLITE_OK;
|
||||||
}
|
}
|
||||||
|
|
@ -1041,6 +1063,13 @@ static int fvec_from_value(sqlite3_value *value, f32 **vector,
|
||||||
}
|
}
|
||||||
|
|
||||||
f32 res = (f32)result;
|
f32 res = (f32)result;
|
||||||
|
if (isnan(res) || isinf(res)) {
|
||||||
|
sqlite3_free(x.z);
|
||||||
|
*pzErr = sqlite3_mprintf(
|
||||||
|
"invalid float32 vector: element %d is %s",
|
||||||
|
(int)x.length, isnan(res) ? "NaN" : "Inf");
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
array_append(&x, (const void *)&res);
|
array_append(&x, (const void *)&res);
|
||||||
|
|
||||||
offset += (endptr - ptr);
|
offset += (endptr - ptr);
|
||||||
|
|
@ -2559,7 +2588,8 @@ enum Vec0RescoreQuantizerType {
|
||||||
|
|
||||||
struct Vec0RescoreConfig {
|
struct Vec0RescoreConfig {
|
||||||
enum Vec0RescoreQuantizerType quantizer_type;
|
enum Vec0RescoreQuantizerType quantizer_type;
|
||||||
int oversample;
|
int oversample; // CREATE-time default
|
||||||
|
int oversample_search; // runtime override (0 = use default)
|
||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
@ -2631,7 +2661,8 @@ struct Vec0DiskannConfig {
|
||||||
struct Vec0DiskannCandidate {
|
struct Vec0DiskannCandidate {
|
||||||
i64 rowid;
|
i64 rowid;
|
||||||
f32 distance;
|
f32 distance;
|
||||||
int visited; // 1 if this candidate's neighbors have been explored
|
int visited; // 1 if this candidate's neighbors have been explored
|
||||||
|
int confirmed; // 1 if full-precision vector was successfully read (node exists)
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -3125,6 +3156,9 @@ int vec0_parse_vector_column(const char *source, int source_length,
|
||||||
if (rc != SQLITE_OK) {
|
if (rc != SQLITE_OK) {
|
||||||
return SQLITE_ERROR;
|
return SQLITE_ERROR;
|
||||||
}
|
}
|
||||||
|
if (ivfConfig.quantizer == VEC0_IVF_QUANTIZER_BINARY && (dimensions % 8) != 0) {
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
return SQLITE_ERROR; // IVF not compiled in
|
return SQLITE_ERROR; // IVF not compiled in
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -3366,8 +3400,9 @@ static sqlite3_module vec_eachModule = {
|
||||||
|
|
||||||
#define VEC0_COLUMN_ID 0
|
#define VEC0_COLUMN_ID 0
|
||||||
#define VEC0_COLUMN_USERN_START 1
|
#define VEC0_COLUMN_USERN_START 1
|
||||||
#define VEC0_COLUMN_OFFSET_DISTANCE 1
|
#define VEC0_COLUMN_OFFSET_COMMAND 1
|
||||||
#define VEC0_COLUMN_OFFSET_K 2
|
#define VEC0_COLUMN_OFFSET_DISTANCE 2
|
||||||
|
#define VEC0_COLUMN_OFFSET_K 3
|
||||||
|
|
||||||
#define VEC0_SHADOW_INFO_NAME "\"%w\".\"%w_info\""
|
#define VEC0_SHADOW_INFO_NAME "\"%w\".\"%w_info\""
|
||||||
|
|
||||||
|
|
@ -3465,6 +3500,10 @@ struct vec0_vtab {
|
||||||
// Will change the schema of the _rowids table, and insert/query logic.
|
// Will change the schema of the _rowids table, and insert/query logic.
|
||||||
int pkIsText;
|
int pkIsText;
|
||||||
|
|
||||||
|
// True if the hidden command column (named after the table) exists.
|
||||||
|
// Tables created before v0.1.10 or without _info table don't have it.
|
||||||
|
int hasCommandColumn;
|
||||||
|
|
||||||
// number of defined vector columns.
|
// number of defined vector columns.
|
||||||
int numVectorColumns;
|
int numVectorColumns;
|
||||||
|
|
||||||
|
|
@ -3744,20 +3783,19 @@ int vec0_num_defined_user_columns(vec0_vtab *p) {
|
||||||
* @param p vec0 table
|
* @param p vec0 table
|
||||||
* @return int
|
* @return int
|
||||||
*/
|
*/
|
||||||
int vec0_column_distance_idx(vec0_vtab *p) {
|
int vec0_column_command_idx(vec0_vtab *p) {
|
||||||
return VEC0_COLUMN_USERN_START + (vec0_num_defined_user_columns(p) - 1) +
|
// Command column is the first hidden column (right after user columns)
|
||||||
VEC0_COLUMN_OFFSET_DISTANCE;
|
return VEC0_COLUMN_USERN_START + vec0_num_defined_user_columns(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
int vec0_column_distance_idx(vec0_vtab *p) {
|
||||||
|
int base = VEC0_COLUMN_USERN_START + vec0_num_defined_user_columns(p);
|
||||||
|
return base + (p->hasCommandColumn ? 1 : 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Returns the index of the k hidden column for the given vec0 table.
|
|
||||||
*
|
|
||||||
* @param p vec0 table
|
|
||||||
* @return int k column index
|
|
||||||
*/
|
|
||||||
int vec0_column_k_idx(vec0_vtab *p) {
|
int vec0_column_k_idx(vec0_vtab *p) {
|
||||||
return VEC0_COLUMN_USERN_START + (vec0_num_defined_user_columns(p) - 1) +
|
int base = VEC0_COLUMN_USERN_START + vec0_num_defined_user_columns(p);
|
||||||
VEC0_COLUMN_OFFSET_K;
|
return base + (p->hasCommandColumn ? 2 : 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -4676,16 +4714,10 @@ int vec0_new_chunk(vec0_vtab *p, sqlite3_value ** partitionKeyValues, i64 *chunk
|
||||||
}
|
}
|
||||||
int vector_column_idx = p->user_column_idxs[i];
|
int vector_column_idx = p->user_column_idxs[i];
|
||||||
|
|
||||||
#if SQLITE_VEC_ENABLE_RESCORE
|
// Non-FLAT columns (rescore, IVF, DiskANN) don't use _vector_chunks
|
||||||
// Rescore and IVF columns don't use _vector_chunks for float storage
|
if (p->vector_columns[vector_column_idx].index_type != VEC0_INDEX_TYPE_FLAT) {
|
||||||
if (p->vector_columns[vector_column_idx].index_type == VEC0_INDEX_TYPE_RESCORE
|
|
||||||
#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE
|
|
||||||
|| p->vector_columns[vector_column_idx].index_type == VEC0_INDEX_TYPE_IVF
|
|
||||||
#endif
|
|
||||||
) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
i64 vectorsSize =
|
i64 vectorsSize =
|
||||||
p->chunk_size * vector_column_byte_size(p->vector_columns[vector_column_idx]);
|
p->chunk_size * vector_column_byte_size(p->vector_columns[vector_column_idx]);
|
||||||
|
|
@ -5122,11 +5154,6 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (hasRescore) {
|
if (hasRescore) {
|
||||||
if (numAuxiliaryColumns > 0) {
|
|
||||||
*pzErr = sqlite3_mprintf(VEC_CONSTRUCTOR_ERROR
|
|
||||||
"Auxiliary columns are not supported with rescore indexes");
|
|
||||||
goto error;
|
|
||||||
}
|
|
||||||
if (numMetadataColumns > 0) {
|
if (numMetadataColumns > 0) {
|
||||||
*pzErr = sqlite3_mprintf(VEC_CONSTRUCTOR_ERROR
|
*pzErr = sqlite3_mprintf(VEC_CONSTRUCTOR_ERROR
|
||||||
"Metadata columns are not supported with rescore indexes");
|
"Metadata columns are not supported with rescore indexes");
|
||||||
|
|
@ -5156,11 +5183,6 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
|
||||||
"partition key columns are not supported with IVF indexes");
|
"partition key columns are not supported with IVF indexes");
|
||||||
goto error;
|
goto error;
|
||||||
}
|
}
|
||||||
if (numAuxiliaryColumns > 0) {
|
|
||||||
*pzErr = sqlite3_mprintf(VEC_CONSTRUCTOR_ERROR
|
|
||||||
"auxiliary columns are not supported with IVF indexes");
|
|
||||||
goto error;
|
|
||||||
}
|
|
||||||
if (numMetadataColumns > 0) {
|
if (numMetadataColumns > 0) {
|
||||||
*pzErr = sqlite3_mprintf(VEC_CONSTRUCTOR_ERROR
|
*pzErr = sqlite3_mprintf(VEC_CONSTRUCTOR_ERROR
|
||||||
"metadata columns are not supported with IVF indexes");
|
"metadata columns are not supported with IVF indexes");
|
||||||
|
|
@ -5172,12 +5194,6 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
|
||||||
// DiskANN columns cannot coexist with aux/metadata/partition columns
|
// DiskANN columns cannot coexist with aux/metadata/partition columns
|
||||||
for (int i = 0; i < numVectorColumns; i++) {
|
for (int i = 0; i < numVectorColumns; i++) {
|
||||||
if (pNew->vector_columns[i].index_type == VEC0_INDEX_TYPE_DISKANN) {
|
if (pNew->vector_columns[i].index_type == VEC0_INDEX_TYPE_DISKANN) {
|
||||||
if (numAuxiliaryColumns > 0) {
|
|
||||||
*pzErr = sqlite3_mprintf(
|
|
||||||
VEC_CONSTRUCTOR_ERROR
|
|
||||||
"Auxiliary columns are not supported with DiskANN-indexed vector columns");
|
|
||||||
goto error;
|
|
||||||
}
|
|
||||||
if (numMetadataColumns > 0) {
|
if (numMetadataColumns > 0) {
|
||||||
*pzErr = sqlite3_mprintf(
|
*pzErr = sqlite3_mprintf(
|
||||||
VEC_CONSTRUCTOR_ERROR
|
VEC_CONSTRUCTOR_ERROR
|
||||||
|
|
@ -5194,6 +5210,74 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Determine whether to add the FTS5-style hidden command column.
|
||||||
|
// New tables (isCreate) always get it; existing tables only if created
|
||||||
|
// with v0.1.10+ (which validated no column name == table name).
|
||||||
|
int hasCommandColumn = 0;
|
||||||
|
if (isCreate) {
|
||||||
|
// Validate no user column name conflicts with the table name
|
||||||
|
const char *tblName = argv[2];
|
||||||
|
int tblNameLen = (int)strlen(tblName);
|
||||||
|
for (int i = 0; i < numVectorColumns; i++) {
|
||||||
|
if (pNew->vector_columns[i].name_length == tblNameLen &&
|
||||||
|
sqlite3_strnicmp(pNew->vector_columns[i].name, tblName, tblNameLen) == 0) {
|
||||||
|
*pzErr = sqlite3_mprintf(
|
||||||
|
VEC_CONSTRUCTOR_ERROR
|
||||||
|
"column name '%s' conflicts with table name (reserved for command column)",
|
||||||
|
tblName);
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < numPartitionColumns; i++) {
|
||||||
|
if (pNew->paritition_columns[i].name_length == tblNameLen &&
|
||||||
|
sqlite3_strnicmp(pNew->paritition_columns[i].name, tblName, tblNameLen) == 0) {
|
||||||
|
*pzErr = sqlite3_mprintf(
|
||||||
|
VEC_CONSTRUCTOR_ERROR
|
||||||
|
"column name '%s' conflicts with table name (reserved for command column)",
|
||||||
|
tblName);
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < numAuxiliaryColumns; i++) {
|
||||||
|
if (pNew->auxiliary_columns[i].name_length == tblNameLen &&
|
||||||
|
sqlite3_strnicmp(pNew->auxiliary_columns[i].name, tblName, tblNameLen) == 0) {
|
||||||
|
*pzErr = sqlite3_mprintf(
|
||||||
|
VEC_CONSTRUCTOR_ERROR
|
||||||
|
"column name '%s' conflicts with table name (reserved for command column)",
|
||||||
|
tblName);
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < numMetadataColumns; i++) {
|
||||||
|
if (pNew->metadata_columns[i].name_length == tblNameLen &&
|
||||||
|
sqlite3_strnicmp(pNew->metadata_columns[i].name, tblName, tblNameLen) == 0) {
|
||||||
|
*pzErr = sqlite3_mprintf(
|
||||||
|
VEC_CONSTRUCTOR_ERROR
|
||||||
|
"column name '%s' conflicts with table name (reserved for command column)",
|
||||||
|
tblName);
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
hasCommandColumn = 1;
|
||||||
|
} else {
|
||||||
|
// xConnect: check _info shadow table for version
|
||||||
|
sqlite3_stmt *stmtInfo = NULL;
|
||||||
|
char *zInfoSql = sqlite3_mprintf(
|
||||||
|
"SELECT value FROM " VEC0_SHADOW_INFO_NAME " WHERE key = 'CREATE_VERSION_PATCH'",
|
||||||
|
argv[1], argv[2]);
|
||||||
|
if (zInfoSql) {
|
||||||
|
int infoRc = sqlite3_prepare_v2(db, zInfoSql, -1, &stmtInfo, NULL);
|
||||||
|
sqlite3_free(zInfoSql);
|
||||||
|
if (infoRc == SQLITE_OK && sqlite3_step(stmtInfo) == SQLITE_ROW) {
|
||||||
|
int patch = sqlite3_column_int(stmtInfo, 0);
|
||||||
|
hasCommandColumn = (patch >= 10); // v0.1.10+
|
||||||
|
}
|
||||||
|
// If _info doesn't exist or has no version, assume old table
|
||||||
|
sqlite3_finalize(stmtInfo);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pNew->hasCommandColumn = hasCommandColumn;
|
||||||
|
|
||||||
sqlite3_str *createStr = sqlite3_str_new(NULL);
|
sqlite3_str *createStr = sqlite3_str_new(NULL);
|
||||||
sqlite3_str_appendall(createStr, "CREATE TABLE x(");
|
sqlite3_str_appendall(createStr, "CREATE TABLE x(");
|
||||||
if (pkColumnName) {
|
if (pkColumnName) {
|
||||||
|
|
@ -5235,7 +5319,11 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
sqlite3_str_appendall(createStr, " distance hidden, k hidden) ");
|
if (hasCommandColumn) {
|
||||||
|
sqlite3_str_appendf(createStr, " \"%w\" hidden, distance hidden, k hidden) ", argv[2]);
|
||||||
|
} else {
|
||||||
|
sqlite3_str_appendall(createStr, " distance hidden, k hidden) ");
|
||||||
|
}
|
||||||
if (pkColumnName) {
|
if (pkColumnName) {
|
||||||
sqlite3_str_appendall(createStr, "without rowid ");
|
sqlite3_str_appendall(createStr, "without rowid ");
|
||||||
}
|
}
|
||||||
|
|
@ -5469,11 +5557,9 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
|
||||||
sqlite3_finalize(stmt);
|
sqlite3_finalize(stmt);
|
||||||
|
|
||||||
for (int i = 0; i < pNew->numVectorColumns; i++) {
|
for (int i = 0; i < pNew->numVectorColumns; i++) {
|
||||||
#if SQLITE_VEC_ENABLE_RESCORE
|
// Non-FLAT columns (rescore, IVF, DiskANN) don't use _vector_chunks
|
||||||
// Non-FLAT columns don't use _vector_chunks
|
|
||||||
if (pNew->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT)
|
if (pNew->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT)
|
||||||
continue;
|
continue;
|
||||||
#endif
|
|
||||||
char *zSql = sqlite3_mprintf(VEC0_SHADOW_VECTOR_N_CREATE,
|
char *zSql = sqlite3_mprintf(VEC0_SHADOW_VECTOR_N_CREATE,
|
||||||
pNew->schemaName, pNew->tableName, i);
|
pNew->schemaName, pNew->tableName, i);
|
||||||
if (!zSql) {
|
if (!zSql) {
|
||||||
|
|
@ -5762,10 +5848,9 @@ static int vec0Destroy(sqlite3_vtab *pVtab) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
#if SQLITE_VEC_ENABLE_RESCORE
|
// Non-FLAT columns (rescore, IVF, DiskANN) don't use _vector_chunks
|
||||||
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT)
|
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT)
|
||||||
continue;
|
continue;
|
||||||
#endif
|
|
||||||
zSql = sqlite3_mprintf("DROP TABLE \"%w\".\"%w\"", p->schemaName,
|
zSql = sqlite3_mprintf("DROP TABLE \"%w\".\"%w\"", p->schemaName,
|
||||||
p->shadowVectorChunksNames[i]);
|
p->shadowVectorChunksNames[i]);
|
||||||
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0);
|
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0);
|
||||||
|
|
@ -8815,15 +8900,9 @@ int vec0Update_InsertWriteFinalStep(vec0_vtab *p, i64 chunk_rowid,
|
||||||
|
|
||||||
// Go insert the vector data into the vector chunk shadow tables
|
// Go insert the vector data into the vector chunk shadow tables
|
||||||
for (int i = 0; i < p->numVectorColumns; i++) {
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
#if SQLITE_VEC_ENABLE_RESCORE
|
// Non-FLAT columns (rescore, IVF, DiskANN) don't use _vector_chunks
|
||||||
// Rescore and IVF columns don't use _vector_chunks
|
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT)
|
||||||
if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE
|
|
||||||
#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE
|
|
||||||
|| p->vector_columns[i].index_type == VEC0_INDEX_TYPE_IVF
|
|
||||||
#endif
|
|
||||||
)
|
|
||||||
continue;
|
continue;
|
||||||
#endif
|
|
||||||
|
|
||||||
sqlite3_blob *blobVectors;
|
sqlite3_blob *blobVectors;
|
||||||
rc = sqlite3_blob_open(p->db, p->schemaName, p->shadowVectorChunksNames[i],
|
rc = sqlite3_blob_open(p->db, p->schemaName, p->shadowVectorChunksNames[i],
|
||||||
|
|
@ -9082,6 +9161,9 @@ int vec0_write_metadata_value(vec0_vtab *p, int metadata_column_idx, i64 rowid,
|
||||||
*
|
*
|
||||||
* @return int SQLITE_OK on success, otherwise error code on failure
|
* @return int SQLITE_OK on success, otherwise error code on failure
|
||||||
*/
|
*/
|
||||||
|
// Forward declaration: needed for INSERT OR REPLACE handling in vec0Update_Insert
|
||||||
|
int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite3_value *idValue);
|
||||||
|
|
||||||
int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
|
int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
|
||||||
sqlite_int64 *pRowid) {
|
sqlite_int64 *pRowid) {
|
||||||
UNUSED_PARAMETER(argc);
|
UNUSED_PARAMETER(argc);
|
||||||
|
|
@ -9202,6 +9284,44 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
|
||||||
goto cleanup;
|
goto cleanup;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle INSERT OR REPLACE: if the conflict resolution is REPLACE and the
|
||||||
|
// row already exists, delete the existing row first before inserting.
|
||||||
|
if (sqlite3_vtab_on_conflict(p->db) == SQLITE_REPLACE) {
|
||||||
|
sqlite3_value *idValue = argv[2 + VEC0_COLUMN_ID];
|
||||||
|
int idType = sqlite3_value_type(idValue);
|
||||||
|
int existingRowExists = 0;
|
||||||
|
|
||||||
|
if (p->pkIsText && idType == SQLITE_TEXT) {
|
||||||
|
i64 existingRowid;
|
||||||
|
rc = vec0_rowid_from_id(p, idValue, &existingRowid);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
existingRowExists = 1;
|
||||||
|
} else if (rc == SQLITE_EMPTY) {
|
||||||
|
rc = SQLITE_OK; // row doesn't exist, proceed with normal insert
|
||||||
|
} else {
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
} else if (!p->pkIsText && idType == SQLITE_INTEGER) {
|
||||||
|
i64 existingRowid = sqlite3_value_int64(idValue);
|
||||||
|
i64 chunk_id_tmp, chunk_offset_tmp;
|
||||||
|
rc = vec0_get_chunk_position(p, existingRowid, NULL, &chunk_id_tmp, &chunk_offset_tmp);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
existingRowExists = 1;
|
||||||
|
} else if (rc == SQLITE_EMPTY) {
|
||||||
|
rc = SQLITE_OK; // row doesn't exist, proceed with normal insert
|
||||||
|
} else {
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (existingRowExists) {
|
||||||
|
rc = vec0Update_Delete(pVTab, idValue);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Step #1: Insert/get a rowid for this row, from the _rowids table.
|
// Step #1: Insert/get a rowid for this row, from the _rowids table.
|
||||||
rc = vec0Update_InsertRowidStep(p, argv[2 + VEC0_COLUMN_ID], &rowid);
|
rc = vec0Update_InsertRowidStep(p, argv[2 + VEC0_COLUMN_ID], &rowid);
|
||||||
if (rc != SQLITE_OK) {
|
if (rc != SQLITE_OK) {
|
||||||
|
|
@ -9449,11 +9569,9 @@ int vec0Update_Delete_ClearVectors(vec0_vtab *p, i64 chunk_id,
|
||||||
u64 chunk_offset) {
|
u64 chunk_offset) {
|
||||||
int rc, brc;
|
int rc, brc;
|
||||||
for (int i = 0; i < p->numVectorColumns; i++) {
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
#if SQLITE_VEC_ENABLE_RESCORE
|
// Non-FLAT columns (rescore, IVF, DiskANN) don't use _vector_chunks
|
||||||
// Non-FLAT columns don't use _vector_chunks
|
|
||||||
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT)
|
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT)
|
||||||
continue;
|
continue;
|
||||||
#endif
|
|
||||||
sqlite3_blob *blobVectors = NULL;
|
sqlite3_blob *blobVectors = NULL;
|
||||||
size_t n = vector_column_byte_size(p->vector_columns[i]);
|
size_t n = vector_column_byte_size(p->vector_columns[i]);
|
||||||
|
|
||||||
|
|
@ -9565,10 +9683,9 @@ int vec0Update_Delete_DeleteChunkIfEmpty(vec0_vtab *p, i64 chunk_id,
|
||||||
|
|
||||||
// Delete from each _vector_chunksNN
|
// Delete from each _vector_chunksNN
|
||||||
for (int i = 0; i < p->numVectorColumns; i++) {
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
#if SQLITE_VEC_ENABLE_RESCORE
|
// Non-FLAT columns (rescore, IVF, DiskANN) don't use _vector_chunks
|
||||||
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT)
|
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT)
|
||||||
continue;
|
continue;
|
||||||
#endif
|
|
||||||
zSql = sqlite3_mprintf(
|
zSql = sqlite3_mprintf(
|
||||||
"DELETE FROM " VEC0_SHADOW_VECTOR_N_NAME " WHERE rowid = ?",
|
"DELETE FROM " VEC0_SHADOW_VECTOR_N_NAME " WHERE rowid = ?",
|
||||||
p->schemaName, p->tableName, i);
|
p->schemaName, p->tableName, i);
|
||||||
|
|
@ -9762,8 +9879,8 @@ int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite3_value *idValue) {
|
||||||
vec0_vtab *p = (vec0_vtab *)pVTab;
|
vec0_vtab *p = (vec0_vtab *)pVTab;
|
||||||
int rc;
|
int rc;
|
||||||
i64 rowid;
|
i64 rowid;
|
||||||
i64 chunk_id;
|
i64 chunk_id = 0;
|
||||||
i64 chunk_offset;
|
i64 chunk_offset = 0;
|
||||||
|
|
||||||
if (p->pkIsText) {
|
if (p->pkIsText) {
|
||||||
rc = vec0_rowid_from_id(p, idValue, &rowid);
|
rc = vec0_rowid_from_id(p, idValue, &rowid);
|
||||||
|
|
@ -9815,16 +9932,15 @@ int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite3_value *idValue) {
|
||||||
if (rc != SQLITE_OK) {
|
if (rc != SQLITE_OK) {
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
#if SQLITE_VEC_ENABLE_RESCORE
|
#if SQLITE_VEC_ENABLE_RESCORE
|
||||||
// 4b. zero out quantized data in rescore chunk tables, delete from rescore vectors
|
// 4b. zero out quantized data in rescore chunk tables, delete from rescore vectors
|
||||||
rc = rescore_on_delete(p, chunk_id, chunk_offset, rowid);
|
rc = rescore_on_delete(p, chunk_id, chunk_offset, rowid);
|
||||||
if (rc != SQLITE_OK) {
|
if (rc != SQLITE_OK) {
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
// 5. delete from _rowids table
|
// 5. delete from _rowids table
|
||||||
rc = vec0Update_Delete_DeleteRowids(p, rowid);
|
rc = vec0Update_Delete_DeleteRowids(p, rowid);
|
||||||
|
|
@ -10125,6 +10241,26 @@ int vec0Update_Update(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Block vector UPDATE for index types that don't implement it —
|
||||||
|
// the DiskANN graph / IVF lists would become stale.
|
||||||
|
{
|
||||||
|
enum Vec0IndexType idx_type = p->vector_columns[vector_idx].index_type;
|
||||||
|
const char *idx_name = NULL;
|
||||||
|
if (idx_type == VEC0_INDEX_TYPE_DISKANN) idx_name = "DiskANN";
|
||||||
|
#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE
|
||||||
|
else if (idx_type == VEC0_INDEX_TYPE_IVF) idx_name = "IVF";
|
||||||
|
#endif
|
||||||
|
if (idx_name) {
|
||||||
|
vtab_set_error(
|
||||||
|
&p->base,
|
||||||
|
"UPDATE on vector column \"%.*s\" is not supported for %s indexes.",
|
||||||
|
p->vector_columns[vector_idx].name_length,
|
||||||
|
p->vector_columns[vector_idx].name,
|
||||||
|
idx_name);
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
rc = vec0Update_UpdateVectorColumn(p, chunk_id, chunk_offset, vector_idx,
|
rc = vec0Update_UpdateVectorColumn(p, chunk_id, chunk_offset, vector_idx,
|
||||||
valueVector, rowid);
|
valueVector, rowid);
|
||||||
if (rc != SQLITE_OK) {
|
if (rc != SQLITE_OK) {
|
||||||
|
|
@ -10143,25 +10279,31 @@ static int vec0Update(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
|
||||||
}
|
}
|
||||||
// INSERT operation
|
// INSERT operation
|
||||||
else if (argc > 1 && sqlite3_value_type(argv[0]) == SQLITE_NULL) {
|
else if (argc > 1 && sqlite3_value_type(argv[0]) == SQLITE_NULL) {
|
||||||
#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE || SQLITE_VEC_ENABLE_DISKANN
|
vec0_vtab *p = (vec0_vtab *)pVTab;
|
||||||
// Check for command inserts: INSERT INTO t(rowid) VALUES ('command-string')
|
// FTS5-style command dispatch via hidden column named after table
|
||||||
// The id column holds the command string.
|
if (p->hasCommandColumn) {
|
||||||
sqlite3_value *idVal = argv[2 + VEC0_COLUMN_ID];
|
sqlite3_value *cmdVal = argv[2 + vec0_column_command_idx(p)];
|
||||||
if (sqlite3_value_type(idVal) == SQLITE_TEXT) {
|
if (sqlite3_value_type(cmdVal) == SQLITE_TEXT) {
|
||||||
const char *cmd = (const char *)sqlite3_value_text(idVal);
|
const char *cmd = (const char *)sqlite3_value_text(cmdVal);
|
||||||
vec0_vtab *p = (vec0_vtab *)pVTab;
|
int cmdRc = SQLITE_EMPTY;
|
||||||
int cmdRc = SQLITE_EMPTY;
|
#if SQLITE_VEC_ENABLE_RESCORE
|
||||||
|
cmdRc = rescore_handle_command(p, cmd);
|
||||||
|
#endif
|
||||||
#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE
|
#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE
|
||||||
cmdRc = ivf_handle_command(p, cmd, argc, argv);
|
if (cmdRc == SQLITE_EMPTY)
|
||||||
|
cmdRc = ivf_handle_command(p, cmd, argc, argv);
|
||||||
#endif
|
#endif
|
||||||
#if SQLITE_VEC_ENABLE_DISKANN
|
#if SQLITE_VEC_ENABLE_DISKANN
|
||||||
if (cmdRc == SQLITE_EMPTY)
|
if (cmdRc == SQLITE_EMPTY)
|
||||||
cmdRc = diskann_handle_command(p, cmd);
|
cmdRc = diskann_handle_command(p, cmd);
|
||||||
#endif
|
#endif
|
||||||
if (cmdRc != SQLITE_EMPTY) return cmdRc; // handled (or error)
|
if (cmdRc == SQLITE_EMPTY) {
|
||||||
// SQLITE_EMPTY means not a recognized command — fall through to normal insert
|
vtab_set_error(pVTab, "unknown vec0 command: '%s'", cmd);
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
return cmdRc;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
return vec0Update_Insert(pVTab, argc, argv, pRowid);
|
return vec0Update_Insert(pVTab, argc, argv, pRowid);
|
||||||
}
|
}
|
||||||
// UPDATE operation
|
// UPDATE operation
|
||||||
|
|
@ -10261,6 +10403,163 @@ static int vec0Rollback(sqlite3_vtab *pVTab) {
|
||||||
return SQLITE_OK;
|
return SQLITE_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* xRename implementation for vec0.
|
||||||
|
* Renames all shadow tables to match the new virtual table name,
|
||||||
|
* then updates cached table names and finalizes stale prepared statements.
|
||||||
|
*/
|
||||||
|
static int vec0Rename(sqlite3_vtab *pVtab, const char *zNew) {
|
||||||
|
vec0_vtab *p = (vec0_vtab *)pVtab;
|
||||||
|
int rc = SQLITE_OK;
|
||||||
|
|
||||||
|
// Build a single SQL string with ALTER TABLE RENAME for every shadow table.
|
||||||
|
sqlite3_str *s = sqlite3_str_new(p->db);
|
||||||
|
|
||||||
|
// Core shadow tables (always present)
|
||||||
|
sqlite3_str_appendf(s,
|
||||||
|
"ALTER TABLE \"%w\".\"%w_info\" RENAME TO \"%w_info\";",
|
||||||
|
p->schemaName, p->tableName, zNew);
|
||||||
|
sqlite3_str_appendf(s,
|
||||||
|
"ALTER TABLE \"%w\".\"%w_rowids\" RENAME TO \"%w_rowids\";",
|
||||||
|
p->schemaName, p->tableName, zNew);
|
||||||
|
sqlite3_str_appendf(s,
|
||||||
|
"ALTER TABLE \"%w\".\"%w_chunks\" RENAME TO \"%w_chunks\";",
|
||||||
|
p->schemaName, p->tableName, zNew);
|
||||||
|
|
||||||
|
// Auxiliary shadow table (only if auxiliary columns exist)
|
||||||
|
if (p->numAuxiliaryColumns > 0) {
|
||||||
|
sqlite3_str_appendf(s,
|
||||||
|
"ALTER TABLE \"%w\".\"%w_auxiliary\" RENAME TO \"%w_auxiliary\";",
|
||||||
|
p->schemaName, p->tableName, zNew);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Per-vector-column shadow tables
|
||||||
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
|
sqlite3_str_appendf(s,
|
||||||
|
"ALTER TABLE \"%w\".\"%w_vector_chunks%02d\" RENAME TO \"%w_vector_chunks%02d\";",
|
||||||
|
p->schemaName, p->tableName, i, zNew, i);
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_RESCORE
|
||||||
|
if (p->shadowRescoreChunksNames[i]) {
|
||||||
|
sqlite3_str_appendf(s,
|
||||||
|
"ALTER TABLE \"%w\".\"%w_rescore_chunks%02d\" RENAME TO \"%w_rescore_chunks%02d\";",
|
||||||
|
p->schemaName, p->tableName, i, zNew, i);
|
||||||
|
sqlite3_str_appendf(s,
|
||||||
|
"ALTER TABLE \"%w\".\"%w_rescore_vectors%02d\" RENAME TO \"%w_rescore_vectors%02d\";",
|
||||||
|
p->schemaName, p->tableName, i, zNew, i);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_DISKANN
|
||||||
|
if (p->shadowVectorsNames[i]) {
|
||||||
|
sqlite3_str_appendf(s,
|
||||||
|
"ALTER TABLE \"%w\".\"%w_vectors%02d\" RENAME TO \"%w_vectors%02d\";",
|
||||||
|
p->schemaName, p->tableName, i, zNew, i);
|
||||||
|
sqlite3_str_appendf(s,
|
||||||
|
"ALTER TABLE \"%w\".\"%w_diskann_nodes%02d\" RENAME TO \"%w_diskann_nodes%02d\";",
|
||||||
|
p->schemaName, p->tableName, i, zNew, i);
|
||||||
|
sqlite3_str_appendf(s,
|
||||||
|
"ALTER TABLE \"%w\".\"%w_diskann_buffer%02d\" RENAME TO \"%w_diskann_buffer%02d\";",
|
||||||
|
p->schemaName, p->tableName, i, zNew, i);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE
|
||||||
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
|
if (p->shadowIvfCellsNames[i]) {
|
||||||
|
sqlite3_str_appendf(s,
|
||||||
|
"ALTER TABLE \"%w\".\"%w_ivf_cells%02d\" RENAME TO \"%w_ivf_cells%02d\";",
|
||||||
|
p->schemaName, p->tableName, i, zNew, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Per-metadata-column shadow tables
|
||||||
|
for (int i = 0; i < p->numMetadataColumns; i++) {
|
||||||
|
sqlite3_str_appendf(s,
|
||||||
|
"ALTER TABLE \"%w\".\"%w_metadatachunks%02d\" RENAME TO \"%w_metadatachunks%02d\";",
|
||||||
|
p->schemaName, p->tableName, i, zNew, i);
|
||||||
|
if (p->metadata_columns[i].kind == VEC0_METADATA_COLUMN_KIND_TEXT) {
|
||||||
|
sqlite3_str_appendf(s,
|
||||||
|
"ALTER TABLE \"%w\".\"%w_metadatatext%02d\" RENAME TO \"%w_metadatatext%02d\";",
|
||||||
|
p->schemaName, p->tableName, i, zNew, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
char *zSql = sqlite3_str_finish(s);
|
||||||
|
if (!zSql) {
|
||||||
|
return SQLITE_NOMEM;
|
||||||
|
}
|
||||||
|
|
||||||
|
rc = sqlite3_exec(p->db, zSql, 0, 0, 0);
|
||||||
|
sqlite3_free(zSql);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finalize all prepared statements — they reference old table names.
|
||||||
|
vec0_free_resources(p);
|
||||||
|
|
||||||
|
// Update cached table name
|
||||||
|
sqlite3_free(p->tableName);
|
||||||
|
p->tableName = sqlite3_mprintf("%s", zNew);
|
||||||
|
if (!p->tableName) return SQLITE_NOMEM;
|
||||||
|
|
||||||
|
// Update cached shadow table names
|
||||||
|
sqlite3_free(p->shadowRowidsName);
|
||||||
|
p->shadowRowidsName = sqlite3_mprintf("%s_rowids", zNew);
|
||||||
|
|
||||||
|
sqlite3_free(p->shadowChunksName);
|
||||||
|
p->shadowChunksName = sqlite3_mprintf("%s_chunks", zNew);
|
||||||
|
|
||||||
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
|
sqlite3_free(p->shadowVectorChunksNames[i]);
|
||||||
|
p->shadowVectorChunksNames[i] =
|
||||||
|
sqlite3_mprintf("%s_vector_chunks%02d", zNew, i);
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_RESCORE
|
||||||
|
if (p->shadowRescoreChunksNames[i]) {
|
||||||
|
sqlite3_free(p->shadowRescoreChunksNames[i]);
|
||||||
|
p->shadowRescoreChunksNames[i] =
|
||||||
|
sqlite3_mprintf("%s_rescore_chunks%02d", zNew, i);
|
||||||
|
sqlite3_free(p->shadowRescoreVectorsNames[i]);
|
||||||
|
p->shadowRescoreVectorsNames[i] =
|
||||||
|
sqlite3_mprintf("%s_rescore_vectors%02d", zNew, i);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_DISKANN
|
||||||
|
if (p->shadowVectorsNames[i]) {
|
||||||
|
sqlite3_free(p->shadowVectorsNames[i]);
|
||||||
|
p->shadowVectorsNames[i] =
|
||||||
|
sqlite3_mprintf("%s_vectors%02d", zNew, i);
|
||||||
|
sqlite3_free(p->shadowDiskannNodesNames[i]);
|
||||||
|
p->shadowDiskannNodesNames[i] =
|
||||||
|
sqlite3_mprintf("%s_diskann_nodes%02d", zNew, i);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE
|
||||||
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
|
if (p->shadowIvfCellsNames[i]) {
|
||||||
|
sqlite3_free(p->shadowIvfCellsNames[i]);
|
||||||
|
p->shadowIvfCellsNames[i] =
|
||||||
|
sqlite3_mprintf("%s_ivf_cells%02d", zNew, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
for (int i = 0; i < p->numMetadataColumns; i++) {
|
||||||
|
sqlite3_free(p->shadowMetadataChunksNames[i]);
|
||||||
|
p->shadowMetadataChunksNames[i] =
|
||||||
|
sqlite3_mprintf("%s_metadatachunks%02d", zNew, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
return SQLITE_OK;
|
||||||
|
}
|
||||||
|
|
||||||
static sqlite3_module vec0Module = {
|
static sqlite3_module vec0Module = {
|
||||||
/* iVersion */ 3,
|
/* iVersion */ 3,
|
||||||
/* xCreate */ vec0Create,
|
/* xCreate */ vec0Create,
|
||||||
|
|
@ -10281,7 +10580,7 @@ static sqlite3_module vec0Module = {
|
||||||
/* xCommit */ vec0Commit,
|
/* xCommit */ vec0Commit,
|
||||||
/* xRollback */ vec0Rollback,
|
/* xRollback */ vec0Rollback,
|
||||||
/* xFindFunction */ 0,
|
/* xFindFunction */ 0,
|
||||||
/* xRename */ 0, // https://github.com/asg017/sqlite-vec/issues/43
|
/* xRename */ vec0Rename,
|
||||||
/* xSavepoint */ 0,
|
/* xSavepoint */ 0,
|
||||||
/* xRelease */ 0,
|
/* xRelease */ 0,
|
||||||
/* xRollbackTo */ 0,
|
/* xRollbackTo */ 0,
|
||||||
|
|
|
||||||
File diff suppressed because one or more lines are too long
BIN
tests/fixtures/legacy-v0.1.6.db
vendored
Normal file
BIN
tests/fixtures/legacy-v0.1.6.db
vendored
Normal file
Binary file not shown.
|
|
@ -50,7 +50,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
{
|
{
|
||||||
sqlite3_stmt *stmt;
|
sqlite3_stmt *stmt;
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmt, NULL);
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmt, NULL);
|
||||||
for (int i = 1; i <= 8; i++) {
|
for (int i = 1; i <= 8; i++) {
|
||||||
float vec[8];
|
float vec[8];
|
||||||
for (int j = 0; j < 8; j++) vec[j] = (float)i * 0.1f + (float)j * 0.01f;
|
for (int j = 0; j < 8; j++) vec[j] = (float)i * 0.1f + (float)j * 0.01f;
|
||||||
|
|
@ -66,11 +66,11 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
sqlite3_stmt *stmtInsert = NULL;
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
sqlite3_stmt *stmtKnn = NULL;
|
sqlite3_stmt *stmtKnn = NULL;
|
||||||
|
|
||||||
/* Commands are dispatched via INSERT INTO t(rowid) VALUES ('cmd_string') */
|
/* Commands are dispatched via INSERT INTO t(t) VALUES ('cmd_string') */
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid) VALUES (?)", -1, &stmtCmd, NULL);
|
"INSERT INTO v(v) VALUES (?)", -1, &stmtCmd, NULL);
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?",
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?",
|
||||||
-1, &stmtKnn, NULL);
|
-1, &stmtKnn, NULL);
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
// Insert enough vectors to overflow at least one cell
|
// Insert enough vectors to overflow at least one cell
|
||||||
sqlite3_stmt *stmtInsert = NULL;
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
|
|
@ -81,7 +81,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
|
||||||
// Train to assign vectors to centroids (triggers cell building)
|
// Train to assign vectors to centroids (triggers cell building)
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
// Delete vectors at boundary positions based on fuzz data
|
// Delete vectors at boundary positions based on fuzz data
|
||||||
|
|
@ -102,7 +102,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
{
|
{
|
||||||
sqlite3_stmt *si = NULL;
|
sqlite3_stmt *si = NULL;
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &si, NULL);
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &si, NULL);
|
||||||
if (si) {
|
if (si) {
|
||||||
for (int i = 0; i < 10; i++) {
|
for (int i = 0; i < 10; i++) {
|
||||||
float *vec = sqlite3_malloc(dim * sizeof(float));
|
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
|
@ -140,7 +140,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
// Test assign-vectors with multi-cell state
|
// Test assign-vectors with multi-cell state
|
||||||
// First clear centroids
|
// First clear centroids
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('clear-centroids')",
|
"INSERT INTO v(v) VALUES ('clear-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
// Set centroids manually, then assign
|
// Set centroids manually, then assign
|
||||||
|
|
@ -151,7 +151,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
|
||||||
char cmd[128];
|
char cmd[128];
|
||||||
snprintf(cmd, sizeof(cmd),
|
snprintf(cmd, sizeof(cmd),
|
||||||
"INSERT INTO v(rowid, emb) VALUES ('set-centroid:%d', ?)", c);
|
"INSERT INTO v(v, emb) VALUES ('set-centroid:%d', ?)", c);
|
||||||
sqlite3_stmt *sc = NULL;
|
sqlite3_stmt *sc = NULL;
|
||||||
sqlite3_prepare_v2(db, cmd, -1, &sc, NULL);
|
sqlite3_prepare_v2(db, cmd, -1, &sc, NULL);
|
||||||
if (sc) {
|
if (sc) {
|
||||||
|
|
@ -163,7 +163,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('assign-vectors')",
|
"INSERT INTO v(v) VALUES ('assign-vectors')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
// Final query after assign-vectors
|
// Final query after assign-vectors
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
// Insert vectors
|
// Insert vectors
|
||||||
sqlite3_stmt *stmtInsert = NULL;
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
|
|
@ -125,14 +125,14 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
|
||||||
// Clear centroids and re-compute to test round-trip
|
// Clear centroids and re-compute to test round-trip
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('clear-centroids')",
|
"INSERT INTO v(v) VALUES ('clear-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
// Insert a few more vectors in untrained state
|
// Insert a few more vectors in untrained state
|
||||||
{
|
{
|
||||||
sqlite3_stmt *si = NULL;
|
sqlite3_stmt *si = NULL;
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &si, NULL);
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &si, NULL);
|
||||||
if (si) {
|
if (si) {
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
float *vec = sqlite3_malloc(dim * sizeof(float));
|
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
|
@ -150,7 +150,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
|
||||||
// Re-train
|
// Re-train
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
// Delete some rows after training, then query
|
// Delete some rows after training, then query
|
||||||
|
|
|
||||||
|
|
@ -92,7 +92,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
// Insert vectors
|
// Insert vectors
|
||||||
sqlite3_stmt *stmtInsert = NULL;
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
|
|
@ -134,14 +134,14 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
|
||||||
// Train
|
// Train
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
// Change nprobe at runtime (can exceed nlist -- tests clamping in query)
|
// Change nprobe at runtime (can exceed nlist -- tests clamping in query)
|
||||||
{
|
{
|
||||||
char cmd[64];
|
char cmd[64];
|
||||||
snprintf(cmd, sizeof(cmd),
|
snprintf(cmd, sizeof(cmd),
|
||||||
"INSERT INTO v(rowid) VALUES ('nprobe=%d')", nprobe_initial);
|
"INSERT INTO v(v) VALUES ('nprobe=%d')", nprobe_initial);
|
||||||
sqlite3_exec(db, cmd, NULL, NULL, NULL);
|
sqlite3_exec(db, cmd, NULL, NULL, NULL);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
|
|
@ -82,14 +82,14 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
case 4: {
|
case 4: {
|
||||||
// compute-centroids command
|
// compute-centroids command
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case 5: {
|
case 5: {
|
||||||
// clear-centroids command
|
// clear-centroids command
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('clear-centroids')",
|
"INSERT INTO v(v) VALUES ('clear-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
@ -100,7 +100,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
int nprobe = (n % 4) + 1;
|
int nprobe = (n % 4) + 1;
|
||||||
char buf[64];
|
char buf[64];
|
||||||
snprintf(buf, sizeof(buf),
|
snprintf(buf, sizeof(buf),
|
||||||
"INSERT INTO v(rowid) VALUES ('nprobe=%d')", nprobe);
|
"INSERT INTO v(v) VALUES ('nprobe=%d')", nprobe);
|
||||||
sqlite3_exec(db, buf, NULL, NULL, NULL);
|
sqlite3_exec(db, buf, NULL, NULL, NULL);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
// Insert vectors with fuzz-controlled float values
|
// Insert vectors with fuzz-controlled float values
|
||||||
sqlite3_stmt *stmtInsert = NULL;
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
|
|
@ -93,7 +93,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
|
||||||
// Trigger compute-centroids to exercise kmeans + quantization together
|
// Trigger compute-centroids to exercise kmeans + quantization together
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
// KNN query with fuzz-derived query vector
|
// KNN query with fuzz-derived query vector
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
// Insert vectors with diverse values
|
// Insert vectors with diverse values
|
||||||
sqlite3_stmt *stmtInsert = NULL;
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
|
|
@ -103,7 +103,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
|
||||||
// Train
|
// Train
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
// Multiple KNN queries to exercise rescore path
|
// Multiple KNN queries to exercise rescore path
|
||||||
|
|
@ -156,7 +156,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
|
||||||
// Retrain after deletions
|
// Retrain after deletions
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
// Query after retrain
|
// Query after retrain
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
{
|
{
|
||||||
sqlite3_stmt *si = NULL;
|
sqlite3_stmt *si = NULL;
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &si, NULL);
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &si, NULL);
|
||||||
if (!si) { sqlite3_close(db); return 0; }
|
if (!si) { sqlite3_close(db); return 0; }
|
||||||
for (int i = 0; i < 10; i++) {
|
for (int i = 0; i < 10; i++) {
|
||||||
float vec[8];
|
float vec[8];
|
||||||
|
|
@ -63,7 +63,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
|
||||||
// Train
|
// Train
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
// Now corrupt shadow tables based on fuzz input
|
// Now corrupt shadow tables based on fuzz input
|
||||||
|
|
@ -204,7 +204,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
float newvec[8] = {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f};
|
float newvec[8] = {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f};
|
||||||
sqlite3_stmt *si = NULL;
|
sqlite3_stmt *si = NULL;
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &si, NULL);
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &si, NULL);
|
||||||
if (si) {
|
if (si) {
|
||||||
sqlite3_bind_int64(si, 1, 100);
|
sqlite3_bind_int64(si, 1, 100);
|
||||||
sqlite3_bind_blob(si, 2, newvec, sizeof(newvec), SQLITE_STATIC);
|
sqlite3_bind_blob(si, 2, newvec, sizeof(newvec), SQLITE_STATIC);
|
||||||
|
|
@ -215,12 +215,12 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
|
||||||
// compute-centroids over corrupted state
|
// compute-centroids over corrupted state
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
// clear-centroids
|
// clear-centroids
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('clear-centroids')",
|
"INSERT INTO v(v) VALUES ('clear-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
sqlite3_close(db);
|
sqlite3_close(db);
|
||||||
|
|
|
||||||
81
tests/generate_legacy_db.py
Normal file
81
tests/generate_legacy_db.py
Normal file
|
|
@ -0,0 +1,81 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.10"
|
||||||
|
# dependencies = ["sqlite-vec==0.1.6"]
|
||||||
|
# ///
|
||||||
|
"""Generate a legacy sqlite-vec database for backwards-compat testing.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
uv run --script generate_legacy_db.py
|
||||||
|
|
||||||
|
Creates tests/fixtures/legacy-v0.1.6.db with a vec0 table containing
|
||||||
|
test data that can be read by the current version of sqlite-vec.
|
||||||
|
"""
|
||||||
|
import sqlite3
|
||||||
|
import sqlite_vec
|
||||||
|
import struct
|
||||||
|
import os
|
||||||
|
|
||||||
|
FIXTURE_DIR = os.path.join(os.path.dirname(__file__), "fixtures")
|
||||||
|
DB_PATH = os.path.join(FIXTURE_DIR, "legacy-v0.1.6.db")
|
||||||
|
|
||||||
|
DIMS = 4
|
||||||
|
N_ROWS = 50
|
||||||
|
|
||||||
|
|
||||||
|
def _f32(vals):
|
||||||
|
return struct.pack(f"{len(vals)}f", *vals)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
os.makedirs(FIXTURE_DIR, exist_ok=True)
|
||||||
|
if os.path.exists(DB_PATH):
|
||||||
|
os.remove(DB_PATH)
|
||||||
|
|
||||||
|
db = sqlite3.connect(DB_PATH)
|
||||||
|
db.enable_load_extension(True)
|
||||||
|
sqlite_vec.load(db)
|
||||||
|
|
||||||
|
# Print version for verification
|
||||||
|
version = db.execute("SELECT vec_version()").fetchone()[0]
|
||||||
|
print(f"sqlite-vec version: {version}")
|
||||||
|
|
||||||
|
# Create a basic vec0 table — flat index, no fancy features
|
||||||
|
db.execute(f"CREATE VIRTUAL TABLE legacy_vectors USING vec0(emb float[{DIMS}])")
|
||||||
|
|
||||||
|
# Insert test data: vectors where element[0] == rowid for easy verification
|
||||||
|
for i in range(1, N_ROWS + 1):
|
||||||
|
vec = [float(i), 0.0, 0.0, 0.0]
|
||||||
|
db.execute("INSERT INTO legacy_vectors(rowid, emb) VALUES (?, ?)", [i, _f32(vec)])
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
count = db.execute("SELECT count(*) FROM legacy_vectors").fetchone()[0]
|
||||||
|
print(f"Inserted {count} rows")
|
||||||
|
|
||||||
|
# Test KNN works
|
||||||
|
query = _f32([1.0, 0.0, 0.0, 0.0])
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid, distance FROM legacy_vectors WHERE emb MATCH ? AND k = 5",
|
||||||
|
[query],
|
||||||
|
).fetchall()
|
||||||
|
print(f"KNN top 5: {[(r[0], round(r[1], 4)) for r in rows]}")
|
||||||
|
assert rows[0][0] == 1 # closest to [1,0,0,0]
|
||||||
|
assert len(rows) == 5
|
||||||
|
|
||||||
|
# Also create a table with name == column name (the conflict case)
|
||||||
|
# This was allowed in old versions — new code must not break on reconnect
|
||||||
|
db.execute("CREATE VIRTUAL TABLE emb USING vec0(emb float[4])")
|
||||||
|
for i in range(1, 11):
|
||||||
|
db.execute("INSERT INTO emb(rowid, emb) VALUES (?, ?)", [i, _f32([float(i), 0, 0, 0])])
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
count2 = db.execute("SELECT count(*) FROM emb").fetchone()[0]
|
||||||
|
print(f"Table 'emb' with column 'emb': {count2} rows (name conflict case)")
|
||||||
|
|
||||||
|
db.close()
|
||||||
|
print(f"\nGenerated: {DB_PATH}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from helpers import exec, vec0_shadow_table_contents
|
import struct
|
||||||
|
import pytest
|
||||||
|
from helpers import exec, vec0_shadow_table_contents, _f32
|
||||||
|
|
||||||
|
|
||||||
def test_constructor_limit(db, snapshot):
|
def test_constructor_limit(db, snapshot):
|
||||||
|
|
@ -126,3 +128,198 @@ def test_knn(db, snapshot):
|
||||||
) == snapshot(name="illegal KNN w/ aux")
|
) == snapshot(name="illegal KNN w/ aux")
|
||||||
|
|
||||||
|
|
||||||
|
# ======================================================================
|
||||||
|
# Auxiliary columns with non-flat indexes
|
||||||
|
# ======================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_rescore_aux_shadow_tables(db, snapshot):
|
||||||
|
"""Rescore + aux column: verify shadow tables are created correctly."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
" emb float[128] indexed by rescore(quantizer=bit),"
|
||||||
|
" +label text,"
|
||||||
|
" +score float"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
assert exec(db, "SELECT name, sql FROM sqlite_master WHERE type='table' AND name LIKE 't_%' ORDER BY name") == snapshot(
|
||||||
|
name="rescore aux shadow tables"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rescore_aux_insert_knn(db, snapshot):
|
||||||
|
"""Insert with aux data, KNN should return aux column values."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
" emb float[128] indexed by rescore(quantizer=bit),"
|
||||||
|
" +label text"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
import random
|
||||||
|
random.seed(77)
|
||||||
|
data = [
|
||||||
|
("alpha", [random.gauss(0, 1) for _ in range(128)]),
|
||||||
|
("beta", [random.gauss(0, 1) for _ in range(128)]),
|
||||||
|
("gamma", [random.gauss(0, 1) for _ in range(128)]),
|
||||||
|
]
|
||||||
|
for label, vec in data:
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(emb, label) VALUES (?, ?)",
|
||||||
|
[_f32(vec), label],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exec(db, "SELECT rowid, label FROM t ORDER BY rowid") == snapshot(
|
||||||
|
name="rescore aux select all"
|
||||||
|
)
|
||||||
|
assert vec0_shadow_table_contents(db, "t", skip_info=True) == snapshot(
|
||||||
|
name="rescore aux shadow contents"
|
||||||
|
)
|
||||||
|
|
||||||
|
# KNN should include aux column, "alpha" closest to its own vector
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT label, distance FROM t WHERE emb MATCH ? ORDER BY distance LIMIT 3",
|
||||||
|
[_f32(data[0][1])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 3
|
||||||
|
assert rows[0][0] == "alpha"
|
||||||
|
|
||||||
|
|
||||||
|
def test_rescore_aux_update(db):
|
||||||
|
"""UPDATE aux column on rescore table should work without affecting vectors."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
" emb float[128] indexed by rescore(quantizer=bit),"
|
||||||
|
" +label text"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
import random
|
||||||
|
random.seed(88)
|
||||||
|
vec = [random.gauss(0, 1) for _ in range(128)]
|
||||||
|
db.execute("INSERT INTO t(rowid, emb, label) VALUES (1, ?, 'original')", [_f32(vec)])
|
||||||
|
db.execute("UPDATE t SET label = 'updated' WHERE rowid = 1")
|
||||||
|
|
||||||
|
assert db.execute("SELECT label FROM t WHERE rowid = 1").fetchone()[0] == "updated"
|
||||||
|
|
||||||
|
# KNN still works with updated aux
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid, label FROM t WHERE emb MATCH ? ORDER BY distance LIMIT 1",
|
||||||
|
[_f32(vec)],
|
||||||
|
).fetchall()
|
||||||
|
assert rows[0][0] == 1
|
||||||
|
assert rows[0][1] == "updated"
|
||||||
|
|
||||||
|
|
||||||
|
def test_rescore_aux_delete(db, snapshot):
|
||||||
|
"""DELETE should remove aux data from shadow table."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
" emb float[128] indexed by rescore(quantizer=bit),"
|
||||||
|
" +label text"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
import random
|
||||||
|
random.seed(99)
|
||||||
|
for i in range(5):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, emb, label) VALUES (?, ?, ?)",
|
||||||
|
[i + 1, _f32([random.gauss(0, 1) for _ in range(128)]), f"item-{i+1}"],
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = 3")
|
||||||
|
|
||||||
|
assert exec(db, "SELECT rowid, label FROM t ORDER BY rowid") == snapshot(
|
||||||
|
name="rescore aux after delete"
|
||||||
|
)
|
||||||
|
assert exec(db, "SELECT rowid, value00 FROM t_auxiliary ORDER BY rowid") == snapshot(
|
||||||
|
name="rescore aux shadow after delete"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_diskann_aux_shadow_tables(db, snapshot):
|
||||||
|
"""DiskANN + aux column: verify shadow tables are created correctly."""
|
||||||
|
db.execute("""
|
||||||
|
CREATE VIRTUAL TABLE t USING vec0(
|
||||||
|
emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8),
|
||||||
|
+label text,
|
||||||
|
+score float
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
assert exec(db, "SELECT name, sql FROM sqlite_master WHERE type='table' AND name LIKE 't_%' ORDER BY name") == snapshot(
|
||||||
|
name="diskann aux shadow tables"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_diskann_aux_insert_knn(db, snapshot):
|
||||||
|
"""DiskANN + aux: insert, KNN, verify aux values returned."""
|
||||||
|
db.execute("""
|
||||||
|
CREATE VIRTUAL TABLE t USING vec0(
|
||||||
|
emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8),
|
||||||
|
+label text
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
data = [
|
||||||
|
("red", [1, 0, 0, 0, 0, 0, 0, 0]),
|
||||||
|
("green", [0, 1, 0, 0, 0, 0, 0, 0]),
|
||||||
|
("blue", [0, 0, 1, 0, 0, 0, 0, 0]),
|
||||||
|
]
|
||||||
|
for label, vec in data:
|
||||||
|
db.execute("INSERT INTO t(emb, label) VALUES (?, ?)", [_f32(vec), label])
|
||||||
|
|
||||||
|
assert exec(db, "SELECT rowid, label FROM t ORDER BY rowid") == snapshot(
|
||||||
|
name="diskann aux select all"
|
||||||
|
)
|
||||||
|
assert vec0_shadow_table_contents(db, "t", skip_info=True) == snapshot(
|
||||||
|
name="diskann aux shadow contents"
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT label, distance FROM t WHERE emb MATCH ? AND k = 3",
|
||||||
|
[_f32([1, 0, 0, 0, 0, 0, 0, 0])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) >= 1
|
||||||
|
assert rows[0][0] == "red"
|
||||||
|
|
||||||
|
|
||||||
|
def test_diskann_aux_update_and_delete(db, snapshot):
|
||||||
|
"""DiskANN + aux: update aux column, delete row, verify cleanup."""
|
||||||
|
db.execute("""
|
||||||
|
CREATE VIRTUAL TABLE t USING vec0(
|
||||||
|
emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8),
|
||||||
|
+label text
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
for i in range(5):
|
||||||
|
vec = [0.0] * 8
|
||||||
|
vec[i % 8] = 1.0
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, emb, label) VALUES (?, ?, ?)",
|
||||||
|
[i + 1, _f32(vec), f"item-{i+1}"],
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("UPDATE t SET label = 'UPDATED' WHERE rowid = 2")
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = 3")
|
||||||
|
|
||||||
|
assert exec(db, "SELECT rowid, label FROM t ORDER BY rowid") == snapshot(
|
||||||
|
name="diskann aux after update+delete"
|
||||||
|
)
|
||||||
|
assert exec(db, "SELECT rowid, value00 FROM t_auxiliary ORDER BY rowid") == snapshot(
|
||||||
|
name="diskann aux shadow after update+delete"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_diskann_aux_drop_cleans_all(db):
|
||||||
|
"""DROP TABLE should remove aux shadow table too."""
|
||||||
|
db.execute("""
|
||||||
|
CREATE VIRTUAL TABLE t USING vec0(
|
||||||
|
emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8),
|
||||||
|
+label text
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
db.execute("INSERT INTO t(emb, label) VALUES (?, 'test')", [_f32([1]*8)])
|
||||||
|
db.execute("DROP TABLE t")
|
||||||
|
|
||||||
|
tables = [r[0] for r in db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE name LIKE 't_%'"
|
||||||
|
).fetchall()]
|
||||||
|
assert "t_auxiliary" not in tables
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -589,7 +589,7 @@ def test_diskann_command_search_list_size(db):
|
||||||
assert len(results_before) == 5
|
assert len(results_before) == 5
|
||||||
|
|
||||||
# Override search_list_size_search at runtime
|
# Override search_list_size_search at runtime
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('search_list_size_search=256')")
|
db.execute("INSERT INTO t(t) VALUES ('search_list_size_search=256')")
|
||||||
|
|
||||||
# Query should still work
|
# Query should still work
|
||||||
results_after = db.execute(
|
results_after = db.execute(
|
||||||
|
|
@ -598,14 +598,14 @@ def test_diskann_command_search_list_size(db):
|
||||||
assert len(results_after) == 5
|
assert len(results_after) == 5
|
||||||
|
|
||||||
# Override search_list_size_insert at runtime
|
# Override search_list_size_insert at runtime
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('search_list_size_insert=32')")
|
db.execute("INSERT INTO t(t) VALUES ('search_list_size_insert=32')")
|
||||||
|
|
||||||
# Inserts should still work
|
# Inserts should still work
|
||||||
vec = struct.pack("64f", *[random.random() for _ in range(64)])
|
vec = struct.pack("64f", *[random.random() for _ in range(64)])
|
||||||
db.execute("INSERT INTO t(emb) VALUES (?)", [vec])
|
db.execute("INSERT INTO t(emb) VALUES (?)", [vec])
|
||||||
|
|
||||||
# Override unified search_list_size
|
# Override unified search_list_size
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('search_list_size=64')")
|
db.execute("INSERT INTO t(t) VALUES ('search_list_size=64')")
|
||||||
|
|
||||||
results_final = db.execute(
|
results_final = db.execute(
|
||||||
"SELECT rowid, distance FROM t WHERE emb MATCH ? AND k = 5", [query]
|
"SELECT rowid, distance FROM t WHERE emb MATCH ? AND k = 5", [query]
|
||||||
|
|
@ -620,9 +620,9 @@ def test_diskann_command_search_list_size_error(db):
|
||||||
emb float[64] INDEXED BY diskann(neighbor_quantizer=binary)
|
emb float[64] INDEXED BY diskann(neighbor_quantizer=binary)
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
result = exec(db, "INSERT INTO t(rowid) VALUES ('search_list_size=0')")
|
result = exec(db, "INSERT INTO t(t) VALUES ('search_list_size=0')")
|
||||||
assert "error" in result
|
assert "error" in result
|
||||||
result = exec(db, "INSERT INTO t(rowid) VALUES ('search_list_size=-1')")
|
result = exec(db, "INSERT INTO t(t) VALUES ('search_list_size=-1')")
|
||||||
assert "error" in result
|
assert "error" in result
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -630,16 +630,19 @@ def test_diskann_command_search_list_size_error(db):
|
||||||
# Error cases: DiskANN + auxiliary/metadata/partition columns
|
# Error cases: DiskANN + auxiliary/metadata/partition columns
|
||||||
# ======================================================================
|
# ======================================================================
|
||||||
|
|
||||||
def test_diskann_create_error_with_auxiliary_column(db):
|
def test_diskann_create_with_auxiliary_column(db):
|
||||||
"""DiskANN tables should not support auxiliary columns."""
|
"""DiskANN tables should support auxiliary columns."""
|
||||||
result = exec(db, """
|
db.execute("""
|
||||||
CREATE VIRTUAL TABLE t USING vec0(
|
CREATE VIRTUAL TABLE t USING vec0(
|
||||||
emb float[64] INDEXED BY diskann(neighbor_quantizer=binary),
|
emb float[64] INDEXED BY diskann(neighbor_quantizer=binary),
|
||||||
+extra text
|
+extra text
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
assert "error" in result
|
# Auxiliary shadow table should exist
|
||||||
assert "auxiliary" in result["message"].lower() or "Auxiliary" in result["message"]
|
tables = [r[0] for r in db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE name LIKE 't_%' ORDER BY 1"
|
||||||
|
).fetchall()]
|
||||||
|
assert "t_auxiliary" in tables
|
||||||
|
|
||||||
|
|
||||||
def test_diskann_create_error_with_metadata_column(db):
|
def test_diskann_create_error_with_metadata_column(db):
|
||||||
|
|
@ -891,7 +894,7 @@ def test_diskann_delete_preserves_graph_connectivity(db):
|
||||||
# ======================================================================
|
# ======================================================================
|
||||||
|
|
||||||
def test_diskann_update_vector(db):
|
def test_diskann_update_vector(db):
|
||||||
"""UPDATE a vector on DiskANN table may not be supported; verify it either works or errors cleanly."""
|
"""UPDATE a vector on DiskANN table should error (will be implemented soon)."""
|
||||||
db.execute("""
|
db.execute("""
|
||||||
CREATE VIRTUAL TABLE t USING vec0(
|
CREATE VIRTUAL TABLE t USING vec0(
|
||||||
emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
|
emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
|
||||||
|
|
@ -901,17 +904,8 @@ def test_diskann_update_vector(db):
|
||||||
db.execute("INSERT INTO t(rowid, emb) VALUES (2, ?)", [_f32([0, 1, 0, 0, 0, 0, 0, 0])])
|
db.execute("INSERT INTO t(rowid, emb) VALUES (2, ?)", [_f32([0, 1, 0, 0, 0, 0, 0, 0])])
|
||||||
db.execute("INSERT INTO t(rowid, emb) VALUES (3, ?)", [_f32([0, 0, 1, 0, 0, 0, 0, 0])])
|
db.execute("INSERT INTO t(rowid, emb) VALUES (3, ?)", [_f32([0, 0, 1, 0, 0, 0, 0, 0])])
|
||||||
|
|
||||||
# UPDATE may not be fully supported for DiskANN yet; verify no crash
|
with pytest.raises(sqlite3.OperationalError, match="UPDATE on vector column.*not supported for DiskANN"):
|
||||||
result = exec(db, "UPDATE t SET emb = ? WHERE rowid = 1", [_f32([0, 0.9, 0.1, 0, 0, 0, 0, 0])])
|
db.execute("UPDATE t SET emb = ? WHERE rowid = 1", [_f32([0, 0.9, 0.1, 0, 0, 0, 0, 0])])
|
||||||
if "error" not in result:
|
|
||||||
# If UPDATE succeeded, verify KNN reflects the new value
|
|
||||||
rows = db.execute(
|
|
||||||
"SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=3",
|
|
||||||
[_f32([0, 1, 0, 0, 0, 0, 0, 0])],
|
|
||||||
).fetchall()
|
|
||||||
assert len(rows) == 3
|
|
||||||
# rowid 2 should still be closest (exact match)
|
|
||||||
assert rows[0][0] == 2
|
|
||||||
|
|
||||||
|
|
||||||
# ======================================================================
|
# ======================================================================
|
||||||
|
|
@ -1158,3 +1152,180 @@ def test_diskann_large_batch_insert_500(db):
|
||||||
distances = [r[1] for r in rows]
|
distances = [r[1] for r in rows]
|
||||||
for i in range(len(distances) - 1):
|
for i in range(len(distances) - 1):
|
||||||
assert distances[i] <= distances[i + 1]
|
assert distances[i] <= distances[i + 1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_corrupt_truncated_node_blob(db):
|
||||||
|
"""KNN should error (not crash) when DiskANN node blob is truncated."""
|
||||||
|
db.execute("""
|
||||||
|
CREATE VIRTUAL TABLE t USING vec0(
|
||||||
|
emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
for i in range(5):
|
||||||
|
vec = [0.0] * 8
|
||||||
|
vec[i % 8] = 1.0
|
||||||
|
db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i + 1, _f32(vec)])
|
||||||
|
|
||||||
|
# Corrupt a DiskANN node: truncate neighbor_ids to 1 byte (wrong size)
|
||||||
|
db.execute(
|
||||||
|
"UPDATE t_diskann_nodes00 SET neighbor_ids = x'00' WHERE rowid = 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not crash — may return wrong results or error
|
||||||
|
try:
|
||||||
|
db.execute(
|
||||||
|
"SELECT rowid FROM t WHERE emb MATCH ? AND k=3",
|
||||||
|
[_f32([1, 0, 0, 0, 0, 0, 0, 0])],
|
||||||
|
).fetchall()
|
||||||
|
except sqlite3.OperationalError:
|
||||||
|
pass # Error is acceptable — crash is not
|
||||||
|
|
||||||
|
|
||||||
|
def test_diskann_delete_reinsert_cycle_knn(db):
|
||||||
|
"""Repeatedly delete and reinsert rows, verify KNN stays correct."""
|
||||||
|
import random
|
||||||
|
random.seed(101)
|
||||||
|
db.execute("""
|
||||||
|
CREATE VIRTUAL TABLE t USING vec0(
|
||||||
|
emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
N = 30
|
||||||
|
vecs = {}
|
||||||
|
for i in range(1, N + 1):
|
||||||
|
v = [random.gauss(0, 1) for _ in range(8)]
|
||||||
|
vecs[i] = v
|
||||||
|
db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(v)])
|
||||||
|
|
||||||
|
# 3 cycles: delete half, reinsert with new vectors, verify KNN
|
||||||
|
for cycle in range(3):
|
||||||
|
to_delete = random.sample(sorted(vecs.keys()), len(vecs) // 2)
|
||||||
|
for r in to_delete:
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = ?", [r])
|
||||||
|
del vecs[r]
|
||||||
|
|
||||||
|
# Reinsert with new rowids
|
||||||
|
new_start = 100 + cycle * 50
|
||||||
|
for i in range(len(to_delete)):
|
||||||
|
rid = new_start + i
|
||||||
|
v = [random.gauss(0, 1) for _ in range(8)]
|
||||||
|
vecs[rid] = v
|
||||||
|
db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [rid, _f32(v)])
|
||||||
|
|
||||||
|
# KNN should return only alive rows
|
||||||
|
query = [0.0] * 8
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid FROM t WHERE emb MATCH ? AND k=10",
|
||||||
|
[_f32(query)],
|
||||||
|
).fetchall()
|
||||||
|
returned = {r["rowid"] for r in rows}
|
||||||
|
assert returned.issubset(set(vecs.keys())), \
|
||||||
|
f"Cycle {cycle}: deleted rowid in KNN results"
|
||||||
|
assert len(rows) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_diskann_delete_interleaved_with_knn(db):
|
||||||
|
"""Delete one row at a time, querying KNN after each delete."""
|
||||||
|
db.execute("""
|
||||||
|
CREATE VIRTUAL TABLE t USING vec0(
|
||||||
|
emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
N = 20
|
||||||
|
for i in range(1, N + 1):
|
||||||
|
vec = [0.0] * 8
|
||||||
|
vec[i % 8] = float(i)
|
||||||
|
db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)])
|
||||||
|
|
||||||
|
alive = set(range(1, N + 1))
|
||||||
|
for to_del in [1, 5, 10, 15, 20]:
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = ?", [to_del])
|
||||||
|
alive.discard(to_del)
|
||||||
|
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid FROM t WHERE emb MATCH ? AND k=5",
|
||||||
|
[_f32([1, 0, 0, 0, 0, 0, 0, 0])],
|
||||||
|
).fetchall()
|
||||||
|
returned = {r["rowid"] for r in rows}
|
||||||
|
assert returned.issubset(alive), \
|
||||||
|
f"Deleted rowid {to_del} found in KNN results"
|
||||||
|
|
||||||
|
|
||||||
|
# ======================================================================
|
||||||
|
# Text primary key + DiskANN
|
||||||
|
# ======================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_diskann_text_pk_insert_knn_delete(db):
|
||||||
|
"""DiskANN with text primary key: insert, KNN, delete, KNN again."""
|
||||||
|
db.execute("""
|
||||||
|
CREATE VIRTUAL TABLE t USING vec0(
|
||||||
|
id text primary key,
|
||||||
|
emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
vecs = {
|
||||||
|
"alpha": [1, 0, 0, 0, 0, 0, 0, 0],
|
||||||
|
"beta": [0, 1, 0, 0, 0, 0, 0, 0],
|
||||||
|
"gamma": [0, 0, 1, 0, 0, 0, 0, 0],
|
||||||
|
"delta": [0, 0, 0, 1, 0, 0, 0, 0],
|
||||||
|
"epsilon": [0, 0, 0, 0, 1, 0, 0, 0],
|
||||||
|
}
|
||||||
|
for name, vec in vecs.items():
|
||||||
|
db.execute("INSERT INTO t(id, emb) VALUES (?, ?)", [name, _f32(vec)])
|
||||||
|
|
||||||
|
# KNN should return text IDs
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT id, distance FROM t WHERE emb MATCH ? AND k=3",
|
||||||
|
[_f32([1, 0, 0, 0, 0, 0, 0, 0])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) >= 1
|
||||||
|
ids = [r["id"] for r in rows]
|
||||||
|
assert "alpha" in ids # closest to query
|
||||||
|
|
||||||
|
# Delete and verify
|
||||||
|
db.execute("DELETE FROM t WHERE id = 'alpha'")
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT id FROM t WHERE emb MATCH ? AND k=3",
|
||||||
|
[_f32([1, 0, 0, 0, 0, 0, 0, 0])],
|
||||||
|
).fetchall()
|
||||||
|
ids = [r["id"] for r in rows]
|
||||||
|
assert "alpha" not in ids
|
||||||
|
|
||||||
|
|
||||||
|
def test_diskann_delete_scrubs_all_references(db):
|
||||||
|
"""After DELETE, no shadow table should contain the deleted rowid or its data."""
|
||||||
|
import struct
|
||||||
|
db.execute("""
|
||||||
|
CREATE VIRTUAL TABLE t USING vec0(
|
||||||
|
emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
for i in range(20):
|
||||||
|
vec = struct.pack("8f", *[float(i + d) for d in range(8)])
|
||||||
|
db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, vec])
|
||||||
|
|
||||||
|
target = 5
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = ?", [target])
|
||||||
|
|
||||||
|
# Node row itself should be gone
|
||||||
|
assert db.execute(
|
||||||
|
"SELECT count(*) FROM t_diskann_nodes00 WHERE rowid=?", [target]
|
||||||
|
).fetchone()[0] == 0
|
||||||
|
|
||||||
|
# Vector should be gone
|
||||||
|
assert db.execute(
|
||||||
|
"SELECT count(*) FROM t_vectors00 WHERE rowid=?", [target]
|
||||||
|
).fetchone()[0] == 0
|
||||||
|
|
||||||
|
# No other node should reference the deleted rowid in neighbor_ids
|
||||||
|
for row in db.execute("SELECT rowid, neighbor_ids FROM t_diskann_nodes00"):
|
||||||
|
node_rowid = row[0]
|
||||||
|
ids_blob = row[1]
|
||||||
|
for j in range(0, len(ids_blob), 8):
|
||||||
|
nid = struct.unpack("<q", ids_blob[j : j + 8])[0]
|
||||||
|
assert nid != target, (
|
||||||
|
f"Node {node_rowid} slot {j // 8} still references "
|
||||||
|
f"deleted rowid {target}"
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -27,3 +27,15 @@ def test_info(db, snapshot):
|
||||||
assert exec(db, "select key, typeof(value) from v_info order by 1") == snapshot()
|
assert exec(db, "select key, typeof(value) from v_info order by 1") == snapshot()
|
||||||
|
|
||||||
|
|
||||||
|
def test_command_column_name_conflict(db):
|
||||||
|
"""Table name matching a column name should error (command column conflict)."""
|
||||||
|
# This would conflict: hidden command column 'embeddings' vs vector column 'embeddings'
|
||||||
|
with pytest.raises(sqlite3.OperationalError, match="conflicts with table name"):
|
||||||
|
db.execute(
|
||||||
|
"create virtual table embeddings using vec0(embeddings float[4])"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Different names should work fine
|
||||||
|
db.execute("create virtual table t using vec0(embeddings float[4])")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -483,3 +483,171 @@ def test_delete_one_chunk_of_two_shrinks_pages(tmp_path):
|
||||||
row = db.execute("select emb from v where rowid = ?", [i]).fetchone()
|
row = db.execute("select emb from v where rowid = ?", [i]).fetchone()
|
||||||
assert row[0] == _f32([float(i)] * dims)
|
assert row[0] == _f32([float(i)] * dims)
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_wal_concurrent_reader_during_write(tmp_path):
|
||||||
|
"""In WAL mode, a reader should see a consistent snapshot while a writer inserts."""
|
||||||
|
dims = 4
|
||||||
|
db_path = str(tmp_path / "test.db")
|
||||||
|
|
||||||
|
# Writer: create table, insert initial rows, enable WAL
|
||||||
|
writer = sqlite3.connect(db_path)
|
||||||
|
writer.enable_load_extension(True)
|
||||||
|
writer.load_extension("dist/vec0")
|
||||||
|
writer.execute("PRAGMA journal_mode=WAL")
|
||||||
|
writer.execute(
|
||||||
|
f"CREATE VIRTUAL TABLE v USING vec0(emb float[{dims}])"
|
||||||
|
)
|
||||||
|
for i in range(1, 11):
|
||||||
|
writer.execute("INSERT INTO v(rowid, emb) VALUES (?, ?)", [i, _f32([float(i)] * dims)])
|
||||||
|
writer.commit()
|
||||||
|
|
||||||
|
# Reader: open separate connection, start read
|
||||||
|
reader = sqlite3.connect(db_path)
|
||||||
|
reader.enable_load_extension(True)
|
||||||
|
reader.load_extension("dist/vec0")
|
||||||
|
|
||||||
|
# Reader sees 10 rows
|
||||||
|
count_before = reader.execute("SELECT count(*) FROM v").fetchone()[0]
|
||||||
|
assert count_before == 10
|
||||||
|
|
||||||
|
# Writer inserts more rows (not yet committed)
|
||||||
|
writer.execute("BEGIN")
|
||||||
|
for i in range(11, 21):
|
||||||
|
writer.execute("INSERT INTO v(rowid, emb) VALUES (?, ?)", [i, _f32([float(i)] * dims)])
|
||||||
|
|
||||||
|
# Reader still sees 10 (WAL snapshot isolation)
|
||||||
|
count_during = reader.execute("SELECT count(*) FROM v").fetchone()[0]
|
||||||
|
assert count_during == 10
|
||||||
|
|
||||||
|
# KNN during writer's transaction should work on reader's snapshot
|
||||||
|
rows = reader.execute(
|
||||||
|
"SELECT rowid FROM v WHERE emb MATCH ? AND k = 5",
|
||||||
|
[_f32([1.0] * dims)],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 5
|
||||||
|
assert all(r[0] <= 10 for r in rows) # only original rows
|
||||||
|
|
||||||
|
# Writer commits
|
||||||
|
writer.commit()
|
||||||
|
|
||||||
|
# Reader sees new rows after re-query (new snapshot)
|
||||||
|
count_after = reader.execute("SELECT count(*) FROM v").fetchone()[0]
|
||||||
|
assert count_after == 20
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
reader.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_or_replace_integer_pk(db):
|
||||||
|
"""INSERT OR REPLACE should update vector when rowid already exists."""
|
||||||
|
db.execute("create virtual table v using vec0(emb float[4], chunk_size=8)")
|
||||||
|
|
||||||
|
db.execute(
|
||||||
|
"insert into v(rowid, emb) values (1, ?)", [_f32([1.0, 2.0, 3.0, 4.0])]
|
||||||
|
)
|
||||||
|
# Replace with new vector
|
||||||
|
db.execute(
|
||||||
|
"insert or replace into v(rowid, emb) values (1, ?)",
|
||||||
|
[_f32([10.0, 20.0, 30.0, 40.0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should still have exactly 1 row
|
||||||
|
count = db.execute("select count(*) from v").fetchone()[0]
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
# Vector should be the replaced value
|
||||||
|
row = db.execute("select emb from v where rowid = 1").fetchone()
|
||||||
|
assert row[0] == _f32([10.0, 20.0, 30.0, 40.0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_or_replace_new_row(db):
|
||||||
|
"""INSERT OR REPLACE with a new rowid should just insert normally."""
|
||||||
|
db.execute("create virtual table v using vec0(emb float[4], chunk_size=8)")
|
||||||
|
|
||||||
|
db.execute(
|
||||||
|
"insert or replace into v(rowid, emb) values (1, ?)",
|
||||||
|
[_f32([1.0, 2.0, 3.0, 4.0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
count = db.execute("select count(*) from v").fetchone()[0]
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
row = db.execute("select emb from v where rowid = 1").fetchone()
|
||||||
|
assert row[0] == _f32([1.0, 2.0, 3.0, 4.0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_or_replace_text_pk(db):
|
||||||
|
"""INSERT OR REPLACE should work with text primary keys."""
|
||||||
|
db.execute(
|
||||||
|
"create virtual table v using vec0("
|
||||||
|
"id text primary key, emb float[4], chunk_size=8"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute(
|
||||||
|
"insert into v(id, emb) values ('doc_a', ?)",
|
||||||
|
[_f32([1.0, 2.0, 3.0, 4.0])],
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"insert or replace into v(id, emb) values ('doc_a', ?)",
|
||||||
|
[_f32([10.0, 20.0, 30.0, 40.0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
count = db.execute("select count(*) from v").fetchone()[0]
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
row = db.execute("select emb from v where id = 'doc_a'").fetchone()
|
||||||
|
assert row[0] == _f32([10.0, 20.0, 30.0, 40.0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_or_replace_with_auxiliary(db):
|
||||||
|
"""INSERT OR REPLACE should also replace auxiliary column values."""
|
||||||
|
db.execute(
|
||||||
|
"create virtual table v using vec0("
|
||||||
|
"emb float[4], +label text, chunk_size=8"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute(
|
||||||
|
"insert into v(rowid, emb, label) values (1, ?, 'old')",
|
||||||
|
[_f32([1.0, 2.0, 3.0, 4.0])],
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"insert or replace into v(rowid, emb, label) values (1, ?, 'new')",
|
||||||
|
[_f32([10.0, 20.0, 30.0, 40.0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
count = db.execute("select count(*) from v").fetchone()[0]
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
row = db.execute("select emb, label from v where rowid = 1").fetchone()
|
||||||
|
assert row[0] == _f32([10.0, 20.0, 30.0, 40.0])
|
||||||
|
assert row[1] == "new"
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_or_replace_knn_uses_new_vector(db):
|
||||||
|
"""After INSERT OR REPLACE, KNN should find the new vector, not the old one."""
|
||||||
|
db.execute("create virtual table v using vec0(emb float[4], chunk_size=8)")
|
||||||
|
|
||||||
|
db.execute(
|
||||||
|
"insert into v(rowid, emb) values (1, ?)", [_f32([1.0, 0.0, 0.0, 0.0])]
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"insert into v(rowid, emb) values (2, ?)", [_f32([0.0, 1.0, 0.0, 0.0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replace row 1's vector to be very close to row 2
|
||||||
|
db.execute(
|
||||||
|
"insert or replace into v(rowid, emb) values (1, ?)",
|
||||||
|
[_f32([0.0, 0.9, 0.0, 0.0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
# KNN for [0, 1, 0, 0] should return row 2 first (exact), then row 1 (close)
|
||||||
|
rows = db.execute(
|
||||||
|
"select rowid, distance from v where emb match ? and k = 2",
|
||||||
|
[_f32([0.0, 1.0, 0.0, 0.0])],
|
||||||
|
).fetchall()
|
||||||
|
assert rows[0][0] == 2
|
||||||
|
assert rows[1][0] == 1
|
||||||
|
assert rows[1][1] < 0.11 # should be close (L2 distance ≈ 0.1)
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,7 @@ def test_batch_insert_knn_recall(db):
|
||||||
)
|
)
|
||||||
assert ivf_total_vectors(db) == 200
|
assert ivf_total_vectors(db) == 200
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
assert ivf_assigned_count(db) == 200
|
assert ivf_assigned_count(db) == 200
|
||||||
|
|
||||||
# Query near 100 -- closest should be rowid 100
|
# Query near 100 -- closest should be rowid 100
|
||||||
|
|
@ -107,7 +107,7 @@ def test_delete_rows_gone_from_knn(db):
|
||||||
[i, _f32([float(i), 0, 0, 0])],
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
# Delete rowid 10
|
# Delete rowid 10
|
||||||
db.execute("DELETE FROM t WHERE rowid = 10")
|
db.execute("DELETE FROM t WHERE rowid = 10")
|
||||||
|
|
@ -127,7 +127,7 @@ def test_delete_all_rows_empty_results(db):
|
||||||
[i, _f32([float(i), 0, 0, 0])],
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
db.execute("DELETE FROM t WHERE rowid = ?", [i])
|
db.execute("DELETE FROM t WHERE rowid = ?", [i])
|
||||||
|
|
@ -152,7 +152,7 @@ def test_insert_after_delete_reuse_rowid(db):
|
||||||
[i, _f32([float(i), 0, 0, 0])],
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
# Delete rowid 5
|
# Delete rowid 5
|
||||||
db.execute("DELETE FROM t WHERE rowid = 5")
|
db.execute("DELETE FROM t WHERE rowid = 5")
|
||||||
|
|
@ -184,7 +184,7 @@ def test_update_vector_via_delete_insert(db):
|
||||||
[i, _f32([float(i), 0, 0, 0])],
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
# "Update" rowid 3: delete and re-insert with new vector
|
# "Update" rowid 3: delete and re-insert with new vector
|
||||||
db.execute("DELETE FROM t WHERE rowid = 3")
|
db.execute("DELETE FROM t WHERE rowid = 3")
|
||||||
|
|
@ -203,13 +203,15 @@ def test_update_vector_via_delete_insert(db):
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
def test_error_ivf_with_auxiliary_column(db):
|
def test_ivf_with_auxiliary_column(db):
|
||||||
result = exec(
|
"""IVF should support auxiliary columns."""
|
||||||
db,
|
db.execute(
|
||||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(), +extra text)",
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(), +extra text)"
|
||||||
)
|
)
|
||||||
assert "error" in result
|
tables = [r[0] for r in db.execute(
|
||||||
assert "auxiliary" in result.get("message", "").lower()
|
"SELECT name FROM sqlite_master WHERE name LIKE 't_%' ORDER BY 1"
|
||||||
|
).fetchall()]
|
||||||
|
assert "t_auxiliary" in tables
|
||||||
|
|
||||||
|
|
||||||
def test_error_ivf_with_metadata_column(db):
|
def test_error_ivf_with_metadata_column(db):
|
||||||
|
|
@ -314,7 +316,7 @@ def test_single_row_compute_centroids(db):
|
||||||
db.execute(
|
db.execute(
|
||||||
"INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 2, 3, 4])]
|
"INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 2, 3, 4])]
|
||||||
)
|
)
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
assert ivf_assigned_count(db) == 1
|
assert ivf_assigned_count(db) == 1
|
||||||
|
|
||||||
results = knn(db, [1, 2, 3, 4], 1)
|
results = knn(db, [1, 2, 3, 4], 1)
|
||||||
|
|
@ -341,10 +343,10 @@ def test_cell_overflow_many_vectors(db):
|
||||||
|
|
||||||
# Set a single centroid so all vectors go there
|
# Set a single centroid so all vectors go there
|
||||||
db.execute(
|
db.execute(
|
||||||
"INSERT INTO t(rowid, v) VALUES ('set-centroid:0', ?)",
|
"INSERT INTO t(t, v) VALUES ('set-centroid:0', ?)",
|
||||||
[_f32([1.0, 0, 0, 0])],
|
[_f32([1.0, 0, 0, 0])],
|
||||||
)
|
)
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('assign-vectors')")
|
db.execute("INSERT INTO t(t) VALUES ('assign-vectors')")
|
||||||
|
|
||||||
assert ivf_assigned_count(db) == 100
|
assert ivf_assigned_count(db) == 100
|
||||||
|
|
||||||
|
|
@ -375,7 +377,7 @@ def test_large_batch_with_training(db):
|
||||||
[i, _f32([float(i), 0, 0, 0])],
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
for i in range(500, 1000):
|
for i in range(500, 1000):
|
||||||
db.execute(
|
db.execute(
|
||||||
|
|
@ -407,7 +409,7 @@ def test_knn_after_interleaved_insert_delete(db):
|
||||||
[i, _f32([float(i), 0, 0, 0])],
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
# Delete rowids 0-9 (closest to query at 5.0)
|
# Delete rowids 0-9 (closest to query at 5.0)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
|
|
@ -432,7 +434,7 @@ def test_knn_empty_centroids_after_deletes(db):
|
||||||
[i, _f32([float(i % 10) * 10, 0, 0, 0])],
|
[i, _f32([float(i % 10) * 10, 0, 0, 0])],
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
# Delete a bunch, potentially emptying some centroids
|
# Delete a bunch, potentially emptying some centroids
|
||||||
for i in range(30):
|
for i in range(30):
|
||||||
|
|
@ -456,7 +458,7 @@ def test_knn_correct_distances(db):
|
||||||
db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([3, 0, 0, 0])])
|
db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([3, 0, 0, 0])])
|
||||||
db.execute("INSERT INTO t(rowid, v) VALUES (3, ?)", [_f32([0, 4, 0, 0])])
|
db.execute("INSERT INTO t(rowid, v) VALUES (3, ?)", [_f32([0, 4, 0, 0])])
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
results = knn(db, [0, 0, 0, 0], 3)
|
results = knn(db, [0, 0, 0, 0], 3)
|
||||||
result_map = {r[0]: r[1] for r in results}
|
result_map = {r[0]: r[1] for r in results}
|
||||||
|
|
@ -545,7 +547,7 @@ def test_interleaved_ops_correctness(db):
|
||||||
[i, _f32([float(i), 0, 0, 0])],
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
# Phase 2: Delete even-numbered rowids
|
# Phase 2: Delete even-numbered rowids
|
||||||
for i in range(0, 50, 2):
|
for i in range(0, 50, 2):
|
||||||
|
|
@ -573,3 +575,15 @@ def test_interleaved_ops_correctness(db):
|
||||||
# Verify we get the right count (25 odd + 15 new - 10 deleted new = 30)
|
# Verify we get the right count (25 odd + 15 new - 10 deleted new = 30)
|
||||||
expected_alive = set(range(1, 50, 2)) | set(range(50, 60)) | set(range(70, 75))
|
expected_alive = set(range(1, 50, 2)) | set(range(50, 60)) | set(range(70, 75))
|
||||||
assert rowids.issubset(expected_alive)
|
assert rowids.issubset(expected_alive)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_update_vector_blocked(db):
|
||||||
|
"""UPDATE on a vector column with IVF index should error (index would become stale)."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(emb float[4] indexed by ivf(nlist=2))"
|
||||||
|
)
|
||||||
|
db.execute("INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([1, 0, 0, 0])])
|
||||||
|
db.execute("INSERT INTO t(rowid, emb) VALUES (2, ?)", [_f32([0, 1, 0, 0])])
|
||||||
|
|
||||||
|
with pytest.raises(sqlite3.OperationalError, match="UPDATE on vector column.*not supported for IVF"):
|
||||||
|
db.execute("UPDATE t SET emb = ? WHERE rowid = 1", [_f32([0, 0, 1, 0])])
|
||||||
|
|
|
||||||
|
|
@ -122,7 +122,7 @@ def test_ivf_int8_insert_and_query(db):
|
||||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
# Should be able to query
|
# Should be able to query
|
||||||
rows = db.execute(
|
rows = db.execute(
|
||||||
|
|
@ -151,7 +151,7 @@ def test_ivf_binary_insert_and_query(db):
|
||||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32(v)]
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32(v)]
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
rows = db.execute(
|
rows = db.execute(
|
||||||
"SELECT rowid FROM t WHERE v MATCH ? AND k = 5",
|
"SELECT rowid FROM t WHERE v MATCH ? AND k = 5",
|
||||||
|
|
@ -221,10 +221,10 @@ def test_ivf_int8_oversample_improves_recall(db):
|
||||||
db.execute("INSERT INTO t1(rowid, v) VALUES (?, ?)", [i, v])
|
db.execute("INSERT INTO t1(rowid, v) VALUES (?, ?)", [i, v])
|
||||||
db.execute("INSERT INTO t2(rowid, v) VALUES (?, ?)", [i, v])
|
db.execute("INSERT INTO t2(rowid, v) VALUES (?, ?)", [i, v])
|
||||||
|
|
||||||
db.execute("INSERT INTO t1(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t1(t1) VALUES ('compute-centroids')")
|
||||||
db.execute("INSERT INTO t2(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t2(t2) VALUES ('compute-centroids')")
|
||||||
db.execute("INSERT INTO t1(rowid) VALUES ('nprobe=4')")
|
db.execute("INSERT INTO t1(t1) VALUES ('nprobe=4')")
|
||||||
db.execute("INSERT INTO t2(rowid) VALUES ('nprobe=4')")
|
db.execute("INSERT INTO t2(t2) VALUES ('nprobe=4')")
|
||||||
|
|
||||||
query = _f32([5.0, 1.5, 2.5, 0.5])
|
query = _f32([5.0, 1.5, 2.5, 0.5])
|
||||||
r1 = db.execute("SELECT rowid FROM t1 WHERE v MATCH ? AND k=10", [query]).fetchall()
|
r1 = db.execute("SELECT rowid FROM t1 WHERE v MATCH ? AND k=10", [query]).fetchall()
|
||||||
|
|
@ -247,9 +247,26 @@ def test_ivf_quantized_delete(db):
|
||||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
assert db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0] == 10
|
assert db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0] == 10
|
||||||
|
|
||||||
db.execute("DELETE FROM t WHERE rowid = 5")
|
db.execute("DELETE FROM t WHERE rowid = 5")
|
||||||
# _ivf_vectors should have 9 rows
|
# _ivf_vectors should have 9 rows
|
||||||
assert db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0] == 9
|
assert db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0] == 9
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_binary_rejects_non_multiple_of_8_dims(db):
|
||||||
|
"""Binary quantizer requires dimensions divisible by 8."""
|
||||||
|
with pytest.raises(sqlite3.OperationalError):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
" v float[12] indexed by ivf(quantizer=binary)"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dimensions divisible by 8 should work
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t2 USING vec0("
|
||||||
|
" v float[16] indexed by ivf(quantizer=binary)"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -217,7 +217,7 @@ def test_compute_centroids(db):
|
||||||
|
|
||||||
assert ivf_unassigned_count(db) == 40
|
assert ivf_unassigned_count(db) == 40
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
# After training: unassigned cell should be gone (or empty), vectors in trained cells
|
# After training: unassigned cell should be gone (or empty), vectors in trained cells
|
||||||
assert ivf_unassigned_count(db) == 0
|
assert ivf_unassigned_count(db) == 0
|
||||||
|
|
@ -238,10 +238,10 @@ def test_compute_centroids_recompute(db):
|
||||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
||||||
assert ivf_assigned_count(db) == 20
|
assert ivf_assigned_count(db) == 20
|
||||||
|
|
||||||
|
|
@ -260,7 +260,7 @@ def test_ivf_insert_after_training(db):
|
||||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
db.execute(
|
db.execute(
|
||||||
"INSERT INTO t(rowid, v) VALUES (100, ?)", [_f32([5, 0, 0, 0])]
|
"INSERT INTO t(rowid, v) VALUES (100, ?)", [_f32([5, 0, 0, 0])]
|
||||||
|
|
@ -290,7 +290,7 @@ def test_ivf_knn_after_training(db):
|
||||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
rows = db.execute(
|
rows = db.execute(
|
||||||
"SELECT rowid, distance FROM t WHERE v MATCH ? AND k = 5",
|
"SELECT rowid, distance FROM t WHERE v MATCH ? AND k = 5",
|
||||||
|
|
@ -310,7 +310,7 @@ def test_ivf_knn_k_larger_than_n(db):
|
||||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
rows = db.execute(
|
rows = db.execute(
|
||||||
"SELECT rowid FROM t WHERE v MATCH ? AND k = 100",
|
"SELECT rowid FROM t WHERE v MATCH ? AND k = 100",
|
||||||
|
|
@ -334,17 +334,17 @@ def test_set_centroid_and_assign(db):
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute(
|
db.execute(
|
||||||
"INSERT INTO t(rowid, v) VALUES ('set-centroid:0', ?)",
|
"INSERT INTO t(t, v) VALUES ('set-centroid:0', ?)",
|
||||||
[_f32([5, 0, 0, 0])],
|
[_f32([5, 0, 0, 0])],
|
||||||
)
|
)
|
||||||
db.execute(
|
db.execute(
|
||||||
"INSERT INTO t(rowid, v) VALUES ('set-centroid:1', ?)",
|
"INSERT INTO t(t, v) VALUES ('set-centroid:1', ?)",
|
||||||
[_f32([15, 0, 0, 0])],
|
[_f32([15, 0, 0, 0])],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('assign-vectors')")
|
db.execute("INSERT INTO t(t) VALUES ('assign-vectors')")
|
||||||
|
|
||||||
assert ivf_unassigned_count(db) == 0
|
assert ivf_unassigned_count(db) == 0
|
||||||
assert ivf_assigned_count(db) == 20
|
assert ivf_assigned_count(db) == 20
|
||||||
|
|
@ -364,10 +364,10 @@ def test_clear_centroids(db):
|
||||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('clear-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('clear-centroids')")
|
||||||
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 0
|
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 0
|
||||||
assert ivf_unassigned_count(db) == 20
|
assert ivf_unassigned_count(db) == 20
|
||||||
trained = db.execute(
|
trained = db.execute(
|
||||||
|
|
@ -390,7 +390,7 @@ def test_ivf_delete_after_training(db):
|
||||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
assert ivf_assigned_count(db) == 10
|
assert ivf_assigned_count(db) == 10
|
||||||
|
|
||||||
db.execute("DELETE FROM t WHERE rowid = 5")
|
db.execute("DELETE FROM t WHERE rowid = 5")
|
||||||
|
|
@ -412,7 +412,7 @@ def test_ivf_recall_nprobe_equals_nlist(db):
|
||||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
rows = db.execute(
|
rows = db.execute(
|
||||||
"SELECT rowid FROM t WHERE v MATCH ? AND k = 10",
|
"SELECT rowid FROM t WHERE v MATCH ? AND k = 10",
|
||||||
|
|
|
||||||
138
tests/test-legacy-compat.py
Normal file
138
tests/test-legacy-compat.py
Normal file
|
|
@ -0,0 +1,138 @@
|
||||||
|
"""Backwards compatibility tests: current sqlite-vec reading legacy databases.
|
||||||
|
|
||||||
|
The fixture file tests/fixtures/legacy-v0.1.6.db was generated by
|
||||||
|
tests/generate_legacy_db.py using sqlite-vec v0.1.6. These tests verify
|
||||||
|
that the current version can fully read, query, insert into, and delete
|
||||||
|
from tables created by older versions.
|
||||||
|
"""
|
||||||
|
import sqlite3
|
||||||
|
import struct
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
FIXTURE_PATH = os.path.join(os.path.dirname(__file__), "fixtures", "legacy-v0.1.6.db")
|
||||||
|
|
||||||
|
|
||||||
|
def _f32(vals):
|
||||||
|
return struct.pack(f"{len(vals)}f", *vals)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def legacy_db(tmp_path):
|
||||||
|
"""Copy the legacy fixture to a temp dir so tests can modify it."""
|
||||||
|
if not os.path.exists(FIXTURE_PATH):
|
||||||
|
pytest.skip("Legacy fixture not found — run: uv run --script tests/generate_legacy_db.py")
|
||||||
|
db_path = str(tmp_path / "legacy.db")
|
||||||
|
shutil.copy2(FIXTURE_PATH, db_path)
|
||||||
|
db = sqlite3.connect(db_path)
|
||||||
|
db.row_factory = sqlite3.Row
|
||||||
|
db.enable_load_extension(True)
|
||||||
|
db.load_extension("dist/vec0")
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_select_count(legacy_db):
|
||||||
|
"""Basic SELECT count should return all rows."""
|
||||||
|
count = legacy_db.execute("SELECT count(*) FROM legacy_vectors").fetchone()[0]
|
||||||
|
assert count == 50
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_point_query(legacy_db):
|
||||||
|
"""Point query by rowid should return correct vector."""
|
||||||
|
row = legacy_db.execute(
|
||||||
|
"SELECT rowid, emb FROM legacy_vectors WHERE rowid = 1"
|
||||||
|
).fetchone()
|
||||||
|
assert row["rowid"] == 1
|
||||||
|
vec = struct.unpack("4f", row["emb"])
|
||||||
|
assert vec[0] == pytest.approx(1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_knn(legacy_db):
|
||||||
|
"""KNN query on legacy table should return correct results."""
|
||||||
|
query = _f32([1.0, 0.0, 0.0, 0.0])
|
||||||
|
rows = legacy_db.execute(
|
||||||
|
"SELECT rowid, distance FROM legacy_vectors "
|
||||||
|
"WHERE emb MATCH ? AND k = 5",
|
||||||
|
[query],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 5
|
||||||
|
assert rows[0]["rowid"] == 1
|
||||||
|
assert rows[0]["distance"] == pytest.approx(0.0)
|
||||||
|
for i in range(len(rows) - 1):
|
||||||
|
assert rows[i]["distance"] <= rows[i + 1]["distance"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_insert(legacy_db):
|
||||||
|
"""INSERT into legacy table should work."""
|
||||||
|
legacy_db.execute(
|
||||||
|
"INSERT INTO legacy_vectors(rowid, emb) VALUES (100, ?)",
|
||||||
|
[_f32([100.0, 0.0, 0.0, 0.0])],
|
||||||
|
)
|
||||||
|
count = legacy_db.execute("SELECT count(*) FROM legacy_vectors").fetchone()[0]
|
||||||
|
assert count == 51
|
||||||
|
|
||||||
|
rows = legacy_db.execute(
|
||||||
|
"SELECT rowid FROM legacy_vectors WHERE emb MATCH ? AND k = 1",
|
||||||
|
[_f32([100.0, 0.0, 0.0, 0.0])],
|
||||||
|
).fetchall()
|
||||||
|
assert rows[0]["rowid"] == 100
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_delete(legacy_db):
|
||||||
|
"""DELETE from legacy table should work."""
|
||||||
|
legacy_db.execute("DELETE FROM legacy_vectors WHERE rowid = 1")
|
||||||
|
count = legacy_db.execute("SELECT count(*) FROM legacy_vectors").fetchone()[0]
|
||||||
|
assert count == 49
|
||||||
|
|
||||||
|
rows = legacy_db.execute(
|
||||||
|
"SELECT rowid FROM legacy_vectors WHERE emb MATCH ? AND k = 5",
|
||||||
|
[_f32([1.0, 0.0, 0.0, 0.0])],
|
||||||
|
).fetchall()
|
||||||
|
assert 1 not in [r["rowid"] for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_fullscan(legacy_db):
|
||||||
|
"""Full scan should work."""
|
||||||
|
rows = legacy_db.execute(
|
||||||
|
"SELECT rowid FROM legacy_vectors ORDER BY rowid LIMIT 5"
|
||||||
|
).fetchall()
|
||||||
|
assert [r["rowid"] for r in rows] == [1, 2, 3, 4, 5]
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_name_conflict_table(legacy_db):
|
||||||
|
"""Legacy table where column name == table name should work.
|
||||||
|
|
||||||
|
The v0.1.6 DB has: CREATE VIRTUAL TABLE emb USING vec0(emb float[4])
|
||||||
|
Current code should NOT add the command column for this table
|
||||||
|
(detected via _info version check), avoiding the name conflict.
|
||||||
|
"""
|
||||||
|
count = legacy_db.execute("SELECT count(*) FROM emb").fetchone()[0]
|
||||||
|
assert count == 10
|
||||||
|
|
||||||
|
rows = legacy_db.execute(
|
||||||
|
"SELECT rowid, distance FROM emb WHERE emb MATCH ? AND k = 3",
|
||||||
|
[_f32([1.0, 0.0, 0.0, 0.0])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 3
|
||||||
|
assert rows[0]["rowid"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_name_conflict_insert_delete(legacy_db):
|
||||||
|
"""INSERT and DELETE on legacy name-conflict table."""
|
||||||
|
legacy_db.execute(
|
||||||
|
"INSERT INTO emb(rowid, emb) VALUES (100, ?)",
|
||||||
|
[_f32([100.0, 0.0, 0.0, 0.0])],
|
||||||
|
)
|
||||||
|
assert legacy_db.execute("SELECT count(*) FROM emb").fetchone()[0] == 11
|
||||||
|
|
||||||
|
legacy_db.execute("DELETE FROM emb WHERE rowid = 5")
|
||||||
|
assert legacy_db.execute("SELECT count(*) FROM emb").fetchone()[0] == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_no_command_column(legacy_db):
|
||||||
|
"""Legacy tables should NOT have the command column."""
|
||||||
|
with pytest.raises(sqlite3.OperationalError):
|
||||||
|
legacy_db.execute(
|
||||||
|
"INSERT INTO legacy_vectors(legacy_vectors) VALUES ('some_command')"
|
||||||
|
)
|
||||||
|
|
@ -365,6 +365,34 @@ def test_vec_distance_l1():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_vec_reject_nan_inf():
|
||||||
|
"""NaN and Inf in float32 vectors should be rejected."""
|
||||||
|
import struct, math
|
||||||
|
|
||||||
|
# NaN via blob
|
||||||
|
nan_blob = struct.pack("4f", 1.0, float("nan"), 3.0, 4.0)
|
||||||
|
with pytest.raises(sqlite3.OperationalError, match="NaN"):
|
||||||
|
db.execute("SELECT vec_length(?)", [nan_blob])
|
||||||
|
|
||||||
|
# Inf via blob
|
||||||
|
inf_blob = struct.pack("4f", 1.0, float("inf"), 3.0, 4.0)
|
||||||
|
with pytest.raises(sqlite3.OperationalError, match="Inf"):
|
||||||
|
db.execute("SELECT vec_length(?)", [inf_blob])
|
||||||
|
|
||||||
|
# -Inf via blob
|
||||||
|
ninf_blob = struct.pack("4f", 1.0, float("-inf"), 3.0, 4.0)
|
||||||
|
with pytest.raises(sqlite3.OperationalError, match="Inf"):
|
||||||
|
db.execute("SELECT vec_length(?)", [ninf_blob])
|
||||||
|
|
||||||
|
# NaN via JSON
|
||||||
|
# Note: JSON doesn't have NaN literal, but strtod may parse "NaN"
|
||||||
|
# This tests the blob path which is the primary input method
|
||||||
|
|
||||||
|
# Valid vectors still work
|
||||||
|
ok_blob = struct.pack("4f", 1.0, 2.0, 3.0, 4.0)
|
||||||
|
assert db.execute("SELECT vec_length(?)", [ok_blob]).fetchone()[0] == 4
|
||||||
|
|
||||||
|
|
||||||
def test_vec_distance_l2():
|
def test_vec_distance_l2():
|
||||||
vec_distance_l2 = lambda *args, a="?", b="?": db.execute(
|
vec_distance_l2 = lambda *args, a="?", b="?": db.execute(
|
||||||
f"select vec_distance_l2({a}, {b})", args
|
f"select vec_distance_l2({a}, {b})", args
|
||||||
|
|
@ -381,11 +409,17 @@ def test_vec_distance_l2():
|
||||||
|
|
||||||
x = vec_distance_l2(a_sql_t, b_sql_t, a=transform, b=transform)
|
x = vec_distance_l2(a_sql_t, b_sql_t, a=transform, b=transform)
|
||||||
y = npy_l2(np.array(a), np.array(b))
|
y = npy_l2(np.array(a), np.array(b))
|
||||||
assert isclose(x, y, abs_tol=1e-6)
|
assert isclose(x, y, rel_tol=1e-5, abs_tol=1e-6)
|
||||||
|
|
||||||
check([1.2, 0.1], [0.4, -0.4])
|
check([1.2, 0.1], [0.4, -0.4])
|
||||||
check([-1.2, -0.1], [-0.4, 0.4])
|
check([-1.2, -0.1], [-0.4, 0.4])
|
||||||
check([1, 2, 3], [-9, -8, -7], dtype=np.int8)
|
check([1, 2, 3], [-9, -8, -7], dtype=np.int8)
|
||||||
|
# Extreme int8 values: diff=255, squared=65025 which overflows i16
|
||||||
|
# This tests the NEON widening multiply fix (slight float rounding expected)
|
||||||
|
check([-128] * 8, [127] * 8, dtype=np.int8)
|
||||||
|
check([-128] * 16, [127] * 16, dtype=np.int8)
|
||||||
|
check([-128, 127, -128, 127, -128, 127, -128, 127],
|
||||||
|
[127, -128, 127, -128, 127, -128, 127, -128], dtype=np.int8)
|
||||||
|
|
||||||
|
|
||||||
def test_vec_length():
|
def test_vec_length():
|
||||||
|
|
|
||||||
173
tests/test-rename.py
Normal file
173
tests/test-rename.py
Normal file
|
|
@ -0,0 +1,173 @@
|
||||||
|
import sqlite3
|
||||||
|
import pytest
|
||||||
|
from helpers import _f32
|
||||||
|
|
||||||
|
|
||||||
|
def _shadow_tables(db, prefix):
|
||||||
|
"""Return sorted list of shadow table names for a given prefix."""
|
||||||
|
return sorted([
|
||||||
|
row[0] for row in db.execute(
|
||||||
|
r"select name from sqlite_master where name like ? escape '\' and type='table' order by 1",
|
||||||
|
[f"{prefix}\\__%"],
|
||||||
|
).fetchall()
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def test_rename_basic(db):
|
||||||
|
"""ALTER TABLE RENAME should rename vec0 table and all shadow tables."""
|
||||||
|
db.execute("create virtual table v using vec0(a float[2], chunk_size=8)")
|
||||||
|
db.execute("insert into v(rowid, a) values (1, ?)", [_f32([0.1, 0.2])])
|
||||||
|
db.execute("insert into v(rowid, a) values (2, ?)", [_f32([0.3, 0.4])])
|
||||||
|
|
||||||
|
assert _shadow_tables(db, "v") == [
|
||||||
|
"v_chunks",
|
||||||
|
"v_info",
|
||||||
|
"v_rowids",
|
||||||
|
"v_vector_chunks00",
|
||||||
|
]
|
||||||
|
|
||||||
|
db.execute("ALTER TABLE v RENAME TO v2")
|
||||||
|
|
||||||
|
# Old name should no longer work
|
||||||
|
with pytest.raises(sqlite3.OperationalError):
|
||||||
|
db.execute("select * from v")
|
||||||
|
|
||||||
|
# New name should work and return the same data
|
||||||
|
rows = db.execute(
|
||||||
|
"select rowid, distance from v2 where a match ? and k=10",
|
||||||
|
[_f32([0.1, 0.2])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 2
|
||||||
|
assert rows[0][0] == 1 # closest match
|
||||||
|
|
||||||
|
# Shadow tables should all be renamed
|
||||||
|
assert _shadow_tables(db, "v2") == [
|
||||||
|
"v2_chunks",
|
||||||
|
"v2_info",
|
||||||
|
"v2_rowids",
|
||||||
|
"v2_vector_chunks00",
|
||||||
|
]
|
||||||
|
|
||||||
|
# No old shadow tables should remain
|
||||||
|
assert _shadow_tables(db, "v") == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_rename_insert_after(db):
|
||||||
|
"""Inserts and queries should work after rename."""
|
||||||
|
db.execute("create virtual table v using vec0(a float[2], chunk_size=8)")
|
||||||
|
db.execute("insert into v(rowid, a) values (1, ?)", [_f32([0.1, 0.2])])
|
||||||
|
db.execute("ALTER TABLE v RENAME TO v2")
|
||||||
|
|
||||||
|
# Insert into renamed table
|
||||||
|
db.execute("insert into v2(rowid, a) values (2, ?)", [_f32([0.3, 0.4])])
|
||||||
|
|
||||||
|
rows = db.execute(
|
||||||
|
"select rowid from v2 where a match ? and k=10",
|
||||||
|
[_f32([0.3, 0.4])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 2
|
||||||
|
assert rows[0][0] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_rename_delete_after(db):
|
||||||
|
"""Deletes should work after rename."""
|
||||||
|
db.execute("create virtual table v using vec0(a float[2], chunk_size=8)")
|
||||||
|
db.execute("insert into v(rowid, a) values (1, ?)", [_f32([0.1, 0.2])])
|
||||||
|
db.execute("insert into v(rowid, a) values (2, ?)", [_f32([0.3, 0.4])])
|
||||||
|
db.execute("ALTER TABLE v RENAME TO v2")
|
||||||
|
|
||||||
|
db.execute("delete from v2 where rowid = 1")
|
||||||
|
rows = db.execute(
|
||||||
|
"select rowid from v2 where a match ? and k=10",
|
||||||
|
[_f32([0.3, 0.4])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 1
|
||||||
|
assert rows[0][0] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_rename_with_auxiliary(db):
|
||||||
|
"""Rename should also rename the _auxiliary shadow table."""
|
||||||
|
db.execute(
|
||||||
|
"create virtual table v using vec0(a float[2], +name text, chunk_size=8)"
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"insert into v(rowid, a, name) values (1, ?, 'hello')",
|
||||||
|
[_f32([0.1, 0.2])],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert _shadow_tables(db, "v") == [
|
||||||
|
"v_auxiliary",
|
||||||
|
"v_chunks",
|
||||||
|
"v_info",
|
||||||
|
"v_rowids",
|
||||||
|
"v_vector_chunks00",
|
||||||
|
]
|
||||||
|
|
||||||
|
db.execute("ALTER TABLE v RENAME TO v2")
|
||||||
|
|
||||||
|
# Auxiliary data should be accessible
|
||||||
|
rows = db.execute(
|
||||||
|
"select rowid, name from v2 where a match ? and k=10",
|
||||||
|
[_f32([0.1, 0.2])],
|
||||||
|
).fetchall()
|
||||||
|
assert rows[0][0] == 1
|
||||||
|
assert rows[0][1] == "hello"
|
||||||
|
|
||||||
|
assert _shadow_tables(db, "v2") == [
|
||||||
|
"v2_auxiliary",
|
||||||
|
"v2_chunks",
|
||||||
|
"v2_info",
|
||||||
|
"v2_rowids",
|
||||||
|
"v2_vector_chunks00",
|
||||||
|
]
|
||||||
|
assert _shadow_tables(db, "v") == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_rename_with_metadata(db):
|
||||||
|
"""Rename should also rename metadata shadow tables."""
|
||||||
|
db.execute(
|
||||||
|
"create virtual table v using vec0(a float[2], tag text, chunk_size=8)"
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"insert into v(rowid, a, tag) values (1, ?, 'a')",
|
||||||
|
[_f32([0.1, 0.2])],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert _shadow_tables(db, "v") == [
|
||||||
|
"v_chunks",
|
||||||
|
"v_info",
|
||||||
|
"v_metadatachunks00",
|
||||||
|
"v_metadatatext00",
|
||||||
|
"v_rowids",
|
||||||
|
"v_vector_chunks00",
|
||||||
|
]
|
||||||
|
|
||||||
|
db.execute("ALTER TABLE v RENAME TO v2")
|
||||||
|
|
||||||
|
rows = db.execute(
|
||||||
|
"select rowid, tag from v2 where a match ? and k=10",
|
||||||
|
[_f32([0.1, 0.2])],
|
||||||
|
).fetchall()
|
||||||
|
assert rows[0][0] == 1
|
||||||
|
assert rows[0][1] == "a"
|
||||||
|
|
||||||
|
assert _shadow_tables(db, "v2") == [
|
||||||
|
"v2_chunks",
|
||||||
|
"v2_info",
|
||||||
|
"v2_metadatachunks00",
|
||||||
|
"v2_metadatatext00",
|
||||||
|
"v2_rowids",
|
||||||
|
"v2_vector_chunks00",
|
||||||
|
]
|
||||||
|
assert _shadow_tables(db, "v") == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_rename_drop_after(db):
|
||||||
|
"""DROP TABLE should work on a renamed table."""
|
||||||
|
db.execute("create virtual table v using vec0(a float[2], chunk_size=8)")
|
||||||
|
db.execute("insert into v(rowid, a) values (1, ?)", [_f32([0.1, 0.2])])
|
||||||
|
db.execute("ALTER TABLE v RENAME TO v2")
|
||||||
|
db.execute("DROP TABLE v2")
|
||||||
|
|
||||||
|
# Nothing should remain
|
||||||
|
assert _shadow_tables(db, "v2") == []
|
||||||
|
|
@ -32,15 +32,18 @@ def unpack_float_vec(blob):
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
def test_create_error_with_aux_column(db):
|
def test_create_with_aux_column(db):
|
||||||
"""Rescore should reject auxiliary columns."""
|
"""Rescore should support auxiliary columns."""
|
||||||
with pytest.raises(sqlite3.OperationalError, match="Auxiliary columns"):
|
db.execute(
|
||||||
db.execute(
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
"CREATE VIRTUAL TABLE t USING vec0("
|
" embedding float[128] indexed by rescore(quantizer=bit),"
|
||||||
" embedding float[8] indexed by rescore(quantizer=bit),"
|
" +extra text"
|
||||||
" +extra text"
|
")"
|
||||||
")"
|
)
|
||||||
)
|
tables = [r[0] for r in db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE name LIKE 't_%' ORDER BY 1"
|
||||||
|
).fetchall()]
|
||||||
|
assert "t_auxiliary" in tables
|
||||||
|
|
||||||
|
|
||||||
def test_create_error_with_metadata_column(db):
|
def test_create_error_with_metadata_column(db):
|
||||||
|
|
@ -443,6 +446,104 @@ def test_insert_batch_recall(db):
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_interleaved_with_knn(db):
|
||||||
|
"""Delete rows one at a time, running KNN after each delete to verify correctness."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
N = 30
|
||||||
|
random.seed(42)
|
||||||
|
vecs = {i: [random.gauss(0, 1) for _ in range(8)] for i in range(1, N + 1)}
|
||||||
|
for rowid, vec in vecs.items():
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
||||||
|
[rowid, float_vec(vec)],
|
||||||
|
)
|
||||||
|
|
||||||
|
alive = set(vecs.keys())
|
||||||
|
query = [0.0] * 8
|
||||||
|
|
||||||
|
for to_del in [5, 10, 15, 20, 25]:
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = ?", [to_del])
|
||||||
|
alive.discard(to_del)
|
||||||
|
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
|
||||||
|
[float_vec(query)],
|
||||||
|
).fetchall()
|
||||||
|
returned = {r["rowid"] for r in rows}
|
||||||
|
# All returned rows must be alive (not deleted)
|
||||||
|
assert returned.issubset(alive), f"Deleted rowid found in KNN after deleting {to_del}"
|
||||||
|
# Count should match alive set (up to k)
|
||||||
|
assert len(rows) == min(10, len(alive))
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_with_rowid_in_constraint(db):
|
||||||
|
"""Delete rows and verify KNN with rowid_in filter excludes deleted rows."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
" embedding float[8] indexed by rescore(quantizer=int8)"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
for i in range(1, 11):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
||||||
|
[i, float_vec([float(i)] * 8)],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete rows 3, 5, 7
|
||||||
|
for r in [3, 5, 7]:
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = ?", [r])
|
||||||
|
|
||||||
|
# KNN with rowid IN (1,2,3,4,5) — should only return 1, 2, 4 (3 and 5 deleted)
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid FROM t WHERE embedding MATCH ? AND k = 5 AND rowid IN (1, 2, 3, 4, 5)",
|
||||||
|
[float_vec([1.0] * 8)],
|
||||||
|
).fetchall()
|
||||||
|
returned = {r["rowid"] for r in rows}
|
||||||
|
assert 3 not in returned
|
||||||
|
assert 5 not in returned
|
||||||
|
assert returned.issubset({1, 2, 4})
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_all_then_reinsert_batch(db):
|
||||||
|
"""Delete all rows, reinsert a new batch, verify KNN only returns new rows."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
# First batch
|
||||||
|
for i in range(1, 21):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
||||||
|
[i, float_vec([float(i)] * 8)],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete all
|
||||||
|
for i in range(1, 21):
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = ?", [i])
|
||||||
|
|
||||||
|
assert db.execute("SELECT count(*) FROM t").fetchone()[0] == 0
|
||||||
|
|
||||||
|
# Second batch with different rowids and vectors
|
||||||
|
for i in range(100, 110):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
||||||
|
[i, float_vec([float(i - 100)] * 8)],
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5",
|
||||||
|
[float_vec([0.0] * 8)],
|
||||||
|
).fetchall()
|
||||||
|
returned = {r["rowid"] for r in rows}
|
||||||
|
# All returned rowids should be from the second batch
|
||||||
|
assert returned.issubset(set(range(100, 110)))
|
||||||
|
|
||||||
|
|
||||||
def test_knn_int8_cosine(db):
|
def test_knn_int8_cosine(db):
|
||||||
"""Rescore with quantizer=int8 and distance_metric=cosine."""
|
"""Rescore with quantizer=int8 and distance_metric=cosine."""
|
||||||
db.execute(
|
db.execute(
|
||||||
|
|
|
||||||
|
|
@ -566,3 +566,162 @@ def test_multiple_vector_columns(db):
|
||||||
[float_vec([1.0] * 8)],
|
[float_vec([1.0] * 8)],
|
||||||
).fetchall()
|
).fetchall()
|
||||||
assert rows[0]["rowid"] == 2
|
assert rows[0]["rowid"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_corrupt_zeroblob_validity(db):
|
||||||
|
"""KNN should error (not crash) when rescore chunk rowids blob is zeroed out."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
||||||
|
[float_vec([1, 0, 0, 0, 0, 0, 0, 0])],
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
|
||||||
|
[float_vec([0, 1, 0, 0, 0, 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Corrupt: replace rowids with a truncated blob (wrong size)
|
||||||
|
db.execute("UPDATE t_chunks SET rowids = x'00'")
|
||||||
|
|
||||||
|
# Should error, not crash — blob size validation catches the mismatch
|
||||||
|
with pytest.raises(sqlite3.OperationalError):
|
||||||
|
db.execute(
|
||||||
|
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
||||||
|
[float_vec([1, 0, 0, 0, 0, 0, 0, 0])],
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
|
||||||
|
def test_corrupt_truncated_validity_blob(db):
|
||||||
|
"""KNN should error when rescore chunk validity blob is truncated."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
" embedding float[128] indexed by rescore(quantizer=bit)"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
for i in range(5):
|
||||||
|
import random
|
||||||
|
random.seed(i)
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
||||||
|
[i + 1, float_vec([random.gauss(0, 1) for _ in range(128)])],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Corrupt: truncate validity blob to 1 byte (should be chunk_size/8 = 128 bytes)
|
||||||
|
db.execute("UPDATE t_chunks SET validity = x'FF'")
|
||||||
|
|
||||||
|
with pytest.raises(sqlite3.OperationalError):
|
||||||
|
db.execute(
|
||||||
|
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
||||||
|
[float_vec([1.0] * 128)],
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
|
||||||
|
def test_rescore_text_pk_insert_knn_delete(db):
|
||||||
|
"""Rescore with text primary key: insert, KNN, delete, KNN again."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
" id text primary key,"
|
||||||
|
" embedding float[128] indexed by rescore(quantizer=bit)"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
|
||||||
|
import random
|
||||||
|
random.seed(99)
|
||||||
|
vecs = {}
|
||||||
|
for name in ["alpha", "beta", "gamma", "delta", "epsilon"]:
|
||||||
|
v = [random.gauss(0, 1) for _ in range(128)]
|
||||||
|
vecs[name] = v
|
||||||
|
db.execute("INSERT INTO t(id, embedding) VALUES (?, ?)", [name, float_vec(v)])
|
||||||
|
|
||||||
|
# KNN should return text IDs
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT id, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 3",
|
||||||
|
[float_vec(vecs["alpha"])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) >= 1
|
||||||
|
ids = [r["id"] for r in rows]
|
||||||
|
assert "alpha" in ids
|
||||||
|
|
||||||
|
# Delete and verify
|
||||||
|
db.execute("DELETE FROM t WHERE id = 'alpha'")
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT id FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 3",
|
||||||
|
[float_vec(vecs["alpha"])],
|
||||||
|
).fetchall()
|
||||||
|
ids = [r["id"] for r in rows]
|
||||||
|
assert "alpha" not in ids
|
||||||
|
assert len(rows) >= 1 # other results still returned
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_oversample(db):
|
||||||
|
"""oversample can be changed at query time via FTS5-style command."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
" embedding float[128] indexed by rescore(quantizer=bit, oversample=2)"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
random.seed(200)
|
||||||
|
for i in range(200):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
||||||
|
[i + 1, float_vec([random.gauss(0, 1) for _ in range(128)])],
|
||||||
|
)
|
||||||
|
|
||||||
|
query = float_vec([random.gauss(0, 1) for _ in range(128)])
|
||||||
|
|
||||||
|
# KNN with default oversample=2 (low)
|
||||||
|
rows_low = db.execute(
|
||||||
|
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
|
||||||
|
[query],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows_low) == 10
|
||||||
|
|
||||||
|
# Change oversample at runtime to high value
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('oversample=32')")
|
||||||
|
|
||||||
|
# KNN with oversample=32 (high) — same or better recall
|
||||||
|
rows_high = db.execute(
|
||||||
|
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
|
||||||
|
[query],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows_high) == 10
|
||||||
|
|
||||||
|
# Reset to original
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('oversample=2')")
|
||||||
|
|
||||||
|
rows_reset = db.execute(
|
||||||
|
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
|
||||||
|
[query],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows_reset) == 10
|
||||||
|
# After reset, should match the original low-oversample results
|
||||||
|
assert [r["rowid"] for r in rows_reset] == [r["rowid"] for r in rows_low]
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_oversample_error(db):
|
||||||
|
"""Invalid oversample values should error."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
" embedding float[128] indexed by rescore(quantizer=bit)"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
with pytest.raises(sqlite3.OperationalError, match="oversample must be >= 1"):
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('oversample=0')")
|
||||||
|
|
||||||
|
with pytest.raises(sqlite3.OperationalError, match="oversample must be >= 1"):
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('oversample=-5')")
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_command_errors(db):
|
||||||
|
"""Unknown command strings should produce a clear error."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
" embedding float[128] indexed by rescore(quantizer=bit)"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
with pytest.raises(sqlite3.OperationalError, match="unknown vec0 command"):
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('not_a_real_command')")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue