mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 00:36:56 +02:00
Add rescore index for ANN queries
Add rescore index type: stores full-precision float vectors in a rowid-keyed shadow table, quantizes to int8 for fast initial scan, then rescores top candidates with original vectors. Includes config parser, shadow table management, insert/delete support, KNN integration, compile flag (SQLITE_VEC_ENABLE_RESCORE), fuzz targets, and tests.
This commit is contained in:
parent
bf2455f2ba
commit
ba0db0b6d6
19 changed files with 3378 additions and 8 deletions
568
tests/test-rescore.py
Normal file
568
tests/test-rescore.py
Normal file
|
|
@ -0,0 +1,568 @@
|
|||
"""Tests for the rescore index feature in sqlite-vec."""
|
||||
import struct
|
||||
import sqlite3
|
||||
import pytest
|
||||
import math
|
||||
import random
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def db():
|
||||
db = sqlite3.connect(":memory:")
|
||||
db.row_factory = sqlite3.Row
|
||||
db.enable_load_extension(True)
|
||||
db.load_extension("dist/vec0")
|
||||
db.enable_load_extension(False)
|
||||
return db
|
||||
|
||||
|
||||
def float_vec(values):
|
||||
"""Pack a list of floats into a blob for sqlite-vec."""
|
||||
return struct.pack(f"{len(values)}f", *values)
|
||||
|
||||
|
||||
def unpack_float_vec(blob):
|
||||
"""Unpack a float vector blob."""
|
||||
n = len(blob) // 4
|
||||
return list(struct.unpack(f"{n}f", blob))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Creation tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_create_bit(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[128] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
# Table exists and has the right structure
|
||||
row = db.execute(
|
||||
"SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'"
|
||||
).fetchone()
|
||||
assert row["cnt"] > 0
|
||||
|
||||
|
||||
def test_create_int8(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[128] indexed by rescore(quantizer=int8)"
|
||||
")"
|
||||
)
|
||||
row = db.execute(
|
||||
"SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'"
|
||||
).fetchone()
|
||||
assert row["cnt"] > 0
|
||||
|
||||
|
||||
def test_create_with_oversample(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[128] indexed by rescore(quantizer=bit, oversample=16)"
|
||||
")"
|
||||
)
|
||||
row = db.execute(
|
||||
"SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'"
|
||||
).fetchone()
|
||||
assert row["cnt"] > 0
|
||||
|
||||
|
||||
def test_create_with_distance_metric(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[128] distance_metric=cosine indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
row = db.execute(
|
||||
"SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'"
|
||||
).fetchone()
|
||||
assert row["cnt"] > 0
|
||||
|
||||
|
||||
def test_create_error_missing_quantizer(db):
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[128] indexed by rescore(oversample=8)"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
def test_create_error_invalid_quantizer(db):
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[128] indexed by rescore(quantizer=float)"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
def test_create_error_on_bit_column(db):
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding bit[1024] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
def test_create_error_on_int8_column(db):
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding int8[128] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
def test_create_error_bad_oversample_zero(db):
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[128] indexed by rescore(quantizer=bit, oversample=0)"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
def test_create_error_bad_oversample_too_large(db):
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[128] indexed by rescore(quantizer=bit, oversample=999)"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
def test_create_error_bit_dim_not_divisible_by_8(db):
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[100] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Shadow table tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_shadow_tables_exist(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[128] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
tables = [
|
||||
r[0]
|
||||
for r in db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 't_%' ORDER BY name"
|
||||
).fetchall()
|
||||
]
|
||||
assert "t_rescore_chunks00" in tables
|
||||
assert "t_rescore_vectors00" in tables
|
||||
# Rescore columns don't create _vector_chunks
|
||||
assert "t_vector_chunks00" not in tables
|
||||
|
||||
|
||||
def test_drop_cleans_up(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[128] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
db.execute("DROP TABLE t")
|
||||
tables = [
|
||||
r[0]
|
||||
for r in db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 't_%'"
|
||||
).fetchall()
|
||||
]
|
||||
assert len(tables) == 0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Insert tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_insert_single(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
|
||||
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
|
||||
assert row["cnt"] == 1
|
||||
|
||||
|
||||
def test_insert_multiple(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=int8)"
|
||||
")"
|
||||
)
|
||||
for i in range(10):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
||||
[i + 1, float_vec([float(i)] * 8)],
|
||||
)
|
||||
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
|
||||
assert row["cnt"] == 10
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Delete tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_delete_single(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
|
||||
db.execute("DELETE FROM t WHERE rowid = 1")
|
||||
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
|
||||
assert row["cnt"] == 0
|
||||
|
||||
|
||||
def test_delete_and_reinsert(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
|
||||
db.execute("DELETE FROM t WHERE rowid = 1")
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([2.0] * 8)]
|
||||
)
|
||||
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
|
||||
assert row["cnt"] == 1
|
||||
|
||||
|
||||
def test_point_query_returns_float(db):
|
||||
"""SELECT by rowid should return the original float vector, not quantized."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
vals = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
|
||||
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec(vals)])
|
||||
row = db.execute("SELECT embedding FROM t WHERE rowid = 1").fetchone()
|
||||
result = unpack_float_vec(row["embedding"])
|
||||
for a, b in zip(result, vals):
|
||||
assert abs(a - b) < 1e-6
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# KNN tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_knn_basic_bit(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
# Insert vectors where [1,0,0,...] is closest to query [1,0,0,...]
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
||||
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
|
||||
[float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (3, ?)",
|
||||
[float_vec([0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
)
|
||||
rows = db.execute(
|
||||
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
||||
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
).fetchall()
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["rowid"] == 1
|
||||
|
||||
|
||||
def test_knn_basic_int8(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=int8)"
|
||||
")"
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
||||
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
|
||||
[float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (3, ?)",
|
||||
[float_vec([0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
)
|
||||
rows = db.execute(
|
||||
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
||||
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
).fetchall()
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["rowid"] == 1
|
||||
|
||||
|
||||
def test_knn_returns_float_distances(db):
|
||||
"""KNN should return float-precision distances, not quantized distances."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
v1 = [1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
v2 = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]
|
||||
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec(v1)])
|
||||
db.execute("INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec(v2)])
|
||||
|
||||
query = [1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
rows = db.execute(
|
||||
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 2",
|
||||
[float_vec(query)],
|
||||
).fetchall()
|
||||
|
||||
# First result should be exact match with distance ~0
|
||||
assert rows[0]["rowid"] == 1
|
||||
assert rows[0]["distance"] < 0.01
|
||||
|
||||
# Second result should have a float distance
|
||||
# sqrt((1-0)^2 + (0.5-0)^2 + (0-1)^2) = sqrt(2.25) = 1.5
|
||||
assert abs(rows[1]["distance"] - 1.5) < 0.01
|
||||
|
||||
|
||||
def test_knn_recall(db):
|
||||
"""With enough vectors, rescore should achieve good recall (>0.9)."""
|
||||
dim = 32
|
||||
n = 1000
|
||||
k = 10
|
||||
random.seed(42)
|
||||
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t_rescore USING vec0("
|
||||
f" embedding float[{dim}] indexed by rescore(quantizer=bit, oversample=16)"
|
||||
")"
|
||||
)
|
||||
db.execute(
|
||||
f"CREATE VIRTUAL TABLE t_flat USING vec0(embedding float[{dim}])"
|
||||
)
|
||||
|
||||
vectors = [[random.gauss(0, 1) for _ in range(dim)] for _ in range(n)]
|
||||
for i, v in enumerate(vectors):
|
||||
blob = float_vec(v)
|
||||
db.execute(
|
||||
"INSERT INTO t_rescore(rowid, embedding) VALUES (?, ?)", [i + 1, blob]
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t_flat(rowid, embedding) VALUES (?, ?)", [i + 1, blob]
|
||||
)
|
||||
|
||||
query = float_vec([random.gauss(0, 1) for _ in range(dim)])
|
||||
|
||||
rescore_rows = db.execute(
|
||||
"SELECT rowid FROM t_rescore WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
|
||||
[query, k],
|
||||
).fetchall()
|
||||
flat_rows = db.execute(
|
||||
"SELECT rowid FROM t_flat WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
|
||||
[query, k],
|
||||
).fetchall()
|
||||
|
||||
rescore_ids = {r["rowid"] for r in rescore_rows}
|
||||
flat_ids = {r["rowid"] for r in flat_rows}
|
||||
recall = len(rescore_ids & flat_ids) / k
|
||||
assert recall >= 0.7, f"Recall too low: {recall}"
|
||||
|
||||
|
||||
def test_knn_cosine(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] distance_metric=cosine indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
||||
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
|
||||
[float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
)
|
||||
rows = db.execute(
|
||||
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
||||
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
).fetchall()
|
||||
assert rows[0]["rowid"] == 1
|
||||
# cosine distance of identical vectors should be ~0
|
||||
assert rows[0]["distance"] < 0.01
|
||||
|
||||
|
||||
def test_knn_empty_table(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5",
|
||||
[float_vec([1.0] * 8)],
|
||||
).fetchall()
|
||||
assert len(rows) == 0
|
||||
|
||||
|
||||
def test_knn_k_larger_than_n(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
|
||||
db.execute("INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([2.0] * 8)])
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
|
||||
[float_vec([1.0] * 8)],
|
||||
).fetchall()
|
||||
assert len(rows) == 2
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Integration / edge case tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_knn_with_rowid_in(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
for i in range(5):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
||||
[i + 1, float_vec([float(i)] * 8)],
|
||||
)
|
||||
# Only search within rowids 1, 3, 5
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? AND rowid IN (1, 3, 5) ORDER BY distance LIMIT 3",
|
||||
[float_vec([0.0] * 8)],
|
||||
).fetchall()
|
||||
result_ids = {r["rowid"] for r in rows}
|
||||
assert result_ids <= {1, 3, 5}
|
||||
|
||||
|
||||
def test_knn_after_deletes(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=int8)"
|
||||
")"
|
||||
)
|
||||
for i in range(10):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
||||
[i + 1, float_vec([float(i)] * 8)],
|
||||
)
|
||||
# Delete the closest match (rowid 1 = [0,0,...])
|
||||
db.execute("DELETE FROM t WHERE rowid = 1")
|
||||
rows = db.execute(
|
||||
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5",
|
||||
[float_vec([0.0] * 8)],
|
||||
).fetchall()
|
||||
# Verify ordering: rowid 2 ([1]*8) should be closest, then 3 ([2]*8), etc.
|
||||
assert len(rows) >= 2
|
||||
assert rows[0]["distance"] <= rows[1]["distance"]
|
||||
# rowid 2 = [1,1,...] → L2 = sqrt(8) ≈ 2.83, rowid 3 = [2,2,...] → L2 = sqrt(32) ≈ 5.66
|
||||
assert rows[0]["rowid"] == 2, f"Expected rowid 2, got {rows[0]['rowid']} with dist={rows[0]['distance']}"
|
||||
|
||||
|
||||
def test_oversample_effect(db):
|
||||
"""Higher oversample should give equal or better recall."""
|
||||
dim = 32
|
||||
n = 500
|
||||
k = 10
|
||||
random.seed(123)
|
||||
|
||||
vectors = [[random.gauss(0, 1) for _ in range(dim)] for _ in range(n)]
|
||||
query = float_vec([random.gauss(0, 1) for _ in range(dim)])
|
||||
|
||||
recalls = []
|
||||
for oversample in [2, 16]:
|
||||
tname = f"t_os{oversample}"
|
||||
db.execute(
|
||||
f"CREATE VIRTUAL TABLE {tname} USING vec0("
|
||||
f" embedding float[{dim}] indexed by rescore(quantizer=bit, oversample={oversample})"
|
||||
")"
|
||||
)
|
||||
for i, v in enumerate(vectors):
|
||||
db.execute(
|
||||
f"INSERT INTO {tname}(rowid, embedding) VALUES (?, ?)",
|
||||
[i + 1, float_vec(v)],
|
||||
)
|
||||
rows = db.execute(
|
||||
f"SELECT rowid FROM {tname} WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
|
||||
[query, k],
|
||||
).fetchall()
|
||||
recalls.append({r["rowid"] for r in rows})
|
||||
|
||||
# Also get ground truth
|
||||
db.execute(f"CREATE VIRTUAL TABLE t_flat USING vec0(embedding float[{dim}])")
|
||||
for i, v in enumerate(vectors):
|
||||
db.execute(
|
||||
"INSERT INTO t_flat(rowid, embedding) VALUES (?, ?)",
|
||||
[i + 1, float_vec(v)],
|
||||
)
|
||||
gt_rows = db.execute(
|
||||
"SELECT rowid FROM t_flat WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
|
||||
[query, k],
|
||||
).fetchall()
|
||||
gt_ids = {r["rowid"] for r in gt_rows}
|
||||
|
||||
recall_low = len(recalls[0] & gt_ids) / k
|
||||
recall_high = len(recalls[1] & gt_ids) / k
|
||||
assert recall_high >= recall_low
|
||||
|
||||
|
||||
def test_multiple_vector_columns(db):
|
||||
"""One column with rescore, one without."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" v1 float[8] indexed by rescore(quantizer=bit),"
|
||||
" v2 float[8]"
|
||||
")"
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v1, v2) VALUES (1, ?, ?)",
|
||||
[float_vec([1.0] * 8), float_vec([0.0] * 8)],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v1, v2) VALUES (2, ?, ?)",
|
||||
[float_vec([0.0] * 8), float_vec([1.0] * 8)],
|
||||
)
|
||||
|
||||
# KNN on v1 (rescore path)
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE v1 MATCH ? ORDER BY distance LIMIT 1",
|
||||
[float_vec([1.0] * 8)],
|
||||
).fetchall()
|
||||
assert rows[0]["rowid"] == 1
|
||||
|
||||
# KNN on v2 (normal path)
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE v2 MATCH ? ORDER BY distance LIMIT 1",
|
||||
[float_vec([1.0] * 8)],
|
||||
).fetchall()
|
||||
assert rows[0]["rowid"] == 2
|
||||
Loading…
Add table
Add a link
Reference in a new issue