mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 00:36:56 +02:00
Add text PK, WAL concurrency tests, and fix bench-smoke config
Infrastructure improvements: - Fix benchmarks-ann Makefile: type=baseline -> type=vec0-flat (baseline was never a valid INDEX_REGISTRY key) - Add DiskANN + text primary key test: insert, KNN, delete, KNN - Add rescore + text primary key test: insert, KNN, delete, KNN - Add WAL concurrency test: reader sees snapshot isolation while writer has an open transaction, KNN works on reader's snapshot Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
d684178a12
commit
f2c9fb8f08
4 changed files with 138 additions and 4 deletions
|
|
@ -4,9 +4,9 @@ EXT = ../dist/vec0
|
||||||
|
|
||||||
# --- Baseline (brute-force) configs ---
|
# --- Baseline (brute-force) configs ---
|
||||||
BASELINES = \
|
BASELINES = \
|
||||||
"brute-float:type=baseline,variant=float" \
|
"brute-float:type=vec0-flat,variant=float" \
|
||||||
"brute-int8:type=baseline,variant=int8" \
|
"brute-int8:type=vec0-flat,variant=int8" \
|
||||||
"brute-bit:type=baseline,variant=bit"
|
"brute-bit:type=vec0-flat,variant=bit"
|
||||||
|
|
||||||
# --- IVF configs ---
|
# --- IVF configs ---
|
||||||
IVF_CONFIGS = \
|
IVF_CONFIGS = \
|
||||||
|
|
@ -43,7 +43,7 @@ ground-truth: seed
|
||||||
# --- Quick smoke test ---
|
# --- Quick smoke test ---
|
||||||
bench-smoke: seed
|
bench-smoke: seed
|
||||||
$(BENCH) --subset-size 5000 -k 10 -n 20 --dataset cohere1m -o runs \
|
$(BENCH) --subset-size 5000 -k 10 -n 20 --dataset cohere1m -o runs \
|
||||||
"brute-float:type=baseline,variant=float" \
|
"brute-float:type=vec0-flat,variant=float" \
|
||||||
"ivf-quick:type=ivf,nlist=16,nprobe=4" \
|
"ivf-quick:type=ivf,nlist=16,nprobe=4" \
|
||||||
"diskann-quick:type=diskann,R=48,L=64,quantizer=binary"
|
"diskann-quick:type=diskann,R=48,L=64,quantizer=binary"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1246,3 +1246,46 @@ def test_diskann_delete_interleaved_with_knn(db):
|
||||||
returned = {r["rowid"] for r in rows}
|
returned = {r["rowid"] for r in rows}
|
||||||
assert returned.issubset(alive), \
|
assert returned.issubset(alive), \
|
||||||
f"Deleted rowid {to_del} found in KNN results"
|
f"Deleted rowid {to_del} found in KNN results"
|
||||||
|
|
||||||
|
|
||||||
|
# ======================================================================
|
||||||
|
# Text primary key + DiskANN
|
||||||
|
# ======================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_diskann_text_pk_insert_knn_delete(db):
|
||||||
|
"""DiskANN with text primary key: insert, KNN, delete, KNN again."""
|
||||||
|
db.execute("""
|
||||||
|
CREATE VIRTUAL TABLE t USING vec0(
|
||||||
|
id text primary key,
|
||||||
|
emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
vecs = {
|
||||||
|
"alpha": [1, 0, 0, 0, 0, 0, 0, 0],
|
||||||
|
"beta": [0, 1, 0, 0, 0, 0, 0, 0],
|
||||||
|
"gamma": [0, 0, 1, 0, 0, 0, 0, 0],
|
||||||
|
"delta": [0, 0, 0, 1, 0, 0, 0, 0],
|
||||||
|
"epsilon": [0, 0, 0, 0, 1, 0, 0, 0],
|
||||||
|
}
|
||||||
|
for name, vec in vecs.items():
|
||||||
|
db.execute("INSERT INTO t(id, emb) VALUES (?, ?)", [name, _f32(vec)])
|
||||||
|
|
||||||
|
# KNN should return text IDs
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT id, distance FROM t WHERE emb MATCH ? AND k=3",
|
||||||
|
[_f32([1, 0, 0, 0, 0, 0, 0, 0])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) >= 1
|
||||||
|
ids = [r["id"] for r in rows]
|
||||||
|
assert "alpha" in ids # closest to query
|
||||||
|
|
||||||
|
# Delete and verify
|
||||||
|
db.execute("DELETE FROM t WHERE id = 'alpha'")
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT id FROM t WHERE emb MATCH ? AND k=3",
|
||||||
|
[_f32([1, 0, 0, 0, 0, 0, 0, 0])],
|
||||||
|
).fetchall()
|
||||||
|
ids = [r["id"] for r in rows]
|
||||||
|
assert "alpha" not in ids
|
||||||
|
|
|
||||||
|
|
@ -483,3 +483,57 @@ def test_delete_one_chunk_of_two_shrinks_pages(tmp_path):
|
||||||
row = db.execute("select emb from v where rowid = ?", [i]).fetchone()
|
row = db.execute("select emb from v where rowid = ?", [i]).fetchone()
|
||||||
assert row[0] == _f32([float(i)] * dims)
|
assert row[0] == _f32([float(i)] * dims)
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_wal_concurrent_reader_during_write(tmp_path):
|
||||||
|
"""In WAL mode, a reader should see a consistent snapshot while a writer inserts."""
|
||||||
|
dims = 4
|
||||||
|
db_path = str(tmp_path / "test.db")
|
||||||
|
|
||||||
|
# Writer: create table, insert initial rows, enable WAL
|
||||||
|
writer = sqlite3.connect(db_path)
|
||||||
|
writer.enable_load_extension(True)
|
||||||
|
writer.load_extension("dist/vec0")
|
||||||
|
writer.execute("PRAGMA journal_mode=WAL")
|
||||||
|
writer.execute(
|
||||||
|
f"CREATE VIRTUAL TABLE v USING vec0(emb float[{dims}])"
|
||||||
|
)
|
||||||
|
for i in range(1, 11):
|
||||||
|
writer.execute("INSERT INTO v(rowid, emb) VALUES (?, ?)", [i, _f32([float(i)] * dims)])
|
||||||
|
writer.commit()
|
||||||
|
|
||||||
|
# Reader: open separate connection, start read
|
||||||
|
reader = sqlite3.connect(db_path)
|
||||||
|
reader.enable_load_extension(True)
|
||||||
|
reader.load_extension("dist/vec0")
|
||||||
|
|
||||||
|
# Reader sees 10 rows
|
||||||
|
count_before = reader.execute("SELECT count(*) FROM v").fetchone()[0]
|
||||||
|
assert count_before == 10
|
||||||
|
|
||||||
|
# Writer inserts more rows (not yet committed)
|
||||||
|
writer.execute("BEGIN")
|
||||||
|
for i in range(11, 21):
|
||||||
|
writer.execute("INSERT INTO v(rowid, emb) VALUES (?, ?)", [i, _f32([float(i)] * dims)])
|
||||||
|
|
||||||
|
# Reader still sees 10 (WAL snapshot isolation)
|
||||||
|
count_during = reader.execute("SELECT count(*) FROM v").fetchone()[0]
|
||||||
|
assert count_during == 10
|
||||||
|
|
||||||
|
# KNN during writer's transaction should work on reader's snapshot
|
||||||
|
rows = reader.execute(
|
||||||
|
"SELECT rowid FROM v WHERE emb MATCH ? AND k = 5",
|
||||||
|
[_f32([1.0] * dims)],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 5
|
||||||
|
assert all(r[0] <= 10 for r in rows) # only original rows
|
||||||
|
|
||||||
|
# Writer commits
|
||||||
|
writer.commit()
|
||||||
|
|
||||||
|
# Reader sees new rows after re-query (new snapshot)
|
||||||
|
count_after = reader.execute("SELECT count(*) FROM v").fetchone()[0]
|
||||||
|
assert count_after == 20
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
reader.close()
|
||||||
|
|
|
||||||
|
|
@ -595,3 +595,40 @@ def test_corrupt_zeroblob_validity(db):
|
||||||
).fetchall()
|
).fetchall()
|
||||||
except sqlite3.OperationalError:
|
except sqlite3.OperationalError:
|
||||||
pass # Error is acceptable — crash is not
|
pass # Error is acceptable — crash is not
|
||||||
|
|
||||||
|
|
||||||
|
def test_rescore_text_pk_insert_knn_delete(db):
|
||||||
|
"""Rescore with text primary key: insert, KNN, delete, KNN again."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
" id text primary key,"
|
||||||
|
" embedding float[128] indexed by rescore(quantizer=bit)"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
|
||||||
|
import random
|
||||||
|
random.seed(99)
|
||||||
|
vecs = {}
|
||||||
|
for name in ["alpha", "beta", "gamma", "delta", "epsilon"]:
|
||||||
|
v = [random.gauss(0, 1) for _ in range(128)]
|
||||||
|
vecs[name] = v
|
||||||
|
db.execute("INSERT INTO t(id, embedding) VALUES (?, ?)", [name, float_vec(v)])
|
||||||
|
|
||||||
|
# KNN should return text IDs
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT id, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 3",
|
||||||
|
[float_vec(vecs["alpha"])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) >= 1
|
||||||
|
ids = [r["id"] for r in rows]
|
||||||
|
assert "alpha" in ids
|
||||||
|
|
||||||
|
# Delete and verify
|
||||||
|
db.execute("DELETE FROM t WHERE id = 'alpha'")
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT id FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 3",
|
||||||
|
[float_vec(vecs["alpha"])],
|
||||||
|
).fetchall()
|
||||||
|
ids = [r["id"] for r in rows]
|
||||||
|
assert "alpha" not in ids
|
||||||
|
assert len(rows) >= 1 # other results still returned
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue