"""Tests for the rescore index feature in sqlite-vec.""" import struct import sqlite3 import pytest import math import random @pytest.fixture() def db(): db = sqlite3.connect(":memory:") db.row_factory = sqlite3.Row db.enable_load_extension(True) db.load_extension("dist/vec0") db.enable_load_extension(False) return db def float_vec(values): """Pack a list of floats into a blob for sqlite-vec.""" return struct.pack(f"{len(values)}f", *values) def unpack_float_vec(blob): """Unpack a float vector blob.""" n = len(blob) // 4 return list(struct.unpack(f"{n}f", blob)) # ============================================================================ # Creation tests # ============================================================================ def test_create_bit(db): db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[128] indexed by rescore(quantizer=bit)" ")" ) # Table exists and has the right structure row = db.execute( "SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'" ).fetchone() assert row["cnt"] > 0 def test_create_int8(db): db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[128] indexed by rescore(quantizer=int8)" ")" ) row = db.execute( "SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'" ).fetchone() assert row["cnt"] > 0 def test_create_with_oversample(db): db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[128] indexed by rescore(quantizer=bit, oversample=16)" ")" ) row = db.execute( "SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'" ).fetchone() assert row["cnt"] > 0 def test_create_with_distance_metric(db): db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[128] distance_metric=cosine indexed by rescore(quantizer=bit)" ")" ) row = db.execute( "SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'" ).fetchone() assert row["cnt"] > 0 def test_create_error_missing_quantizer(db): with pytest.raises(sqlite3.OperationalError): db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[128] indexed by rescore(oversample=8)" ")" ) def test_create_error_invalid_quantizer(db): with pytest.raises(sqlite3.OperationalError): db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[128] indexed by rescore(quantizer=float)" ")" ) def test_create_error_on_bit_column(db): with pytest.raises(sqlite3.OperationalError): db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding bit[1024] indexed by rescore(quantizer=bit)" ")" ) def test_create_error_on_int8_column(db): with pytest.raises(sqlite3.OperationalError): db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding int8[128] indexed by rescore(quantizer=bit)" ")" ) def test_create_error_bad_oversample_zero(db): with pytest.raises(sqlite3.OperationalError): db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[128] indexed by rescore(quantizer=bit, oversample=0)" ")" ) def test_create_error_bad_oversample_too_large(db): with pytest.raises(sqlite3.OperationalError): db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[128] indexed by rescore(quantizer=bit, oversample=999)" ")" ) def test_create_error_bit_dim_not_divisible_by_8(db): with pytest.raises(sqlite3.OperationalError): db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[100] indexed by rescore(quantizer=bit)" ")" ) # ============================================================================ # Shadow table tests # ============================================================================ def test_shadow_tables_exist(db): db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[128] indexed by rescore(quantizer=bit)" ")" ) tables = [ r[0] for r in db.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 't_%' ORDER BY name" ).fetchall() ] assert "t_rescore_chunks00" in tables assert "t_rescore_vectors00" in tables # Rescore columns don't create _vector_chunks assert "t_vector_chunks00" not in tables def test_drop_cleans_up(db): db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[128] indexed by rescore(quantizer=bit)" ")" ) db.execute("DROP TABLE t") tables = [ r[0] for r in db.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 't_%'" ).fetchall() ] assert len(tables) == 0 # ============================================================================ # Insert tests # ============================================================================ def test_insert_single(db): db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[8] indexed by rescore(quantizer=bit)" ")" ) db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)]) row = db.execute("SELECT count(*) as cnt FROM t").fetchone() assert row["cnt"] == 1 def test_insert_multiple(db): db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[8] indexed by rescore(quantizer=int8)" ")" ) for i in range(10): db.execute( "INSERT INTO t(rowid, embedding) VALUES (?, ?)", [i + 1, float_vec([float(i)] * 8)], ) row = db.execute("SELECT count(*) as cnt FROM t").fetchone() assert row["cnt"] == 10 # ============================================================================ # Delete tests # ============================================================================ def test_delete_single(db): db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[8] indexed by rescore(quantizer=bit)" ")" ) db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)]) db.execute("DELETE FROM t WHERE rowid = 1") row = db.execute("SELECT count(*) as cnt FROM t").fetchone() assert row["cnt"] == 0 def test_delete_and_reinsert(db): db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[8] indexed by rescore(quantizer=bit)" ")" ) db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)]) db.execute("DELETE FROM t WHERE rowid = 1") db.execute( "INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([2.0] * 8)] ) row = db.execute("SELECT count(*) as cnt FROM t").fetchone() assert row["cnt"] == 1 def test_point_query_returns_float(db): """SELECT by rowid should return the original float vector, not quantized.""" db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[8] indexed by rescore(quantizer=bit)" ")" ) vals = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec(vals)]) row = db.execute("SELECT embedding FROM t WHERE rowid = 1").fetchone() result = unpack_float_vec(row["embedding"]) for a, b in zip(result, vals): assert abs(a - b) < 1e-6 # ============================================================================ # KNN tests # ============================================================================ def test_knn_basic_bit(db): db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[8] indexed by rescore(quantizer=bit)" ")" ) # Insert vectors where [1,0,0,...] is closest to query [1,0,0,...] db.execute( "INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], ) db.execute( "INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], ) db.execute( "INSERT INTO t(rowid, embedding) VALUES (3, ?)", [float_vec([0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0])], ) rows = db.execute( "SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1", [float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], ).fetchall() assert len(rows) == 1 assert rows[0]["rowid"] == 1 def test_knn_basic_int8(db): db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[8] indexed by rescore(quantizer=int8)" ")" ) db.execute( "INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], ) db.execute( "INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], ) db.execute( "INSERT INTO t(rowid, embedding) VALUES (3, ?)", [float_vec([0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0])], ) rows = db.execute( "SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1", [float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], ).fetchall() assert len(rows) == 1 assert rows[0]["rowid"] == 1 def test_knn_returns_float_distances(db): """KNN should return float-precision distances, not quantized distances.""" db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[8] indexed by rescore(quantizer=bit)" ")" ) v1 = [1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] v2 = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec(v1)]) db.execute("INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec(v2)]) query = [1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] rows = db.execute( "SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 2", [float_vec(query)], ).fetchall() # First result should be exact match with distance ~0 assert rows[0]["rowid"] == 1 assert rows[0]["distance"] < 0.01 # Second result should have a float distance # sqrt((1-0)^2 + (0.5-0)^2 + (0-1)^2) = sqrt(2.25) = 1.5 assert abs(rows[1]["distance"] - 1.5) < 0.01 def test_knn_recall(db): """With enough vectors, rescore should achieve good recall (>0.9).""" dim = 32 n = 1000 k = 10 random.seed(42) db.execute( "CREATE VIRTUAL TABLE t_rescore USING vec0(" f" embedding float[{dim}] indexed by rescore(quantizer=bit, oversample=16)" ")" ) db.execute( f"CREATE VIRTUAL TABLE t_flat USING vec0(embedding float[{dim}])" ) vectors = [[random.gauss(0, 1) for _ in range(dim)] for _ in range(n)] for i, v in enumerate(vectors): blob = float_vec(v) db.execute( "INSERT INTO t_rescore(rowid, embedding) VALUES (?, ?)", [i + 1, blob] ) db.execute( "INSERT INTO t_flat(rowid, embedding) VALUES (?, ?)", [i + 1, blob] ) query = float_vec([random.gauss(0, 1) for _ in range(dim)]) rescore_rows = db.execute( "SELECT rowid FROM t_rescore WHERE embedding MATCH ? ORDER BY distance LIMIT ?", [query, k], ).fetchall() flat_rows = db.execute( "SELECT rowid FROM t_flat WHERE embedding MATCH ? ORDER BY distance LIMIT ?", [query, k], ).fetchall() rescore_ids = {r["rowid"] for r in rescore_rows} flat_ids = {r["rowid"] for r in flat_rows} recall = len(rescore_ids & flat_ids) / k assert recall >= 0.7, f"Recall too low: {recall}" def test_knn_cosine(db): db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[8] distance_metric=cosine indexed by rescore(quantizer=bit)" ")" ) db.execute( "INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], ) db.execute( "INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], ) rows = db.execute( "SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1", [float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], ).fetchall() assert rows[0]["rowid"] == 1 # cosine distance of identical vectors should be ~0 assert rows[0]["distance"] < 0.01 def test_knn_empty_table(db): db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[8] indexed by rescore(quantizer=bit)" ")" ) rows = db.execute( "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5", [float_vec([1.0] * 8)], ).fetchall() assert len(rows) == 0 def test_knn_k_larger_than_n(db): db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[8] indexed by rescore(quantizer=bit)" ")" ) db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)]) db.execute("INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([2.0] * 8)]) rows = db.execute( "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10", [float_vec([1.0] * 8)], ).fetchall() assert len(rows) == 2 # ============================================================================ # Integration / edge case tests # ============================================================================ def test_knn_with_rowid_in(db): db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[8] indexed by rescore(quantizer=bit)" ")" ) for i in range(5): db.execute( "INSERT INTO t(rowid, embedding) VALUES (?, ?)", [i + 1, float_vec([float(i)] * 8)], ) # Only search within rowids 1, 3, 5 rows = db.execute( "SELECT rowid FROM t WHERE embedding MATCH ? AND rowid IN (1, 3, 5) ORDER BY distance LIMIT 3", [float_vec([0.0] * 8)], ).fetchall() result_ids = {r["rowid"] for r in rows} assert result_ids <= {1, 3, 5} def test_knn_after_deletes(db): db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[8] indexed by rescore(quantizer=int8)" ")" ) for i in range(10): db.execute( "INSERT INTO t(rowid, embedding) VALUES (?, ?)", [i + 1, float_vec([float(i)] * 8)], ) # Delete the closest match (rowid 1 = [0,0,...]) db.execute("DELETE FROM t WHERE rowid = 1") rows = db.execute( "SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5", [float_vec([0.0] * 8)], ).fetchall() # Verify ordering: rowid 2 ([1]*8) should be closest, then 3 ([2]*8), etc. assert len(rows) >= 2 assert rows[0]["distance"] <= rows[1]["distance"] # rowid 2 = [1,1,...] → L2 = sqrt(8) ≈ 2.83, rowid 3 = [2,2,...] → L2 = sqrt(32) ≈ 5.66 assert rows[0]["rowid"] == 2, f"Expected rowid 2, got {rows[0]['rowid']} with dist={rows[0]['distance']}" def test_oversample_effect(db): """Higher oversample should give equal or better recall.""" dim = 32 n = 500 k = 10 random.seed(123) vectors = [[random.gauss(0, 1) for _ in range(dim)] for _ in range(n)] query = float_vec([random.gauss(0, 1) for _ in range(dim)]) recalls = [] for oversample in [2, 16]: tname = f"t_os{oversample}" db.execute( f"CREATE VIRTUAL TABLE {tname} USING vec0(" f" embedding float[{dim}] indexed by rescore(quantizer=bit, oversample={oversample})" ")" ) for i, v in enumerate(vectors): db.execute( f"INSERT INTO {tname}(rowid, embedding) VALUES (?, ?)", [i + 1, float_vec(v)], ) rows = db.execute( f"SELECT rowid FROM {tname} WHERE embedding MATCH ? ORDER BY distance LIMIT ?", [query, k], ).fetchall() recalls.append({r["rowid"] for r in rows}) # Also get ground truth db.execute(f"CREATE VIRTUAL TABLE t_flat USING vec0(embedding float[{dim}])") for i, v in enumerate(vectors): db.execute( "INSERT INTO t_flat(rowid, embedding) VALUES (?, ?)", [i + 1, float_vec(v)], ) gt_rows = db.execute( "SELECT rowid FROM t_flat WHERE embedding MATCH ? ORDER BY distance LIMIT ?", [query, k], ).fetchall() gt_ids = {r["rowid"] for r in gt_rows} recall_low = len(recalls[0] & gt_ids) / k recall_high = len(recalls[1] & gt_ids) / k assert recall_high >= recall_low def test_multiple_vector_columns(db): """One column with rescore, one without.""" db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " v1 float[8] indexed by rescore(quantizer=bit)," " v2 float[8]" ")" ) db.execute( "INSERT INTO t(rowid, v1, v2) VALUES (1, ?, ?)", [float_vec([1.0] * 8), float_vec([0.0] * 8)], ) db.execute( "INSERT INTO t(rowid, v1, v2) VALUES (2, ?, ?)", [float_vec([0.0] * 8), float_vec([1.0] * 8)], ) # KNN on v1 (rescore path) rows = db.execute( "SELECT rowid FROM t WHERE v1 MATCH ? ORDER BY distance LIMIT 1", [float_vec([1.0] * 8)], ).fetchall() assert rows[0]["rowid"] == 1 # KNN on v2 (normal path) rows = db.execute( "SELECT rowid FROM t WHERE v2 MATCH ? ORDER BY distance LIMIT 1", [float_vec([1.0] * 8)], ).fetchall() assert rows[0]["rowid"] == 2 def test_corrupt_zeroblob_validity(db): """KNN should error (not crash) when rescore chunk rowids blob is zeroed out.""" db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[8] indexed by rescore(quantizer=bit)" ")" ) db.execute( "INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1, 0, 0, 0, 0, 0, 0, 0])], ) db.execute( "INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([0, 1, 0, 0, 0, 0, 0, 0])], ) # Corrupt: replace rowids with a truncated blob (wrong size) db.execute("UPDATE t_chunks SET rowids = x'00'") # Should error, not crash — blob size validation catches the mismatch with pytest.raises(sqlite3.OperationalError): db.execute( "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1", [float_vec([1, 0, 0, 0, 0, 0, 0, 0])], ).fetchall() def test_corrupt_truncated_validity_blob(db): """KNN should error when rescore chunk validity blob is truncated.""" db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[128] indexed by rescore(quantizer=bit)" ")" ) for i in range(5): import random random.seed(i) db.execute( "INSERT INTO t(rowid, embedding) VALUES (?, ?)", [i + 1, float_vec([random.gauss(0, 1) for _ in range(128)])], ) # Corrupt: truncate validity blob to 1 byte (should be chunk_size/8 = 128 bytes) db.execute("UPDATE t_chunks SET validity = x'FF'") with pytest.raises(sqlite3.OperationalError): db.execute( "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1", [float_vec([1.0] * 128)], ).fetchall() def test_rescore_text_pk_insert_knn_delete(db): """Rescore with text primary key: insert, KNN, delete, KNN again.""" db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " id text primary key," " embedding float[128] indexed by rescore(quantizer=bit)" ")" ) import random random.seed(99) vecs = {} for name in ["alpha", "beta", "gamma", "delta", "epsilon"]: v = [random.gauss(0, 1) for _ in range(128)] vecs[name] = v db.execute("INSERT INTO t(id, embedding) VALUES (?, ?)", [name, float_vec(v)]) # KNN should return text IDs rows = db.execute( "SELECT id, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 3", [float_vec(vecs["alpha"])], ).fetchall() assert len(rows) >= 1 ids = [r["id"] for r in rows] assert "alpha" in ids # Delete and verify db.execute("DELETE FROM t WHERE id = 'alpha'") rows = db.execute( "SELECT id FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 3", [float_vec(vecs["alpha"])], ).fetchall() ids = [r["id"] for r in rows] assert "alpha" not in ids assert len(rows) >= 1 # other results still returned def test_runtime_oversample(db): """oversample can be changed at query time via FTS5-style command.""" db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[128] indexed by rescore(quantizer=bit, oversample=2)" ")" ) random.seed(200) for i in range(200): db.execute( "INSERT INTO t(rowid, embedding) VALUES (?, ?)", [i + 1, float_vec([random.gauss(0, 1) for _ in range(128)])], ) query = float_vec([random.gauss(0, 1) for _ in range(128)]) # KNN with default oversample=2 (low) rows_low = db.execute( "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10", [query], ).fetchall() assert len(rows_low) == 10 # Change oversample at runtime to high value db.execute("INSERT INTO t(t) VALUES ('oversample=32')") # KNN with oversample=32 (high) — same or better recall rows_high = db.execute( "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10", [query], ).fetchall() assert len(rows_high) == 10 # Reset to original db.execute("INSERT INTO t(t) VALUES ('oversample=2')") rows_reset = db.execute( "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10", [query], ).fetchall() assert len(rows_reset) == 10 # After reset, should match the original low-oversample results assert [r["rowid"] for r in rows_reset] == [r["rowid"] for r in rows_low] def test_runtime_oversample_error(db): """Invalid oversample values should error.""" db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[128] indexed by rescore(quantizer=bit)" ")" ) with pytest.raises(sqlite3.OperationalError, match="oversample must be >= 1"): db.execute("INSERT INTO t(t) VALUES ('oversample=0')") with pytest.raises(sqlite3.OperationalError, match="oversample must be >= 1"): db.execute("INSERT INTO t(t) VALUES ('oversample=-5')") def test_unknown_command_errors(db): """Unknown command strings should produce a clear error.""" db.execute( "CREATE VIRTUAL TABLE t USING vec0(" " embedding float[128] indexed by rescore(quantizer=bit)" ")" ) with pytest.raises(sqlite3.OperationalError, match="unknown vec0 command"): db.execute("INSERT INTO t(t) VALUES ('not_a_real_command')")