diff --git a/benchmarks-ann/Makefile b/benchmarks-ann/Makefile index a631478..9ae456e 100644 --- a/benchmarks-ann/Makefile +++ b/benchmarks-ann/Makefile @@ -4,9 +4,9 @@ EXT = ../dist/vec0 # --- Baseline (brute-force) configs --- BASELINES = \ - "brute-float:type=baseline,variant=float" \ - "brute-int8:type=baseline,variant=int8" \ - "brute-bit:type=baseline,variant=bit" + "brute-float:type=vec0-flat,variant=float" \ + "brute-int8:type=vec0-flat,variant=int8" \ + "brute-bit:type=vec0-flat,variant=bit" # --- IVF configs --- IVF_CONFIGS = \ @@ -43,7 +43,7 @@ ground-truth: seed # --- Quick smoke test --- bench-smoke: seed $(BENCH) --subset-size 5000 -k 10 -n 20 --dataset cohere1m -o runs \ - "brute-float:type=baseline,variant=float" \ + "brute-float:type=vec0-flat,variant=float" \ "ivf-quick:type=ivf,nlist=16,nprobe=4" \ "diskann-quick:type=diskann,R=48,L=64,quantizer=binary" diff --git a/tests/test-diskann.py b/tests/test-diskann.py index f2a56a1..d3f3e86 100644 --- a/tests/test-diskann.py +++ b/tests/test-diskann.py @@ -1246,3 +1246,46 @@ def test_diskann_delete_interleaved_with_knn(db): returned = {r["rowid"] for r in rows} assert returned.issubset(alive), \ f"Deleted rowid {to_del} found in KNN results" + + +# ====================================================================== +# Text primary key + DiskANN +# ====================================================================== + + +def test_diskann_text_pk_insert_knn_delete(db): + """DiskANN with text primary key: insert, KNN, delete, KNN again.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + id text primary key, + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8) + ) + """) + + vecs = { + "alpha": [1, 0, 0, 0, 0, 0, 0, 0], + "beta": [0, 1, 0, 0, 0, 0, 0, 0], + "gamma": [0, 0, 1, 0, 0, 0, 0, 0], + "delta": [0, 0, 0, 1, 0, 0, 0, 0], + "epsilon": [0, 0, 0, 0, 1, 0, 0, 0], + } + for name, vec in vecs.items(): + db.execute("INSERT INTO t(id, emb) VALUES (?, ?)", [name, _f32(vec)]) + + # KNN should return text IDs + rows = db.execute( + "SELECT id, distance FROM t WHERE emb MATCH ? AND k=3", + [_f32([1, 0, 0, 0, 0, 0, 0, 0])], + ).fetchall() + assert len(rows) >= 1 + ids = [r["id"] for r in rows] + assert "alpha" in ids # closest to query + + # Delete and verify + db.execute("DELETE FROM t WHERE id = 'alpha'") + rows = db.execute( + "SELECT id FROM t WHERE emb MATCH ? AND k=3", + [_f32([1, 0, 0, 0, 0, 0, 0, 0])], + ).fetchall() + ids = [r["id"] for r in rows] + assert "alpha" not in ids diff --git a/tests/test-insert-delete.py b/tests/test-insert-delete.py index eb34f84..7e97ea2 100644 --- a/tests/test-insert-delete.py +++ b/tests/test-insert-delete.py @@ -483,3 +483,57 @@ def test_delete_one_chunk_of_two_shrinks_pages(tmp_path): row = db.execute("select emb from v where rowid = ?", [i]).fetchone() assert row[0] == _f32([float(i)] * dims) db.close() + + +def test_wal_concurrent_reader_during_write(tmp_path): + """In WAL mode, a reader should see a consistent snapshot while a writer inserts.""" + dims = 4 + db_path = str(tmp_path / "test.db") + + # Writer: create table, insert initial rows, enable WAL + writer = sqlite3.connect(db_path) + writer.enable_load_extension(True) + writer.load_extension("dist/vec0") + writer.execute("PRAGMA journal_mode=WAL") + writer.execute( + f"CREATE VIRTUAL TABLE v USING vec0(emb float[{dims}])" + ) + for i in range(1, 11): + writer.execute("INSERT INTO v(rowid, emb) VALUES (?, ?)", [i, _f32([float(i)] * dims)]) + writer.commit() + + # Reader: open separate connection, start read + reader = sqlite3.connect(db_path) + reader.enable_load_extension(True) + reader.load_extension("dist/vec0") + + # Reader sees 10 rows + count_before = reader.execute("SELECT count(*) FROM v").fetchone()[0] + assert count_before == 10 + + # Writer inserts more rows (not yet committed) + writer.execute("BEGIN") + for i in range(11, 21): + writer.execute("INSERT INTO v(rowid, emb) VALUES (?, ?)", [i, _f32([float(i)] * dims)]) + + # Reader still sees 10 (WAL snapshot isolation) + count_during = reader.execute("SELECT count(*) FROM v").fetchone()[0] + assert count_during == 10 + + # KNN during writer's transaction should work on reader's snapshot + rows = reader.execute( + "SELECT rowid FROM v WHERE emb MATCH ? AND k = 5", + [_f32([1.0] * dims)], + ).fetchall() + assert len(rows) == 5 + assert all(r[0] <= 10 for r in rows) # only original rows + + # Writer commits + writer.commit() + + # Reader sees new rows after re-query (new snapshot) + count_after = reader.execute("SELECT count(*) FROM v").fetchone()[0] + assert count_after == 20 + + writer.close() + reader.close() diff --git a/tests/test-rescore.py b/tests/test-rescore.py index 1dc6cd7..7c9c669 100644 --- a/tests/test-rescore.py +++ b/tests/test-rescore.py @@ -595,3 +595,40 @@ def test_corrupt_zeroblob_validity(db): ).fetchall() except sqlite3.OperationalError: pass # Error is acceptable — crash is not + + +def test_rescore_text_pk_insert_knn_delete(db): + """Rescore with text primary key: insert, KNN, delete, KNN again.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " id text primary key," + " embedding float[128] indexed by rescore(quantizer=bit)" + ")" + ) + + import random + random.seed(99) + vecs = {} + for name in ["alpha", "beta", "gamma", "delta", "epsilon"]: + v = [random.gauss(0, 1) for _ in range(128)] + vecs[name] = v + db.execute("INSERT INTO t(id, embedding) VALUES (?, ?)", [name, float_vec(v)]) + + # KNN should return text IDs + rows = db.execute( + "SELECT id, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 3", + [float_vec(vecs["alpha"])], + ).fetchall() + assert len(rows) >= 1 + ids = [r["id"] for r in rows] + assert "alpha" in ids + + # Delete and verify + db.execute("DELETE FROM t WHERE id = 'alpha'") + rows = db.execute( + "SELECT id FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 3", + [float_vec(vecs["alpha"])], + ).fetchall() + ids = [r["id"] for r in rows] + assert "alpha" not in ids + assert len(rows) >= 1 # other results still returned