sqlite-vec/tests/test-rescore.py
Alex Garcia 82f4eb08bf Add NULL checks after sqlite3_column_blob in rescore and DiskANN
sqlite3_column_blob() returns NULL for zero-length blobs or on OOM.
Several call sites in rescore KNN and DiskANN node/vector read passed
the result directly to memcpy without checking, risking NULL deref on
corrupt or empty databases. IVF already had proper NULL checks.

Adds corruption regression tests that truncate shadow table blobs and
verify the query errors cleanly instead of crashing.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-31 14:31:49 -07:00

597 lines
19 KiB
Python

"""Tests for the rescore index feature in sqlite-vec."""
import struct
import sqlite3
import pytest
import math
import random
@pytest.fixture()
def db():
db = sqlite3.connect(":memory:")
db.row_factory = sqlite3.Row
db.enable_load_extension(True)
db.load_extension("dist/vec0")
db.enable_load_extension(False)
return db
def float_vec(values):
"""Pack a list of floats into a blob for sqlite-vec."""
return struct.pack(f"{len(values)}f", *values)
def unpack_float_vec(blob):
"""Unpack a float vector blob."""
n = len(blob) // 4
return list(struct.unpack(f"{n}f", blob))
# ============================================================================
# Creation tests
# ============================================================================
def test_create_bit(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[128] indexed by rescore(quantizer=bit)"
")"
)
# Table exists and has the right structure
row = db.execute(
"SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'"
).fetchone()
assert row["cnt"] > 0
def test_create_int8(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[128] indexed by rescore(quantizer=int8)"
")"
)
row = db.execute(
"SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'"
).fetchone()
assert row["cnt"] > 0
def test_create_with_oversample(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[128] indexed by rescore(quantizer=bit, oversample=16)"
")"
)
row = db.execute(
"SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'"
).fetchone()
assert row["cnt"] > 0
def test_create_with_distance_metric(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[128] distance_metric=cosine indexed by rescore(quantizer=bit)"
")"
)
row = db.execute(
"SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'"
).fetchone()
assert row["cnt"] > 0
def test_create_error_missing_quantizer(db):
with pytest.raises(sqlite3.OperationalError):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[128] indexed by rescore(oversample=8)"
")"
)
def test_create_error_invalid_quantizer(db):
with pytest.raises(sqlite3.OperationalError):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[128] indexed by rescore(quantizer=float)"
")"
)
def test_create_error_on_bit_column(db):
with pytest.raises(sqlite3.OperationalError):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding bit[1024] indexed by rescore(quantizer=bit)"
")"
)
def test_create_error_on_int8_column(db):
with pytest.raises(sqlite3.OperationalError):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding int8[128] indexed by rescore(quantizer=bit)"
")"
)
def test_create_error_bad_oversample_zero(db):
with pytest.raises(sqlite3.OperationalError):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[128] indexed by rescore(quantizer=bit, oversample=0)"
")"
)
def test_create_error_bad_oversample_too_large(db):
with pytest.raises(sqlite3.OperationalError):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[128] indexed by rescore(quantizer=bit, oversample=999)"
")"
)
def test_create_error_bit_dim_not_divisible_by_8(db):
with pytest.raises(sqlite3.OperationalError):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[100] indexed by rescore(quantizer=bit)"
")"
)
# ============================================================================
# Shadow table tests
# ============================================================================
def test_shadow_tables_exist(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[128] indexed by rescore(quantizer=bit)"
")"
)
tables = [
r[0]
for r in db.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 't_%' ORDER BY name"
).fetchall()
]
assert "t_rescore_chunks00" in tables
assert "t_rescore_vectors00" in tables
# Rescore columns don't create _vector_chunks
assert "t_vector_chunks00" not in tables
def test_drop_cleans_up(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[128] indexed by rescore(quantizer=bit)"
")"
)
db.execute("DROP TABLE t")
tables = [
r[0]
for r in db.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 't_%'"
).fetchall()
]
assert len(tables) == 0
# ============================================================================
# Insert tests
# ============================================================================
def test_insert_single(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
assert row["cnt"] == 1
def test_insert_multiple(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=int8)"
")"
)
for i in range(10):
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
[i + 1, float_vec([float(i)] * 8)],
)
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
assert row["cnt"] == 10
# ============================================================================
# Delete tests
# ============================================================================
def test_delete_single(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
db.execute("DELETE FROM t WHERE rowid = 1")
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
assert row["cnt"] == 0
def test_delete_and_reinsert(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
db.execute("DELETE FROM t WHERE rowid = 1")
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([2.0] * 8)]
)
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
assert row["cnt"] == 1
def test_point_query_returns_float(db):
"""SELECT by rowid should return the original float vector, not quantized."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
vals = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec(vals)])
row = db.execute("SELECT embedding FROM t WHERE rowid = 1").fetchone()
result = unpack_float_vec(row["embedding"])
for a, b in zip(result, vals):
assert abs(a - b) < 1e-6
# ============================================================================
# KNN tests
# ============================================================================
def test_knn_basic_bit(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
# Insert vectors where [1,0,0,...] is closest to query [1,0,0,...]
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
[float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (3, ?)",
[float_vec([0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
)
rows = db.execute(
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
).fetchall()
assert len(rows) == 1
assert rows[0]["rowid"] == 1
def test_knn_basic_int8(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=int8)"
")"
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
[float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (3, ?)",
[float_vec([0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
)
rows = db.execute(
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
).fetchall()
assert len(rows) == 1
assert rows[0]["rowid"] == 1
def test_knn_returns_float_distances(db):
"""KNN should return float-precision distances, not quantized distances."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
v1 = [1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
v2 = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec(v1)])
db.execute("INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec(v2)])
query = [1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
rows = db.execute(
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 2",
[float_vec(query)],
).fetchall()
# First result should be exact match with distance ~0
assert rows[0]["rowid"] == 1
assert rows[0]["distance"] < 0.01
# Second result should have a float distance
# sqrt((1-0)^2 + (0.5-0)^2 + (0-1)^2) = sqrt(2.25) = 1.5
assert abs(rows[1]["distance"] - 1.5) < 0.01
def test_knn_recall(db):
"""With enough vectors, rescore should achieve good recall (>0.9)."""
dim = 32
n = 1000
k = 10
random.seed(42)
db.execute(
"CREATE VIRTUAL TABLE t_rescore USING vec0("
f" embedding float[{dim}] indexed by rescore(quantizer=bit, oversample=16)"
")"
)
db.execute(
f"CREATE VIRTUAL TABLE t_flat USING vec0(embedding float[{dim}])"
)
vectors = [[random.gauss(0, 1) for _ in range(dim)] for _ in range(n)]
for i, v in enumerate(vectors):
blob = float_vec(v)
db.execute(
"INSERT INTO t_rescore(rowid, embedding) VALUES (?, ?)", [i + 1, blob]
)
db.execute(
"INSERT INTO t_flat(rowid, embedding) VALUES (?, ?)", [i + 1, blob]
)
query = float_vec([random.gauss(0, 1) for _ in range(dim)])
rescore_rows = db.execute(
"SELECT rowid FROM t_rescore WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
[query, k],
).fetchall()
flat_rows = db.execute(
"SELECT rowid FROM t_flat WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
[query, k],
).fetchall()
rescore_ids = {r["rowid"] for r in rescore_rows}
flat_ids = {r["rowid"] for r in flat_rows}
recall = len(rescore_ids & flat_ids) / k
assert recall >= 0.7, f"Recall too low: {recall}"
def test_knn_cosine(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] distance_metric=cosine indexed by rescore(quantizer=bit)"
")"
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
[float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
)
rows = db.execute(
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
).fetchall()
assert rows[0]["rowid"] == 1
# cosine distance of identical vectors should be ~0
assert rows[0]["distance"] < 0.01
def test_knn_empty_table(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
rows = db.execute(
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5",
[float_vec([1.0] * 8)],
).fetchall()
assert len(rows) == 0
def test_knn_k_larger_than_n(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
db.execute("INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([2.0] * 8)])
rows = db.execute(
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
[float_vec([1.0] * 8)],
).fetchall()
assert len(rows) == 2
# ============================================================================
# Integration / edge case tests
# ============================================================================
def test_knn_with_rowid_in(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
for i in range(5):
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
[i + 1, float_vec([float(i)] * 8)],
)
# Only search within rowids 1, 3, 5
rows = db.execute(
"SELECT rowid FROM t WHERE embedding MATCH ? AND rowid IN (1, 3, 5) ORDER BY distance LIMIT 3",
[float_vec([0.0] * 8)],
).fetchall()
result_ids = {r["rowid"] for r in rows}
assert result_ids <= {1, 3, 5}
def test_knn_after_deletes(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=int8)"
")"
)
for i in range(10):
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
[i + 1, float_vec([float(i)] * 8)],
)
# Delete the closest match (rowid 1 = [0,0,...])
db.execute("DELETE FROM t WHERE rowid = 1")
rows = db.execute(
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5",
[float_vec([0.0] * 8)],
).fetchall()
# Verify ordering: rowid 2 ([1]*8) should be closest, then 3 ([2]*8), etc.
assert len(rows) >= 2
assert rows[0]["distance"] <= rows[1]["distance"]
# rowid 2 = [1,1,...] → L2 = sqrt(8) ≈ 2.83, rowid 3 = [2,2,...] → L2 = sqrt(32) ≈ 5.66
assert rows[0]["rowid"] == 2, f"Expected rowid 2, got {rows[0]['rowid']} with dist={rows[0]['distance']}"
def test_oversample_effect(db):
"""Higher oversample should give equal or better recall."""
dim = 32
n = 500
k = 10
random.seed(123)
vectors = [[random.gauss(0, 1) for _ in range(dim)] for _ in range(n)]
query = float_vec([random.gauss(0, 1) for _ in range(dim)])
recalls = []
for oversample in [2, 16]:
tname = f"t_os{oversample}"
db.execute(
f"CREATE VIRTUAL TABLE {tname} USING vec0("
f" embedding float[{dim}] indexed by rescore(quantizer=bit, oversample={oversample})"
")"
)
for i, v in enumerate(vectors):
db.execute(
f"INSERT INTO {tname}(rowid, embedding) VALUES (?, ?)",
[i + 1, float_vec(v)],
)
rows = db.execute(
f"SELECT rowid FROM {tname} WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
[query, k],
).fetchall()
recalls.append({r["rowid"] for r in rows})
# Also get ground truth
db.execute(f"CREATE VIRTUAL TABLE t_flat USING vec0(embedding float[{dim}])")
for i, v in enumerate(vectors):
db.execute(
"INSERT INTO t_flat(rowid, embedding) VALUES (?, ?)",
[i + 1, float_vec(v)],
)
gt_rows = db.execute(
"SELECT rowid FROM t_flat WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
[query, k],
).fetchall()
gt_ids = {r["rowid"] for r in gt_rows}
recall_low = len(recalls[0] & gt_ids) / k
recall_high = len(recalls[1] & gt_ids) / k
assert recall_high >= recall_low
def test_multiple_vector_columns(db):
"""One column with rescore, one without."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" v1 float[8] indexed by rescore(quantizer=bit),"
" v2 float[8]"
")"
)
db.execute(
"INSERT INTO t(rowid, v1, v2) VALUES (1, ?, ?)",
[float_vec([1.0] * 8), float_vec([0.0] * 8)],
)
db.execute(
"INSERT INTO t(rowid, v1, v2) VALUES (2, ?, ?)",
[float_vec([0.0] * 8), float_vec([1.0] * 8)],
)
# KNN on v1 (rescore path)
rows = db.execute(
"SELECT rowid FROM t WHERE v1 MATCH ? ORDER BY distance LIMIT 1",
[float_vec([1.0] * 8)],
).fetchall()
assert rows[0]["rowid"] == 1
# KNN on v2 (normal path)
rows = db.execute(
"SELECT rowid FROM t WHERE v2 MATCH ? ORDER BY distance LIMIT 1",
[float_vec([1.0] * 8)],
).fetchall()
assert rows[0]["rowid"] == 2
def test_corrupt_zeroblob_validity(db):
"""KNN should error (not crash) when rescore chunk rowids blob is zeroed out."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
[float_vec([1, 0, 0, 0, 0, 0, 0, 0])],
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
[float_vec([0, 1, 0, 0, 0, 0, 0, 0])],
)
# Corrupt: replace rowids with a truncated blob (wrong size)
db.execute("UPDATE t_chunks SET rowids = x'00'")
# Should not crash — may return wrong results or error
try:
rows = db.execute(
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
[float_vec([1, 0, 0, 0, 0, 0, 0, 0])],
).fetchall()
except sqlite3.OperationalError:
pass # Error is acceptable — crash is not