sqlite-vec/tests/test-ivf.py
Alex Garcia 3358e127f6 Add IVF index for vec0 virtual table
Add inverted file (IVF) index type: partitions vectors into clusters via
k-means, quantizes to int8, and scans only the nearest nprobe partitions at
query time. Includes shadow table management, insert/delete, KNN integration,
compile flag (SQLITE_VEC_ENABLE_IVF), fuzz targets, and tests. Removes
superseded ivf-benchmarks/ directory.
2026-03-31 01:18:47 -07:00

426 lines
13 KiB
Python

import pytest
import sqlite3
import struct
import math
from helpers import _f32, exec
@pytest.fixture()
def db():
db = sqlite3.connect(":memory:")
db.row_factory = sqlite3.Row
db.enable_load_extension(True)
db.load_extension("dist/vec0")
db.enable_load_extension(False)
return db
def ivf_total_vectors(db, table="t", col=0):
"""Count total vectors across all IVF cells."""
return db.execute(
f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d}"
).fetchone()[0]
def ivf_unassigned_count(db, table="t", col=0):
"""Count vectors in unassigned cells (centroid_id=-1)."""
return db.execute(
f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d} WHERE centroid_id = -1"
).fetchone()[0]
def ivf_assigned_count(db, table="t", col=0):
"""Count vectors in trained cells (centroid_id >= 0)."""
return db.execute(
f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d} WHERE centroid_id >= 0"
).fetchone()[0]
# ============================================================================
# Parser tests
# ============================================================================
def test_ivf_create_defaults(db):
"""ivf() with no args uses defaults."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf())"
)
tables = [
r[0]
for r in db.execute(
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
).fetchall()
]
assert "t_ivf_centroids00" in tables
assert "t_ivf_cells00" in tables
assert "t_ivf_rowid_map00" in tables
def test_ivf_create_custom_params(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=64, nprobe=8))"
)
tables = [
r[0]
for r in db.execute(
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
).fetchall()
]
assert "t_ivf_centroids00" in tables
assert "t_ivf_cells00" in tables
def test_ivf_create_with_distance_metric(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0(v float[4] distance_metric=cosine indexed by ivf(nlist=16))"
)
tables = [
r[0]
for r in db.execute(
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
).fetchall()
]
assert "t_ivf_centroids00" in tables
def test_ivf_create_error_unknown_key(db):
result = exec(
db,
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(bogus=1))",
)
assert "error" in result
def test_ivf_create_error_nprobe_gt_nlist(db):
result = exec(
db,
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=10))",
)
assert "error" in result
# ============================================================================
# Shadow table tests
# ============================================================================
def test_ivf_shadow_tables_created(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=8))"
)
trained = db.execute(
"SELECT value FROM t_info WHERE key = 'ivf_trained_0'"
).fetchone()[0]
assert str(trained) == "0"
# No cells yet (created lazily on first insert)
count = db.execute(
"SELECT count(*) FROM t_ivf_cells00"
).fetchone()[0]
assert count == 0
def test_ivf_drop_cleans_up(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
)
db.execute("DROP TABLE t")
tables = [
r[0]
for r in db.execute(
"SELECT name FROM sqlite_master WHERE type='table'"
).fetchall()
]
assert not any("ivf" in t for t in tables)
# ============================================================================
# Insert tests (flat mode)
# ============================================================================
def test_ivf_insert_flat_mode(db):
"""Before training, vectors go to unassigned cell."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
)
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 2, 3, 4])])
db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([5, 6, 7, 8])])
assert ivf_unassigned_count(db) == 2
assert ivf_assigned_count(db) == 0
# rowid_map should have 2 entries
assert db.execute("SELECT count(*) FROM t_ivf_rowid_map00").fetchone()[0] == 2
def test_ivf_delete_flat_mode(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
)
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 2, 3, 4])])
db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([5, 6, 7, 8])])
db.execute("DELETE FROM t WHERE rowid = 1")
assert ivf_unassigned_count(db) == 1
assert db.execute("SELECT count(*) FROM t_ivf_rowid_map00").fetchone()[0] == 1
# ============================================================================
# KNN flat mode tests
# ============================================================================
def test_ivf_knn_flat_mode(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
)
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 0, 0, 0])])
db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([2, 0, 0, 0])])
db.execute("INSERT INTO t(rowid, v) VALUES (3, ?)", [_f32([9, 0, 0, 0])])
rows = db.execute(
"SELECT rowid, distance FROM t WHERE v MATCH ? AND k = 2",
[_f32([1.5, 0, 0, 0])],
).fetchall()
assert len(rows) == 2
rowids = {r[0] for r in rows}
assert rowids == {1, 2}
def test_ivf_knn_flat_empty(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
)
rows = db.execute(
"SELECT rowid FROM t WHERE v MATCH ? AND k = 5",
[_f32([1, 0, 0, 0])],
).fetchall()
assert len(rows) == 0
# ============================================================================
# compute-centroids tests
# ============================================================================
def test_compute_centroids(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
)
for i in range(40):
db.execute(
"INSERT INTO t(rowid, v) VALUES (?, ?)",
[i, _f32([i % 10, i // 10, 0, 0])],
)
assert ivf_unassigned_count(db) == 40
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
# After training: unassigned cell should be gone (or empty), vectors in trained cells
assert ivf_unassigned_count(db) == 0
assert ivf_assigned_count(db) == 40
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 4
trained = db.execute(
"SELECT value FROM t_info WHERE key='ivf_trained_0'"
).fetchone()[0]
assert str(trained) == "1"
def test_compute_centroids_recompute(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2))"
)
for i in range(20):
db.execute(
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
)
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
assert ivf_assigned_count(db) == 20
# ============================================================================
# Insert after training (assigned mode)
# ============================================================================
def test_ivf_insert_after_training(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2))"
)
for i in range(20):
db.execute(
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
)
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
db.execute(
"INSERT INTO t(rowid, v) VALUES (100, ?)", [_f32([5, 0, 0, 0])]
)
# Should be in a trained cell, not unassigned
row = db.execute(
"SELECT m.cell_id, c.centroid_id FROM t_ivf_rowid_map00 m "
"JOIN t_ivf_cells00 c ON c.rowid = m.cell_id "
"WHERE m.rowid = 100"
).fetchone()
assert row is not None
assert row[1] >= 0 # centroid_id >= 0 means trained cell
# ============================================================================
# KNN after training (IVF probe mode)
# ============================================================================
def test_ivf_knn_after_training(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=4))"
)
for i in range(100):
db.execute(
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
)
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
rows = db.execute(
"SELECT rowid, distance FROM t WHERE v MATCH ? AND k = 5",
[_f32([50.0, 0, 0, 0])],
).fetchall()
assert len(rows) == 5
assert rows[0][0] == 50
assert rows[0][1] < 0.01
def test_ivf_knn_k_larger_than_n(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
)
for i in range(5):
db.execute(
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
)
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
rows = db.execute(
"SELECT rowid FROM t WHERE v MATCH ? AND k = 100",
[_f32([0, 0, 0, 0])],
).fetchall()
assert len(rows) == 5
# ============================================================================
# Manual centroid import (set-centroid, assign-vectors)
# ============================================================================
def test_set_centroid_and_assign(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=0))"
)
for i in range(20):
db.execute(
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
)
db.execute(
"INSERT INTO t(rowid, v) VALUES ('set-centroid:0', ?)",
[_f32([5, 0, 0, 0])],
)
db.execute(
"INSERT INTO t(rowid, v) VALUES ('set-centroid:1', ?)",
[_f32([15, 0, 0, 0])],
)
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
db.execute("INSERT INTO t(rowid) VALUES ('assign-vectors')")
assert ivf_unassigned_count(db) == 0
assert ivf_assigned_count(db) == 20
# ============================================================================
# clear-centroids
# ============================================================================
def test_clear_centroids(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2))"
)
for i in range(20):
db.execute(
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
)
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
db.execute("INSERT INTO t(rowid) VALUES ('clear-centroids')")
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 0
assert ivf_unassigned_count(db) == 20
trained = db.execute(
"SELECT value FROM t_info WHERE key='ivf_trained_0'"
).fetchone()[0]
assert str(trained) == "0"
# ============================================================================
# Delete after training
# ============================================================================
def test_ivf_delete_after_training(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2))"
)
for i in range(10):
db.execute(
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
)
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
assert ivf_assigned_count(db) == 10
db.execute("DELETE FROM t WHERE rowid = 5")
assert ivf_assigned_count(db) == 9
assert db.execute("SELECT count(*) FROM t_ivf_rowid_map00").fetchone()[0] == 9
# ============================================================================
# Recall test
# ============================================================================
def test_ivf_recall_nprobe_equals_nlist(db):
db.execute(
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=8, nprobe=8))"
)
for i in range(100):
db.execute(
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
)
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
rows = db.execute(
"SELECT rowid FROM t WHERE v MATCH ? AND k = 10",
[_f32([50.0, 0, 0, 0])],
).fetchall()
rowids = {r[0] for r in rows}
# 45 and 55 are equidistant from 50, so either may appear in top 10
assert 50 in rowids
assert len(rowids) == 10
assert all(45 <= r <= 55 for r in rowids)