mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-26 09:16:22 +02:00
Add rescore index for ANN queries
Add rescore index type: stores full-precision float vectors in a rowid-keyed shadow table, quantizes to int8 for fast initial scan, then rescores top candidates with original vectors. Includes config parser, shadow table management, insert/delete support, KNN integration, compile flag (SQLITE_VEC_ENABLE_RESCORE), fuzz targets, and tests.
This commit is contained in:
parent
bf2455f2ba
commit
ba0db0b6d6
19 changed files with 3378 additions and 8 deletions
470
tests/test-rescore-mutations.py
Normal file
470
tests/test-rescore-mutations.py
Normal file
|
|
@ -0,0 +1,470 @@
|
|||
"""Mutation and edge-case tests for the rescore index feature."""
|
||||
import struct
|
||||
import sqlite3
|
||||
import pytest
|
||||
import math
|
||||
import random
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def db():
|
||||
db = sqlite3.connect(":memory:")
|
||||
db.row_factory = sqlite3.Row
|
||||
db.enable_load_extension(True)
|
||||
db.load_extension("dist/vec0")
|
||||
db.enable_load_extension(False)
|
||||
return db
|
||||
|
||||
|
||||
def float_vec(values):
|
||||
"""Pack a list of floats into a blob for sqlite-vec."""
|
||||
return struct.pack(f"{len(values)}f", *values)
|
||||
|
||||
|
||||
def unpack_float_vec(blob):
|
||||
"""Unpack a float vector blob."""
|
||||
n = len(blob) // 4
|
||||
return list(struct.unpack(f"{n}f", blob))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Error cases: rescore + aux/metadata/partition
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_create_error_with_aux_column(db):
|
||||
"""Rescore should reject auxiliary columns."""
|
||||
with pytest.raises(sqlite3.OperationalError, match="Auxiliary columns"):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit),"
|
||||
" +extra text"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
def test_create_error_with_metadata_column(db):
|
||||
"""Rescore should reject metadata columns."""
|
||||
with pytest.raises(sqlite3.OperationalError, match="Metadata columns"):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit),"
|
||||
" genre text"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
def test_create_error_with_partition_key(db):
|
||||
"""Rescore should reject partition key columns."""
|
||||
with pytest.raises(sqlite3.OperationalError, match="Partition key"):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit),"
|
||||
" user_id integer partition key"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Insert / batch / delete / update mutations
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_insert_single_verify_knn(db):
|
||||
"""Insert a single row and verify KNN returns it."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
|
||||
rows = db.execute(
|
||||
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
||||
[float_vec([1.0] * 8)],
|
||||
).fetchall()
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["rowid"] == 1
|
||||
assert rows[0]["distance"] < 0.01
|
||||
|
||||
|
||||
def test_insert_large_batch(db):
|
||||
"""Insert 200+ rows (multiple chunks with default chunk_size=1024) and verify count and KNN."""
|
||||
dim = 16
|
||||
n = 200
|
||||
random.seed(99)
|
||||
db.execute(
|
||||
f"CREATE VIRTUAL TABLE t USING vec0("
|
||||
f" embedding float[{dim}] indexed by rescore(quantizer=int8)"
|
||||
f")"
|
||||
)
|
||||
for i in range(n):
|
||||
v = [random.gauss(0, 1) for _ in range(dim)]
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
||||
[i + 1, float_vec(v)],
|
||||
)
|
||||
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
|
||||
assert row["cnt"] == n
|
||||
|
||||
# KNN should return results
|
||||
query = float_vec([random.gauss(0, 1) for _ in range(dim)])
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
|
||||
[query],
|
||||
).fetchall()
|
||||
assert len(rows) == 10
|
||||
|
||||
|
||||
def test_delete_all_rows(db):
|
||||
"""Delete every row, verify count=0, KNN returns empty."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
for i in range(20):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
||||
[i + 1, float_vec([float(i)] * 8)],
|
||||
)
|
||||
assert db.execute("SELECT count(*) as cnt FROM t").fetchone()["cnt"] == 20
|
||||
|
||||
for i in range(20):
|
||||
db.execute("DELETE FROM t WHERE rowid = ?", [i + 1])
|
||||
|
||||
assert db.execute("SELECT count(*) as cnt FROM t").fetchone()["cnt"] == 0
|
||||
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5",
|
||||
[float_vec([0.0] * 8)],
|
||||
).fetchall()
|
||||
assert len(rows) == 0
|
||||
|
||||
|
||||
def test_delete_then_reinsert_same_rowid(db):
|
||||
"""Delete rowid=1, re-insert rowid=1 with different vector, verify KNN uses new vector."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=int8)"
|
||||
")"
|
||||
)
|
||||
# Insert rowid=1 near origin, rowid=2 far from origin
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
||||
[float_vec([0.1] * 8)],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
|
||||
[float_vec([100.0] * 8)],
|
||||
)
|
||||
|
||||
# KNN to [0]*8 -> rowid 1 is closer
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
||||
[float_vec([0.0] * 8)],
|
||||
).fetchall()
|
||||
assert rows[0]["rowid"] == 1
|
||||
|
||||
# Delete rowid=1, re-insert with vector far from origin
|
||||
db.execute("DELETE FROM t WHERE rowid = 1")
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
||||
[float_vec([200.0] * 8)],
|
||||
)
|
||||
|
||||
# Now KNN to [0]*8 -> rowid 2 should be closer
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
||||
[float_vec([0.0] * 8)],
|
||||
).fetchall()
|
||||
assert rows[0]["rowid"] == 2
|
||||
|
||||
|
||||
def test_update_vector(db):
|
||||
"""UPDATE the vector column and verify KNN reflects new value."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=int8)"
|
||||
")"
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
||||
[float_vec([0.0] * 8)],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
|
||||
[float_vec([10.0] * 8)],
|
||||
)
|
||||
|
||||
# Update rowid=1 to be far away
|
||||
db.execute(
|
||||
"UPDATE t SET embedding = ? WHERE rowid = 1",
|
||||
[float_vec([100.0] * 8)],
|
||||
)
|
||||
|
||||
# Now KNN to [0]*8 -> rowid 2 should be closest
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
||||
[float_vec([0.0] * 8)],
|
||||
).fetchall()
|
||||
assert rows[0]["rowid"] == 2
|
||||
|
||||
|
||||
def test_knn_after_delete_all_but_one(db):
|
||||
"""Insert 50 rows, delete 49, KNN should only return the survivor."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
for i in range(50):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
||||
[i + 1, float_vec([float(i)] * 8)],
|
||||
)
|
||||
# Delete all except rowid=25
|
||||
for i in range(50):
|
||||
if i + 1 != 25:
|
||||
db.execute("DELETE FROM t WHERE rowid = ?", [i + 1])
|
||||
|
||||
assert db.execute("SELECT count(*) as cnt FROM t").fetchone()["cnt"] == 1
|
||||
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
|
||||
[float_vec([0.0] * 8)],
|
||||
).fetchall()
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["rowid"] == 25
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Edge cases
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_single_row_knn(db):
|
||||
"""Table with exactly 1 row. LIMIT 1 returns it; LIMIT 5 returns 1."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
|
||||
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
||||
[float_vec([1.0] * 8)],
|
||||
).fetchall()
|
||||
assert len(rows) == 1
|
||||
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5",
|
||||
[float_vec([1.0] * 8)],
|
||||
).fetchall()
|
||||
assert len(rows) == 1
|
||||
|
||||
|
||||
def test_knn_with_all_identical_vectors(db):
|
||||
"""All vectors are the same. All distances should be equal."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=int8)"
|
||||
")"
|
||||
)
|
||||
vec = [3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0]
|
||||
for i in range(10):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
||||
[i + 1, float_vec(vec)],
|
||||
)
|
||||
|
||||
rows = db.execute(
|
||||
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
|
||||
[float_vec(vec)],
|
||||
).fetchall()
|
||||
assert len(rows) == 10
|
||||
# All distances should be ~0 (exact match)
|
||||
for r in rows:
|
||||
assert r["distance"] < 0.01
|
||||
|
||||
|
||||
def test_zero_vector_insert(db):
|
||||
"""Insert the zero vector [0,0,...,0]. Should not crash quantization."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
||||
[float_vec([0.0] * 8)],
|
||||
)
|
||||
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
|
||||
assert row["cnt"] == 1
|
||||
|
||||
# Also test int8 quantizer with zero vector
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t2 USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=int8)"
|
||||
")"
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t2(rowid, embedding) VALUES (1, ?)",
|
||||
[float_vec([0.0] * 8)],
|
||||
)
|
||||
row = db.execute("SELECT count(*) as cnt FROM t2").fetchone()
|
||||
assert row["cnt"] == 1
|
||||
|
||||
|
||||
def test_very_large_values(db):
|
||||
"""Insert vectors with very large float values. Quantization should not crash."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=int8)"
|
||||
")"
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
||||
[float_vec([1e30] * 8)],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
|
||||
[float_vec([1e30, -1e30, 1e30, -1e30, 1e30, -1e30, 1e30, -1e30])],
|
||||
)
|
||||
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
|
||||
assert row["cnt"] == 2
|
||||
|
||||
|
||||
def test_negative_values(db):
|
||||
"""Insert vectors with all negative values. Bit quantization maps all to 0."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
||||
[float_vec([-1.0, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0, -8.0])],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
|
||||
[float_vec([-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8])],
|
||||
)
|
||||
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
|
||||
assert row["cnt"] == 2
|
||||
|
||||
# KNN should still work
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 2",
|
||||
[float_vec([-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8])],
|
||||
).fetchall()
|
||||
assert len(rows) == 2
|
||||
assert rows[0]["rowid"] == 2
|
||||
|
||||
|
||||
def test_single_dimension(db):
|
||||
"""Single-dimension vector (edge case for quantization)."""
|
||||
# int8 quantizer (bit needs dim divisible by 8)
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=int8)"
|
||||
")"
|
||||
)
|
||||
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
|
||||
db.execute("INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([5.0] * 8)])
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
||||
[float_vec([1.0] * 8)],
|
||||
).fetchall()
|
||||
assert rows[0]["rowid"] == 1
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# vec_debug() verification
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_vec_debug_contains_rescore(db):
|
||||
"""vec_debug() should contain 'rescore' in build flags when compiled with SQLITE_VEC_ENABLE_RESCORE."""
|
||||
row = db.execute("SELECT vec_debug() as d").fetchone()
|
||||
assert "rescore" in row["d"]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Insert batch recall test
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_insert_batch_recall(db):
|
||||
"""Insert 150 rows and verify KNN recall is reasonable (>0.6)."""
|
||||
dim = 16
|
||||
n = 150
|
||||
k = 10
|
||||
random.seed(77)
|
||||
|
||||
db.execute(
|
||||
f"CREATE VIRTUAL TABLE t_rescore USING vec0("
|
||||
f" embedding float[{dim}] indexed by rescore(quantizer=int8, oversample=16)"
|
||||
f")"
|
||||
)
|
||||
db.execute(
|
||||
f"CREATE VIRTUAL TABLE t_flat USING vec0(embedding float[{dim}])"
|
||||
)
|
||||
|
||||
vectors = [[random.gauss(0, 1) for _ in range(dim)] for _ in range(n)]
|
||||
for i, v in enumerate(vectors):
|
||||
blob = float_vec(v)
|
||||
db.execute(
|
||||
"INSERT INTO t_rescore(rowid, embedding) VALUES (?, ?)", [i + 1, blob]
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t_flat(rowid, embedding) VALUES (?, ?)", [i + 1, blob]
|
||||
)
|
||||
|
||||
query = float_vec([random.gauss(0, 1) for _ in range(dim)])
|
||||
|
||||
rescore_rows = db.execute(
|
||||
"SELECT rowid FROM t_rescore WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
|
||||
[query, k],
|
||||
).fetchall()
|
||||
flat_rows = db.execute(
|
||||
"SELECT rowid FROM t_flat WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
|
||||
[query, k],
|
||||
).fetchall()
|
||||
|
||||
rescore_ids = {r["rowid"] for r in rescore_rows}
|
||||
flat_ids = {r["rowid"] for r in flat_rows}
|
||||
recall = len(rescore_ids & flat_ids) / k
|
||||
assert recall >= 0.6, f"Recall too low: {recall}"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Distance metric variants
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_knn_int8_cosine(db):
|
||||
"""Rescore with quantizer=int8 and distance_metric=cosine."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] distance_metric=cosine indexed by rescore(quantizer=int8)"
|
||||
")"
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
||||
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
|
||||
[float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (3, ?)",
|
||||
[float_vec([1.0, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
)
|
||||
rows = db.execute(
|
||||
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 2",
|
||||
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
).fetchall()
|
||||
assert rows[0]["rowid"] == 1
|
||||
assert rows[0]["distance"] < 0.01
|
||||
Loading…
Add table
Add a link
Reference in a new issue