mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 00:36:56 +02:00
Replace the old INSERT INTO t(rowid) VALUES('command') hack with a
proper hidden command column named after the table (FTS5 pattern):
INSERT INTO t(t) VALUES ('oversample=16')
The command column is the first hidden column (before distance and k)
to reserve ability for future table-valued function argument use.
Schema: CREATE TABLE x(rowid, <cols>, "<table>" hidden, distance hidden, k hidden)
For backwards compat, pre-v0.1.10 tables (detected via _info shadow
table version) skip the command column to avoid name conflicts with
user columns that may share the table's name. Verified with legacy
fixture DB generated by sqlite-vec v0.1.6.
Changes:
- Add hidden command column to sqlite3_declare_vtab for new tables
- Version-gate via _info shadow table for existing tables
- Validate at CREATE time that no column name matches table name
- Add rescore_handle_command() with oversample=N support
- rescore_knn() prefers runtime oversample_search over CREATE default
- Remove old rowid-based command dispatch
- Migrate all DiskANN/IVF/fuzz tests and benchmarks to new syntax
- Add legacy DB fixture (v0.1.6) and 9 backwards-compat tests
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
727 lines
23 KiB
Python
727 lines
23 KiB
Python
"""Tests for the rescore index feature in sqlite-vec."""
|
|
import struct
|
|
import sqlite3
|
|
import pytest
|
|
import math
|
|
import random
|
|
|
|
|
|
@pytest.fixture()
|
|
def db():
|
|
db = sqlite3.connect(":memory:")
|
|
db.row_factory = sqlite3.Row
|
|
db.enable_load_extension(True)
|
|
db.load_extension("dist/vec0")
|
|
db.enable_load_extension(False)
|
|
return db
|
|
|
|
|
|
def float_vec(values):
|
|
"""Pack a list of floats into a blob for sqlite-vec."""
|
|
return struct.pack(f"{len(values)}f", *values)
|
|
|
|
|
|
def unpack_float_vec(blob):
|
|
"""Unpack a float vector blob."""
|
|
n = len(blob) // 4
|
|
return list(struct.unpack(f"{n}f", blob))
|
|
|
|
|
|
# ============================================================================
|
|
# Creation tests
|
|
# ============================================================================
|
|
|
|
|
|
def test_create_bit(db):
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[128] indexed by rescore(quantizer=bit)"
|
|
")"
|
|
)
|
|
# Table exists and has the right structure
|
|
row = db.execute(
|
|
"SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'"
|
|
).fetchone()
|
|
assert row["cnt"] > 0
|
|
|
|
|
|
def test_create_int8(db):
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[128] indexed by rescore(quantizer=int8)"
|
|
")"
|
|
)
|
|
row = db.execute(
|
|
"SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'"
|
|
).fetchone()
|
|
assert row["cnt"] > 0
|
|
|
|
|
|
def test_create_with_oversample(db):
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[128] indexed by rescore(quantizer=bit, oversample=16)"
|
|
")"
|
|
)
|
|
row = db.execute(
|
|
"SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'"
|
|
).fetchone()
|
|
assert row["cnt"] > 0
|
|
|
|
|
|
def test_create_with_distance_metric(db):
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[128] distance_metric=cosine indexed by rescore(quantizer=bit)"
|
|
")"
|
|
)
|
|
row = db.execute(
|
|
"SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'"
|
|
).fetchone()
|
|
assert row["cnt"] > 0
|
|
|
|
|
|
def test_create_error_missing_quantizer(db):
|
|
with pytest.raises(sqlite3.OperationalError):
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[128] indexed by rescore(oversample=8)"
|
|
")"
|
|
)
|
|
|
|
|
|
def test_create_error_invalid_quantizer(db):
|
|
with pytest.raises(sqlite3.OperationalError):
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[128] indexed by rescore(quantizer=float)"
|
|
")"
|
|
)
|
|
|
|
|
|
def test_create_error_on_bit_column(db):
|
|
with pytest.raises(sqlite3.OperationalError):
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding bit[1024] indexed by rescore(quantizer=bit)"
|
|
")"
|
|
)
|
|
|
|
|
|
def test_create_error_on_int8_column(db):
|
|
with pytest.raises(sqlite3.OperationalError):
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding int8[128] indexed by rescore(quantizer=bit)"
|
|
")"
|
|
)
|
|
|
|
|
|
def test_create_error_bad_oversample_zero(db):
|
|
with pytest.raises(sqlite3.OperationalError):
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[128] indexed by rescore(quantizer=bit, oversample=0)"
|
|
")"
|
|
)
|
|
|
|
|
|
def test_create_error_bad_oversample_too_large(db):
|
|
with pytest.raises(sqlite3.OperationalError):
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[128] indexed by rescore(quantizer=bit, oversample=999)"
|
|
")"
|
|
)
|
|
|
|
|
|
def test_create_error_bit_dim_not_divisible_by_8(db):
|
|
with pytest.raises(sqlite3.OperationalError):
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[100] indexed by rescore(quantizer=bit)"
|
|
")"
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# Shadow table tests
|
|
# ============================================================================
|
|
|
|
|
|
def test_shadow_tables_exist(db):
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[128] indexed by rescore(quantizer=bit)"
|
|
")"
|
|
)
|
|
tables = [
|
|
r[0]
|
|
for r in db.execute(
|
|
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 't_%' ORDER BY name"
|
|
).fetchall()
|
|
]
|
|
assert "t_rescore_chunks00" in tables
|
|
assert "t_rescore_vectors00" in tables
|
|
# Rescore columns don't create _vector_chunks
|
|
assert "t_vector_chunks00" not in tables
|
|
|
|
|
|
def test_drop_cleans_up(db):
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[128] indexed by rescore(quantizer=bit)"
|
|
")"
|
|
)
|
|
db.execute("DROP TABLE t")
|
|
tables = [
|
|
r[0]
|
|
for r in db.execute(
|
|
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 't_%'"
|
|
).fetchall()
|
|
]
|
|
assert len(tables) == 0
|
|
|
|
|
|
# ============================================================================
|
|
# Insert tests
|
|
# ============================================================================
|
|
|
|
|
|
def test_insert_single(db):
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[8] indexed by rescore(quantizer=bit)"
|
|
")"
|
|
)
|
|
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
|
|
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
|
|
assert row["cnt"] == 1
|
|
|
|
|
|
def test_insert_multiple(db):
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[8] indexed by rescore(quantizer=int8)"
|
|
")"
|
|
)
|
|
for i in range(10):
|
|
db.execute(
|
|
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
|
[i + 1, float_vec([float(i)] * 8)],
|
|
)
|
|
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
|
|
assert row["cnt"] == 10
|
|
|
|
|
|
# ============================================================================
|
|
# Delete tests
|
|
# ============================================================================
|
|
|
|
|
|
def test_delete_single(db):
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[8] indexed by rescore(quantizer=bit)"
|
|
")"
|
|
)
|
|
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
|
|
db.execute("DELETE FROM t WHERE rowid = 1")
|
|
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
|
|
assert row["cnt"] == 0
|
|
|
|
|
|
def test_delete_and_reinsert(db):
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[8] indexed by rescore(quantizer=bit)"
|
|
")"
|
|
)
|
|
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
|
|
db.execute("DELETE FROM t WHERE rowid = 1")
|
|
db.execute(
|
|
"INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([2.0] * 8)]
|
|
)
|
|
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
|
|
assert row["cnt"] == 1
|
|
|
|
|
|
def test_point_query_returns_float(db):
|
|
"""SELECT by rowid should return the original float vector, not quantized."""
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[8] indexed by rescore(quantizer=bit)"
|
|
")"
|
|
)
|
|
vals = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
|
|
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec(vals)])
|
|
row = db.execute("SELECT embedding FROM t WHERE rowid = 1").fetchone()
|
|
result = unpack_float_vec(row["embedding"])
|
|
for a, b in zip(result, vals):
|
|
assert abs(a - b) < 1e-6
|
|
|
|
|
|
# ============================================================================
|
|
# KNN tests
|
|
# ============================================================================
|
|
|
|
|
|
def test_knn_basic_bit(db):
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[8] indexed by rescore(quantizer=bit)"
|
|
")"
|
|
)
|
|
# Insert vectors where [1,0,0,...] is closest to query [1,0,0,...]
|
|
db.execute(
|
|
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
|
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
|
)
|
|
db.execute(
|
|
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
|
|
[float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
|
)
|
|
db.execute(
|
|
"INSERT INTO t(rowid, embedding) VALUES (3, ?)",
|
|
[float_vec([0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
|
)
|
|
rows = db.execute(
|
|
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
|
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
|
).fetchall()
|
|
assert len(rows) == 1
|
|
assert rows[0]["rowid"] == 1
|
|
|
|
|
|
def test_knn_basic_int8(db):
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[8] indexed by rescore(quantizer=int8)"
|
|
")"
|
|
)
|
|
db.execute(
|
|
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
|
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
|
)
|
|
db.execute(
|
|
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
|
|
[float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
|
)
|
|
db.execute(
|
|
"INSERT INTO t(rowid, embedding) VALUES (3, ?)",
|
|
[float_vec([0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
|
)
|
|
rows = db.execute(
|
|
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
|
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
|
).fetchall()
|
|
assert len(rows) == 1
|
|
assert rows[0]["rowid"] == 1
|
|
|
|
|
|
def test_knn_returns_float_distances(db):
|
|
"""KNN should return float-precision distances, not quantized distances."""
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[8] indexed by rescore(quantizer=bit)"
|
|
")"
|
|
)
|
|
v1 = [1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
|
|
v2 = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]
|
|
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec(v1)])
|
|
db.execute("INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec(v2)])
|
|
|
|
query = [1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
|
|
rows = db.execute(
|
|
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 2",
|
|
[float_vec(query)],
|
|
).fetchall()
|
|
|
|
# First result should be exact match with distance ~0
|
|
assert rows[0]["rowid"] == 1
|
|
assert rows[0]["distance"] < 0.01
|
|
|
|
# Second result should have a float distance
|
|
# sqrt((1-0)^2 + (0.5-0)^2 + (0-1)^2) = sqrt(2.25) = 1.5
|
|
assert abs(rows[1]["distance"] - 1.5) < 0.01
|
|
|
|
|
|
def test_knn_recall(db):
|
|
"""With enough vectors, rescore should achieve good recall (>0.9)."""
|
|
dim = 32
|
|
n = 1000
|
|
k = 10
|
|
random.seed(42)
|
|
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t_rescore USING vec0("
|
|
f" embedding float[{dim}] indexed by rescore(quantizer=bit, oversample=16)"
|
|
")"
|
|
)
|
|
db.execute(
|
|
f"CREATE VIRTUAL TABLE t_flat USING vec0(embedding float[{dim}])"
|
|
)
|
|
|
|
vectors = [[random.gauss(0, 1) for _ in range(dim)] for _ in range(n)]
|
|
for i, v in enumerate(vectors):
|
|
blob = float_vec(v)
|
|
db.execute(
|
|
"INSERT INTO t_rescore(rowid, embedding) VALUES (?, ?)", [i + 1, blob]
|
|
)
|
|
db.execute(
|
|
"INSERT INTO t_flat(rowid, embedding) VALUES (?, ?)", [i + 1, blob]
|
|
)
|
|
|
|
query = float_vec([random.gauss(0, 1) for _ in range(dim)])
|
|
|
|
rescore_rows = db.execute(
|
|
"SELECT rowid FROM t_rescore WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
|
|
[query, k],
|
|
).fetchall()
|
|
flat_rows = db.execute(
|
|
"SELECT rowid FROM t_flat WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
|
|
[query, k],
|
|
).fetchall()
|
|
|
|
rescore_ids = {r["rowid"] for r in rescore_rows}
|
|
flat_ids = {r["rowid"] for r in flat_rows}
|
|
recall = len(rescore_ids & flat_ids) / k
|
|
assert recall >= 0.7, f"Recall too low: {recall}"
|
|
|
|
|
|
def test_knn_cosine(db):
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[8] distance_metric=cosine indexed by rescore(quantizer=bit)"
|
|
")"
|
|
)
|
|
db.execute(
|
|
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
|
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
|
)
|
|
db.execute(
|
|
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
|
|
[float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
|
)
|
|
rows = db.execute(
|
|
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
|
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
|
).fetchall()
|
|
assert rows[0]["rowid"] == 1
|
|
# cosine distance of identical vectors should be ~0
|
|
assert rows[0]["distance"] < 0.01
|
|
|
|
|
|
def test_knn_empty_table(db):
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[8] indexed by rescore(quantizer=bit)"
|
|
")"
|
|
)
|
|
rows = db.execute(
|
|
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5",
|
|
[float_vec([1.0] * 8)],
|
|
).fetchall()
|
|
assert len(rows) == 0
|
|
|
|
|
|
def test_knn_k_larger_than_n(db):
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[8] indexed by rescore(quantizer=bit)"
|
|
")"
|
|
)
|
|
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
|
|
db.execute("INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([2.0] * 8)])
|
|
rows = db.execute(
|
|
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
|
|
[float_vec([1.0] * 8)],
|
|
).fetchall()
|
|
assert len(rows) == 2
|
|
|
|
|
|
# ============================================================================
|
|
# Integration / edge case tests
|
|
# ============================================================================
|
|
|
|
|
|
def test_knn_with_rowid_in(db):
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[8] indexed by rescore(quantizer=bit)"
|
|
")"
|
|
)
|
|
for i in range(5):
|
|
db.execute(
|
|
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
|
[i + 1, float_vec([float(i)] * 8)],
|
|
)
|
|
# Only search within rowids 1, 3, 5
|
|
rows = db.execute(
|
|
"SELECT rowid FROM t WHERE embedding MATCH ? AND rowid IN (1, 3, 5) ORDER BY distance LIMIT 3",
|
|
[float_vec([0.0] * 8)],
|
|
).fetchall()
|
|
result_ids = {r["rowid"] for r in rows}
|
|
assert result_ids <= {1, 3, 5}
|
|
|
|
|
|
def test_knn_after_deletes(db):
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[8] indexed by rescore(quantizer=int8)"
|
|
")"
|
|
)
|
|
for i in range(10):
|
|
db.execute(
|
|
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
|
[i + 1, float_vec([float(i)] * 8)],
|
|
)
|
|
# Delete the closest match (rowid 1 = [0,0,...])
|
|
db.execute("DELETE FROM t WHERE rowid = 1")
|
|
rows = db.execute(
|
|
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5",
|
|
[float_vec([0.0] * 8)],
|
|
).fetchall()
|
|
# Verify ordering: rowid 2 ([1]*8) should be closest, then 3 ([2]*8), etc.
|
|
assert len(rows) >= 2
|
|
assert rows[0]["distance"] <= rows[1]["distance"]
|
|
# rowid 2 = [1,1,...] → L2 = sqrt(8) ≈ 2.83, rowid 3 = [2,2,...] → L2 = sqrt(32) ≈ 5.66
|
|
assert rows[0]["rowid"] == 2, f"Expected rowid 2, got {rows[0]['rowid']} with dist={rows[0]['distance']}"
|
|
|
|
|
|
def test_oversample_effect(db):
|
|
"""Higher oversample should give equal or better recall."""
|
|
dim = 32
|
|
n = 500
|
|
k = 10
|
|
random.seed(123)
|
|
|
|
vectors = [[random.gauss(0, 1) for _ in range(dim)] for _ in range(n)]
|
|
query = float_vec([random.gauss(0, 1) for _ in range(dim)])
|
|
|
|
recalls = []
|
|
for oversample in [2, 16]:
|
|
tname = f"t_os{oversample}"
|
|
db.execute(
|
|
f"CREATE VIRTUAL TABLE {tname} USING vec0("
|
|
f" embedding float[{dim}] indexed by rescore(quantizer=bit, oversample={oversample})"
|
|
")"
|
|
)
|
|
for i, v in enumerate(vectors):
|
|
db.execute(
|
|
f"INSERT INTO {tname}(rowid, embedding) VALUES (?, ?)",
|
|
[i + 1, float_vec(v)],
|
|
)
|
|
rows = db.execute(
|
|
f"SELECT rowid FROM {tname} WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
|
|
[query, k],
|
|
).fetchall()
|
|
recalls.append({r["rowid"] for r in rows})
|
|
|
|
# Also get ground truth
|
|
db.execute(f"CREATE VIRTUAL TABLE t_flat USING vec0(embedding float[{dim}])")
|
|
for i, v in enumerate(vectors):
|
|
db.execute(
|
|
"INSERT INTO t_flat(rowid, embedding) VALUES (?, ?)",
|
|
[i + 1, float_vec(v)],
|
|
)
|
|
gt_rows = db.execute(
|
|
"SELECT rowid FROM t_flat WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
|
|
[query, k],
|
|
).fetchall()
|
|
gt_ids = {r["rowid"] for r in gt_rows}
|
|
|
|
recall_low = len(recalls[0] & gt_ids) / k
|
|
recall_high = len(recalls[1] & gt_ids) / k
|
|
assert recall_high >= recall_low
|
|
|
|
|
|
def test_multiple_vector_columns(db):
|
|
"""One column with rescore, one without."""
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" v1 float[8] indexed by rescore(quantizer=bit),"
|
|
" v2 float[8]"
|
|
")"
|
|
)
|
|
db.execute(
|
|
"INSERT INTO t(rowid, v1, v2) VALUES (1, ?, ?)",
|
|
[float_vec([1.0] * 8), float_vec([0.0] * 8)],
|
|
)
|
|
db.execute(
|
|
"INSERT INTO t(rowid, v1, v2) VALUES (2, ?, ?)",
|
|
[float_vec([0.0] * 8), float_vec([1.0] * 8)],
|
|
)
|
|
|
|
# KNN on v1 (rescore path)
|
|
rows = db.execute(
|
|
"SELECT rowid FROM t WHERE v1 MATCH ? ORDER BY distance LIMIT 1",
|
|
[float_vec([1.0] * 8)],
|
|
).fetchall()
|
|
assert rows[0]["rowid"] == 1
|
|
|
|
# KNN on v2 (normal path)
|
|
rows = db.execute(
|
|
"SELECT rowid FROM t WHERE v2 MATCH ? ORDER BY distance LIMIT 1",
|
|
[float_vec([1.0] * 8)],
|
|
).fetchall()
|
|
assert rows[0]["rowid"] == 2
|
|
|
|
|
|
def test_corrupt_zeroblob_validity(db):
|
|
"""KNN should error (not crash) when rescore chunk rowids blob is zeroed out."""
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[8] indexed by rescore(quantizer=bit)"
|
|
")"
|
|
)
|
|
db.execute(
|
|
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
|
[float_vec([1, 0, 0, 0, 0, 0, 0, 0])],
|
|
)
|
|
db.execute(
|
|
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
|
|
[float_vec([0, 1, 0, 0, 0, 0, 0, 0])],
|
|
)
|
|
|
|
# Corrupt: replace rowids with a truncated blob (wrong size)
|
|
db.execute("UPDATE t_chunks SET rowids = x'00'")
|
|
|
|
# Should error, not crash — blob size validation catches the mismatch
|
|
with pytest.raises(sqlite3.OperationalError):
|
|
db.execute(
|
|
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
|
[float_vec([1, 0, 0, 0, 0, 0, 0, 0])],
|
|
).fetchall()
|
|
|
|
|
|
def test_corrupt_truncated_validity_blob(db):
|
|
"""KNN should error when rescore chunk validity blob is truncated."""
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[128] indexed by rescore(quantizer=bit)"
|
|
")"
|
|
)
|
|
for i in range(5):
|
|
import random
|
|
random.seed(i)
|
|
db.execute(
|
|
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
|
[i + 1, float_vec([random.gauss(0, 1) for _ in range(128)])],
|
|
)
|
|
|
|
# Corrupt: truncate validity blob to 1 byte (should be chunk_size/8 = 128 bytes)
|
|
db.execute("UPDATE t_chunks SET validity = x'FF'")
|
|
|
|
with pytest.raises(sqlite3.OperationalError):
|
|
db.execute(
|
|
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
|
[float_vec([1.0] * 128)],
|
|
).fetchall()
|
|
|
|
|
|
def test_rescore_text_pk_insert_knn_delete(db):
|
|
"""Rescore with text primary key: insert, KNN, delete, KNN again."""
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" id text primary key,"
|
|
" embedding float[128] indexed by rescore(quantizer=bit)"
|
|
")"
|
|
)
|
|
|
|
import random
|
|
random.seed(99)
|
|
vecs = {}
|
|
for name in ["alpha", "beta", "gamma", "delta", "epsilon"]:
|
|
v = [random.gauss(0, 1) for _ in range(128)]
|
|
vecs[name] = v
|
|
db.execute("INSERT INTO t(id, embedding) VALUES (?, ?)", [name, float_vec(v)])
|
|
|
|
# KNN should return text IDs
|
|
rows = db.execute(
|
|
"SELECT id, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 3",
|
|
[float_vec(vecs["alpha"])],
|
|
).fetchall()
|
|
assert len(rows) >= 1
|
|
ids = [r["id"] for r in rows]
|
|
assert "alpha" in ids
|
|
|
|
# Delete and verify
|
|
db.execute("DELETE FROM t WHERE id = 'alpha'")
|
|
rows = db.execute(
|
|
"SELECT id FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 3",
|
|
[float_vec(vecs["alpha"])],
|
|
).fetchall()
|
|
ids = [r["id"] for r in rows]
|
|
assert "alpha" not in ids
|
|
assert len(rows) >= 1 # other results still returned
|
|
|
|
|
|
def test_runtime_oversample(db):
|
|
"""oversample can be changed at query time via FTS5-style command."""
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[128] indexed by rescore(quantizer=bit, oversample=2)"
|
|
")"
|
|
)
|
|
random.seed(200)
|
|
for i in range(200):
|
|
db.execute(
|
|
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
|
[i + 1, float_vec([random.gauss(0, 1) for _ in range(128)])],
|
|
)
|
|
|
|
query = float_vec([random.gauss(0, 1) for _ in range(128)])
|
|
|
|
# KNN with default oversample=2 (low)
|
|
rows_low = db.execute(
|
|
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
|
|
[query],
|
|
).fetchall()
|
|
assert len(rows_low) == 10
|
|
|
|
# Change oversample at runtime to high value
|
|
db.execute("INSERT INTO t(t) VALUES ('oversample=32')")
|
|
|
|
# KNN with oversample=32 (high) — same or better recall
|
|
rows_high = db.execute(
|
|
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
|
|
[query],
|
|
).fetchall()
|
|
assert len(rows_high) == 10
|
|
|
|
# Reset to original
|
|
db.execute("INSERT INTO t(t) VALUES ('oversample=2')")
|
|
|
|
rows_reset = db.execute(
|
|
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
|
|
[query],
|
|
).fetchall()
|
|
assert len(rows_reset) == 10
|
|
# After reset, should match the original low-oversample results
|
|
assert [r["rowid"] for r in rows_reset] == [r["rowid"] for r in rows_low]
|
|
|
|
|
|
def test_runtime_oversample_error(db):
|
|
"""Invalid oversample values should error."""
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[128] indexed by rescore(quantizer=bit)"
|
|
")"
|
|
)
|
|
with pytest.raises(sqlite3.OperationalError, match="oversample must be >= 1"):
|
|
db.execute("INSERT INTO t(t) VALUES ('oversample=0')")
|
|
|
|
with pytest.raises(sqlite3.OperationalError, match="oversample must be >= 1"):
|
|
db.execute("INSERT INTO t(t) VALUES ('oversample=-5')")
|
|
|
|
|
|
def test_unknown_command_errors(db):
|
|
"""Unknown command strings should produce a clear error."""
|
|
db.execute(
|
|
"CREATE VIRTUAL TABLE t USING vec0("
|
|
" embedding float[128] indexed by rescore(quantizer=bit)"
|
|
")"
|
|
)
|
|
with pytest.raises(sqlite3.OperationalError, match="unknown vec0 command"):
|
|
db.execute("INSERT INTO t(t) VALUES ('not_a_real_command')")
|