sqlite-vec/tests/test-rescore.py
Alex Garcia ba0db0b6d6 Add rescore index for ANN queries
Add rescore index type: stores full-precision float vectors in a rowid-keyed
shadow table, quantizes to int8 for fast initial scan, then rescores top
candidates with original vectors. Includes config parser, shadow table
management, insert/delete support, KNN integration, compile flag
(SQLITE_VEC_ENABLE_RESCORE), fuzz targets, and tests.
2026-03-29 19:45:54 -07:00

568 lines
18 KiB
Python

"""Tests for the rescore index feature in sqlite-vec."""
import struct
import sqlite3
import pytest
import math
import random
@pytest.fixture()
def db():
db = sqlite3.connect(":memory:")
db.row_factory = sqlite3.Row
db.enable_load_extension(True)
db.load_extension("dist/vec0")
db.enable_load_extension(False)
return db
def float_vec(values):
"""Pack a list of floats into a blob for sqlite-vec."""
return struct.pack(f"{len(values)}f", *values)
def unpack_float_vec(blob):
"""Unpack a float vector blob."""
n = len(blob) // 4
return list(struct.unpack(f"{n}f", blob))
# ============================================================================
# Creation tests
# ============================================================================
def test_create_bit(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[128] indexed by rescore(quantizer=bit)"
")"
)
# Table exists and has the right structure
row = db.execute(
"SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'"
).fetchone()
assert row["cnt"] > 0
def test_create_int8(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[128] indexed by rescore(quantizer=int8)"
")"
)
row = db.execute(
"SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'"
).fetchone()
assert row["cnt"] > 0
def test_create_with_oversample(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[128] indexed by rescore(quantizer=bit, oversample=16)"
")"
)
row = db.execute(
"SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'"
).fetchone()
assert row["cnt"] > 0
def test_create_with_distance_metric(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[128] distance_metric=cosine indexed by rescore(quantizer=bit)"
")"
)
row = db.execute(
"SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'"
).fetchone()
assert row["cnt"] > 0
def test_create_error_missing_quantizer(db):
with pytest.raises(sqlite3.OperationalError):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[128] indexed by rescore(oversample=8)"
")"
)
def test_create_error_invalid_quantizer(db):
with pytest.raises(sqlite3.OperationalError):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[128] indexed by rescore(quantizer=float)"
")"
)
def test_create_error_on_bit_column(db):
with pytest.raises(sqlite3.OperationalError):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding bit[1024] indexed by rescore(quantizer=bit)"
")"
)
def test_create_error_on_int8_column(db):
with pytest.raises(sqlite3.OperationalError):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding int8[128] indexed by rescore(quantizer=bit)"
")"
)
def test_create_error_bad_oversample_zero(db):
with pytest.raises(sqlite3.OperationalError):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[128] indexed by rescore(quantizer=bit, oversample=0)"
")"
)
def test_create_error_bad_oversample_too_large(db):
with pytest.raises(sqlite3.OperationalError):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[128] indexed by rescore(quantizer=bit, oversample=999)"
")"
)
def test_create_error_bit_dim_not_divisible_by_8(db):
with pytest.raises(sqlite3.OperationalError):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[100] indexed by rescore(quantizer=bit)"
")"
)
# ============================================================================
# Shadow table tests
# ============================================================================
def test_shadow_tables_exist(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[128] indexed by rescore(quantizer=bit)"
")"
)
tables = [
r[0]
for r in db.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 't_%' ORDER BY name"
).fetchall()
]
assert "t_rescore_chunks00" in tables
assert "t_rescore_vectors00" in tables
# Rescore columns don't create _vector_chunks
assert "t_vector_chunks00" not in tables
def test_drop_cleans_up(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[128] indexed by rescore(quantizer=bit)"
")"
)
db.execute("DROP TABLE t")
tables = [
r[0]
for r in db.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 't_%'"
).fetchall()
]
assert len(tables) == 0
# ============================================================================
# Insert tests
# ============================================================================
def test_insert_single(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
assert row["cnt"] == 1
def test_insert_multiple(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=int8)"
")"
)
for i in range(10):
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
[i + 1, float_vec([float(i)] * 8)],
)
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
assert row["cnt"] == 10
# ============================================================================
# Delete tests
# ============================================================================
def test_delete_single(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
db.execute("DELETE FROM t WHERE rowid = 1")
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
assert row["cnt"] == 0
def test_delete_and_reinsert(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
db.execute("DELETE FROM t WHERE rowid = 1")
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([2.0] * 8)]
)
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
assert row["cnt"] == 1
def test_point_query_returns_float(db):
"""SELECT by rowid should return the original float vector, not quantized."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
vals = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec(vals)])
row = db.execute("SELECT embedding FROM t WHERE rowid = 1").fetchone()
result = unpack_float_vec(row["embedding"])
for a, b in zip(result, vals):
assert abs(a - b) < 1e-6
# ============================================================================
# KNN tests
# ============================================================================
def test_knn_basic_bit(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
# Insert vectors where [1,0,0,...] is closest to query [1,0,0,...]
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
[float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (3, ?)",
[float_vec([0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
)
rows = db.execute(
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
).fetchall()
assert len(rows) == 1
assert rows[0]["rowid"] == 1
def test_knn_basic_int8(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=int8)"
")"
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
[float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (3, ?)",
[float_vec([0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
)
rows = db.execute(
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
).fetchall()
assert len(rows) == 1
assert rows[0]["rowid"] == 1
def test_knn_returns_float_distances(db):
"""KNN should return float-precision distances, not quantized distances."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
v1 = [1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
v2 = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec(v1)])
db.execute("INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec(v2)])
query = [1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
rows = db.execute(
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 2",
[float_vec(query)],
).fetchall()
# First result should be exact match with distance ~0
assert rows[0]["rowid"] == 1
assert rows[0]["distance"] < 0.01
# Second result should have a float distance
# sqrt((1-0)^2 + (0.5-0)^2 + (0-1)^2) = sqrt(2.25) = 1.5
assert abs(rows[1]["distance"] - 1.5) < 0.01
def test_knn_recall(db):
"""With enough vectors, rescore should achieve good recall (>0.9)."""
dim = 32
n = 1000
k = 10
random.seed(42)
db.execute(
"CREATE VIRTUAL TABLE t_rescore USING vec0("
f" embedding float[{dim}] indexed by rescore(quantizer=bit, oversample=16)"
")"
)
db.execute(
f"CREATE VIRTUAL TABLE t_flat USING vec0(embedding float[{dim}])"
)
vectors = [[random.gauss(0, 1) for _ in range(dim)] for _ in range(n)]
for i, v in enumerate(vectors):
blob = float_vec(v)
db.execute(
"INSERT INTO t_rescore(rowid, embedding) VALUES (?, ?)", [i + 1, blob]
)
db.execute(
"INSERT INTO t_flat(rowid, embedding) VALUES (?, ?)", [i + 1, blob]
)
query = float_vec([random.gauss(0, 1) for _ in range(dim)])
rescore_rows = db.execute(
"SELECT rowid FROM t_rescore WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
[query, k],
).fetchall()
flat_rows = db.execute(
"SELECT rowid FROM t_flat WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
[query, k],
).fetchall()
rescore_ids = {r["rowid"] for r in rescore_rows}
flat_ids = {r["rowid"] for r in flat_rows}
recall = len(rescore_ids & flat_ids) / k
assert recall >= 0.7, f"Recall too low: {recall}"
def test_knn_cosine(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] distance_metric=cosine indexed by rescore(quantizer=bit)"
")"
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
)
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
[float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
)
rows = db.execute(
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
).fetchall()
assert rows[0]["rowid"] == 1
# cosine distance of identical vectors should be ~0
assert rows[0]["distance"] < 0.01
def test_knn_empty_table(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
rows = db.execute(
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5",
[float_vec([1.0] * 8)],
).fetchall()
assert len(rows) == 0
def test_knn_k_larger_than_n(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
db.execute("INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([2.0] * 8)])
rows = db.execute(
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
[float_vec([1.0] * 8)],
).fetchall()
assert len(rows) == 2
# ============================================================================
# Integration / edge case tests
# ============================================================================
def test_knn_with_rowid_in(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
for i in range(5):
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
[i + 1, float_vec([float(i)] * 8)],
)
# Only search within rowids 1, 3, 5
rows = db.execute(
"SELECT rowid FROM t WHERE embedding MATCH ? AND rowid IN (1, 3, 5) ORDER BY distance LIMIT 3",
[float_vec([0.0] * 8)],
).fetchall()
result_ids = {r["rowid"] for r in rows}
assert result_ids <= {1, 3, 5}
def test_knn_after_deletes(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=int8)"
")"
)
for i in range(10):
db.execute(
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
[i + 1, float_vec([float(i)] * 8)],
)
# Delete the closest match (rowid 1 = [0,0,...])
db.execute("DELETE FROM t WHERE rowid = 1")
rows = db.execute(
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5",
[float_vec([0.0] * 8)],
).fetchall()
# Verify ordering: rowid 2 ([1]*8) should be closest, then 3 ([2]*8), etc.
assert len(rows) >= 2
assert rows[0]["distance"] <= rows[1]["distance"]
# rowid 2 = [1,1,...] → L2 = sqrt(8) ≈ 2.83, rowid 3 = [2,2,...] → L2 = sqrt(32) ≈ 5.66
assert rows[0]["rowid"] == 2, f"Expected rowid 2, got {rows[0]['rowid']} with dist={rows[0]['distance']}"
def test_oversample_effect(db):
"""Higher oversample should give equal or better recall."""
dim = 32
n = 500
k = 10
random.seed(123)
vectors = [[random.gauss(0, 1) for _ in range(dim)] for _ in range(n)]
query = float_vec([random.gauss(0, 1) for _ in range(dim)])
recalls = []
for oversample in [2, 16]:
tname = f"t_os{oversample}"
db.execute(
f"CREATE VIRTUAL TABLE {tname} USING vec0("
f" embedding float[{dim}] indexed by rescore(quantizer=bit, oversample={oversample})"
")"
)
for i, v in enumerate(vectors):
db.execute(
f"INSERT INTO {tname}(rowid, embedding) VALUES (?, ?)",
[i + 1, float_vec(v)],
)
rows = db.execute(
f"SELECT rowid FROM {tname} WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
[query, k],
).fetchall()
recalls.append({r["rowid"] for r in rows})
# Also get ground truth
db.execute(f"CREATE VIRTUAL TABLE t_flat USING vec0(embedding float[{dim}])")
for i, v in enumerate(vectors):
db.execute(
"INSERT INTO t_flat(rowid, embedding) VALUES (?, ?)",
[i + 1, float_vec(v)],
)
gt_rows = db.execute(
"SELECT rowid FROM t_flat WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
[query, k],
).fetchall()
gt_ids = {r["rowid"] for r in gt_rows}
recall_low = len(recalls[0] & gt_ids) / k
recall_high = len(recalls[1] & gt_ids) / k
assert recall_high >= recall_low
def test_multiple_vector_columns(db):
"""One column with rescore, one without."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" v1 float[8] indexed by rescore(quantizer=bit),"
" v2 float[8]"
")"
)
db.execute(
"INSERT INTO t(rowid, v1, v2) VALUES (1, ?, ?)",
[float_vec([1.0] * 8), float_vec([0.0] * 8)],
)
db.execute(
"INSERT INTO t(rowid, v1, v2) VALUES (2, ?, ?)",
[float_vec([0.0] * 8), float_vec([1.0] * 8)],
)
# KNN on v1 (rescore path)
rows = db.execute(
"SELECT rowid FROM t WHERE v1 MATCH ? ORDER BY distance LIMIT 1",
[float_vec([1.0] * 8)],
).fetchall()
assert rows[0]["rowid"] == 1
# KNN on v2 (normal path)
rows = db.execute(
"SELECT rowid FROM t WHERE v2 MATCH ? ORDER BY distance LIMIT 1",
[float_vec([1.0] * 8)],
).fetchall()
assert rows[0]["rowid"] == 2