sqlite-vec/tests/test-rescore-mutations.py
Alex Garcia b00865429b Filter deleted nodes from DiskANN search results and add delete tests
DiskANN's delete repair only fixes forward edges (nodes the deleted
node pointed to). Stale reverse edges can cause deleted rowids to
appear in search results. Fix: track a 'confirmed' flag on each
search candidate, set when the full-precision vector is successfully
read during re-ranking. Only confirmed candidates are included in
output. Zero additional SQL queries — piggybacks on the existing
re-rank vector read.

Also adds delete hardening tests:
- Rescore: interleaved delete+KNN, rowid_in after deletes, full
  delete+reinsert cycle
- DiskANN: delete+reinsert cycles with KNN verification, interleaved
  delete+KNN

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-31 17:13:29 -07:00

568 lines
18 KiB
Python

"""Mutation and edge-case tests for the rescore index feature."""
import struct
import sqlite3
import pytest
import math
import random
@pytest.fixture()
def db():
db = sqlite3.connect(":memory:")
db.row_factory = sqlite3.Row
db.enable_load_extension(True)
db.load_extension("dist/vec0")
db.enable_load_extension(False)
return db
def float_vec(values):
"""Pack a list of floats into a blob for sqlite-vec."""
return struct.pack(f"{len(values)}f", *values)
def unpack_float_vec(blob):
"""Unpack a float vector blob."""
n = len(blob) // 4
return list(struct.unpack(f"{n}f", blob))
# ============================================================================
# Error cases: rescore + aux/metadata/partition
# ============================================================================
def test_create_error_with_aux_column(db):
"""Rescore should reject auxiliary columns."""
with pytest.raises(sqlite3.OperationalError, match="Auxiliary columns"):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit),"
" +extra text"
")"
)
def test_create_error_with_metadata_column(db):
"""Rescore should reject metadata columns."""
with pytest.raises(sqlite3.OperationalError, match="Metadata columns"):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit),"
" genre text"
")"
)
def test_create_error_with_partition_key(db):
"""Rescore should reject partition key columns."""
with pytest.raises(sqlite3.OperationalError, match="Partition key"):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit),"
" user_id integer partition key"
")"
)
# ============================================================================
# Insert / batch / delete / update mutations
# ============================================================================
def test_insert_single_verify_knn(db):
"""Insert a single row and verify KNN returns it."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
rows = db.execute(
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
[float_vec([1.0] * 8)],
).fetchall()
assert len(rows) == 1
assert rows[0]["rowid"] == 1
assert rows[0]["distance"] < 0.01
def test_insert_large_batch(db):
"""Insert 200+ rows (multiple chunks with default chunk_size=1024) and verify count and KNN."""
dim = 16
n = 200
random.seed(99)
db.execute(
f"CREATE VIRTUAL TABLE t USING vec0("
f" embedding float[{dim}] indexed by rescore(quantizer=int8)"
f")"
)
for i in range(n):
v = [random.gauss(0, 1) for _ in range(dim)]
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
[i + 1, float_vec(v)],
)
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
assert row["cnt"] == n
# KNN should return results
query = float_vec([random.gauss(0, 1) for _ in range(dim)])
rows = db.execute(
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
[query],
).fetchall()
assert len(rows) == 10
def test_delete_all_rows(db):
"""Delete every row, verify count=0, KNN returns empty."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
for i in range(20):
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
[i + 1, float_vec([float(i)] * 8)],
)
assert db.execute("SELECT count(*) as cnt FROM t").fetchone()["cnt"] == 20
for i in range(20):
db.execute("DELETE FROM t WHERE rowid = ?", [i + 1])
assert db.execute("SELECT count(*) as cnt FROM t").fetchone()["cnt"] == 0
rows = db.execute(
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5",
[float_vec([0.0] * 8)],
).fetchall()
assert len(rows) == 0
def test_delete_then_reinsert_same_rowid(db):
"""Delete rowid=1, re-insert rowid=1 with different vector, verify KNN uses new vector."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=int8)"
")"
)
# Insert rowid=1 near origin, rowid=2 far from origin
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
[float_vec([0.1] * 8)],
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
[float_vec([100.0] * 8)],
)
# KNN to [0]*8 -> rowid 1 is closer
rows = db.execute(
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
[float_vec([0.0] * 8)],
).fetchall()
assert rows[0]["rowid"] == 1
# Delete rowid=1, re-insert with vector far from origin
db.execute("DELETE FROM t WHERE rowid = 1")
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
[float_vec([200.0] * 8)],
)
# Now KNN to [0]*8 -> rowid 2 should be closer
rows = db.execute(
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
[float_vec([0.0] * 8)],
).fetchall()
assert rows[0]["rowid"] == 2
def test_update_vector(db):
"""UPDATE the vector column and verify KNN reflects new value."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=int8)"
")"
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
[float_vec([0.0] * 8)],
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
[float_vec([10.0] * 8)],
)
# Update rowid=1 to be far away
db.execute(
"UPDATE t SET embedding = ? WHERE rowid = 1",
[float_vec([100.0] * 8)],
)
# Now KNN to [0]*8 -> rowid 2 should be closest
rows = db.execute(
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
[float_vec([0.0] * 8)],
).fetchall()
assert rows[0]["rowid"] == 2
def test_knn_after_delete_all_but_one(db):
"""Insert 50 rows, delete 49, KNN should only return the survivor."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
for i in range(50):
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
[i + 1, float_vec([float(i)] * 8)],
)
# Delete all except rowid=25
for i in range(50):
if i + 1 != 25:
db.execute("DELETE FROM t WHERE rowid = ?", [i + 1])
assert db.execute("SELECT count(*) as cnt FROM t").fetchone()["cnt"] == 1
rows = db.execute(
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
[float_vec([0.0] * 8)],
).fetchall()
assert len(rows) == 1
assert rows[0]["rowid"] == 25
# ============================================================================
# Edge cases
# ============================================================================
def test_single_row_knn(db):
"""Table with exactly 1 row. LIMIT 1 returns it; LIMIT 5 returns 1."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
rows = db.execute(
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
[float_vec([1.0] * 8)],
).fetchall()
assert len(rows) == 1
rows = db.execute(
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5",
[float_vec([1.0] * 8)],
).fetchall()
assert len(rows) == 1
def test_knn_with_all_identical_vectors(db):
"""All vectors are the same. All distances should be equal."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=int8)"
")"
)
vec = [3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0]
for i in range(10):
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
[i + 1, float_vec(vec)],
)
rows = db.execute(
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
[float_vec(vec)],
).fetchall()
assert len(rows) == 10
# All distances should be ~0 (exact match)
for r in rows:
assert r["distance"] < 0.01
def test_zero_vector_insert(db):
"""Insert the zero vector [0,0,...,0]. Should not crash quantization."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
[float_vec([0.0] * 8)],
)
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
assert row["cnt"] == 1
# Also test int8 quantizer with zero vector
db.execute(
"CREATE VIRTUAL TABLE t2 USING vec0("
" embedding float[8] indexed by rescore(quantizer=int8)"
")"
)
db.execute(
"INSERT INTO t2(rowid, embedding) VALUES (1, ?)",
[float_vec([0.0] * 8)],
)
row = db.execute("SELECT count(*) as cnt FROM t2").fetchone()
assert row["cnt"] == 1
def test_very_large_values(db):
"""Insert vectors with very large float values. Quantization should not crash."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=int8)"
")"
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
[float_vec([1e30] * 8)],
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
[float_vec([1e30, -1e30, 1e30, -1e30, 1e30, -1e30, 1e30, -1e30])],
)
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
assert row["cnt"] == 2
def test_negative_values(db):
"""Insert vectors with all negative values. Bit quantization maps all to 0."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
[float_vec([-1.0, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0, -8.0])],
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
[float_vec([-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8])],
)
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
assert row["cnt"] == 2
# KNN should still work
rows = db.execute(
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 2",
[float_vec([-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8])],
).fetchall()
assert len(rows) == 2
assert rows[0]["rowid"] == 2
def test_single_dimension(db):
"""Single-dimension vector (edge case for quantization)."""
# int8 quantizer (bit needs dim divisible by 8)
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=int8)"
")"
)
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
db.execute("INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([5.0] * 8)])
rows = db.execute(
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
[float_vec([1.0] * 8)],
).fetchall()
assert rows[0]["rowid"] == 1
# ============================================================================
# vec_debug() verification
# ============================================================================
def test_vec_debug_contains_rescore(db):
"""vec_debug() should contain 'rescore' in build flags when compiled with SQLITE_VEC_ENABLE_RESCORE."""
row = db.execute("SELECT vec_debug() as d").fetchone()
assert "rescore" in row["d"]
# ============================================================================
# Insert batch recall test
# ============================================================================
def test_insert_batch_recall(db):
"""Insert 150 rows and verify KNN recall is reasonable (>0.6)."""
dim = 16
n = 150
k = 10
random.seed(77)
db.execute(
f"CREATE VIRTUAL TABLE t_rescore USING vec0("
f" embedding float[{dim}] indexed by rescore(quantizer=int8, oversample=16)"
f")"
)
db.execute(
f"CREATE VIRTUAL TABLE t_flat USING vec0(embedding float[{dim}])"
)
vectors = [[random.gauss(0, 1) for _ in range(dim)] for _ in range(n)]
for i, v in enumerate(vectors):
blob = float_vec(v)
db.execute(
"INSERT INTO t_rescore(rowid, embedding) VALUES (?, ?)", [i + 1, blob]
)
db.execute(
"INSERT INTO t_flat(rowid, embedding) VALUES (?, ?)", [i + 1, blob]
)
query = float_vec([random.gauss(0, 1) for _ in range(dim)])
rescore_rows = db.execute(
"SELECT rowid FROM t_rescore WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
[query, k],
).fetchall()
flat_rows = db.execute(
"SELECT rowid FROM t_flat WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
[query, k],
).fetchall()
rescore_ids = {r["rowid"] for r in rescore_rows}
flat_ids = {r["rowid"] for r in flat_rows}
recall = len(rescore_ids & flat_ids) / k
assert recall >= 0.6, f"Recall too low: {recall}"
# ============================================================================
# Distance metric variants
# ============================================================================
def test_delete_interleaved_with_knn(db):
"""Delete rows one at a time, running KNN after each delete to verify correctness."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
N = 30
random.seed(42)
vecs = {i: [random.gauss(0, 1) for _ in range(8)] for i in range(1, N + 1)}
for rowid, vec in vecs.items():
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
[rowid, float_vec(vec)],
)
alive = set(vecs.keys())
query = [0.0] * 8
for to_del in [5, 10, 15, 20, 25]:
db.execute("DELETE FROM t WHERE rowid = ?", [to_del])
alive.discard(to_del)
rows = db.execute(
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
[float_vec(query)],
).fetchall()
returned = {r["rowid"] for r in rows}
# All returned rows must be alive (not deleted)
assert returned.issubset(alive), f"Deleted rowid found in KNN after deleting {to_del}"
# Count should match alive set (up to k)
assert len(rows) == min(10, len(alive))
def test_delete_with_rowid_in_constraint(db):
"""Delete rows and verify KNN with rowid_in filter excludes deleted rows."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=int8)"
")"
)
for i in range(1, 11):
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
[i, float_vec([float(i)] * 8)],
)
# Delete rows 3, 5, 7
for r in [3, 5, 7]:
db.execute("DELETE FROM t WHERE rowid = ?", [r])
# KNN with rowid IN (1,2,3,4,5) — should only return 1, 2, 4 (3 and 5 deleted)
rows = db.execute(
"SELECT rowid FROM t WHERE embedding MATCH ? AND k = 5 AND rowid IN (1, 2, 3, 4, 5)",
[float_vec([1.0] * 8)],
).fetchall()
returned = {r["rowid"] for r in rows}
assert 3 not in returned
assert 5 not in returned
assert returned.issubset({1, 2, 4})
def test_delete_all_then_reinsert_batch(db):
"""Delete all rows, reinsert a new batch, verify KNN only returns new rows."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
# First batch
for i in range(1, 21):
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
[i, float_vec([float(i)] * 8)],
)
# Delete all
for i in range(1, 21):
db.execute("DELETE FROM t WHERE rowid = ?", [i])
assert db.execute("SELECT count(*) FROM t").fetchone()[0] == 0
# Second batch with different rowids and vectors
for i in range(100, 110):
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
[i, float_vec([float(i - 100)] * 8)],
)
rows = db.execute(
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5",
[float_vec([0.0] * 8)],
).fetchall()
returned = {r["rowid"] for r in rows}
# All returned rowids should be from the second batch
assert returned.issubset(set(range(100, 110)))
def test_knn_int8_cosine(db):
"""Rescore with quantizer=int8 and distance_metric=cosine."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] distance_metric=cosine indexed by rescore(quantizer=int8)"
")"
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
[float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (3, ?)",
[float_vec([1.0, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
)
rows = db.execute(
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 2",
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
).fetchall()
assert rows[0]["rowid"] == 1
assert rows[0]["distance"] < 0.01