diff --git a/test.sql b/test.sql index 9d615a7..8cd3f30 100644 --- a/test.sql +++ b/test.sql @@ -1,5 +1,5 @@ -.load dist/vec0main +.load dist/vec0 .bail on .mode qbox diff --git a/tests/__snapshots__/test-knn-distance-constraints.ambr b/tests/__snapshots__/test-knn-distance-constraints.ambr new file mode 100644 index 0000000..87695a1 --- /dev/null +++ b/tests/__snapshots__/test-knn-distance-constraints.ambr @@ -0,0 +1,273 @@ +# serializer version: 1 +# name: test_normal + OrderedDict({ + 'sql': 'SELECT * FROM v', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'embedding': b'\x00\x00\x80?', + 'is_odd': 1, + }), + OrderedDict({ + 'rowid': 2, + 'embedding': b'\x00\x00\x00@', + 'is_odd': 0, + }), + OrderedDict({ + 'rowid': 3, + 'embedding': b'\x00\x00@@', + 'is_odd': 1, + }), + OrderedDict({ + 'rowid': 4, + 'embedding': b'\x00\x00\x80@', + 'is_odd': 0, + }), + OrderedDict({ + 'rowid': 5, + 'embedding': b'\x00\x00\xa0@', + 'is_odd': 1, + }), + OrderedDict({ + 'rowid': 6, + 'embedding': b'\x00\x00\xc0@', + 'is_odd': 0, + }), + OrderedDict({ + 'rowid': 7, + 'embedding': b'\x00\x00\xe0@', + 'is_odd': 1, + }), + OrderedDict({ + 'rowid': 8, + 'embedding': b'\x00\x00\x00A', + 'is_odd': 0, + }), + OrderedDict({ + 'rowid': 9, + 'embedding': b'\x00\x00\x10A', + 'is_odd': 1, + }), + OrderedDict({ + 'rowid': 10, + 'embedding': b'\x00\x00 A', + 'is_odd': 0, + }), + OrderedDict({ + 'rowid': 11, + 'embedding': b'\x00\x000A', + 'is_odd': 1, + }), + OrderedDict({ + 'rowid': 12, + 'embedding': b'\x00\x00@A', + 'is_odd': 0, + }), + OrderedDict({ + 'rowid': 13, + 'embedding': b'\x00\x00PA', + 'is_odd': 1, + }), + OrderedDict({ + 'rowid': 14, + 'embedding': b'\x00\x00`A', + 'is_odd': 0, + }), + OrderedDict({ + 'rowid': 15, + 'embedding': b'\x00\x00pA', + 'is_odd': 1, + }), + OrderedDict({ + 'rowid': 16, + 'embedding': b'\x00\x00\x80A', + 'is_odd': 0, + }), + OrderedDict({ + 'rowid': 17, + 'embedding': b'\x00\x00\x88A', + 'is_odd': 1, + }), + ]), + }) +# --- +# name: test_normal.1 + OrderedDict({ + 'sql': 'select rowid, distance from v where embedding match ? and k = ? ', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'distance': 0.0, + }), + OrderedDict({ + 'rowid': 2, + 'distance': 1.0, + }), + OrderedDict({ + 'rowid': 3, + 'distance': 2.0, + }), + OrderedDict({ + 'rowid': 4, + 'distance': 3.0, + }), + OrderedDict({ + 'rowid': 5, + 'distance': 4.0, + }), + ]), + }) +# --- +# name: test_normal.2 + OrderedDict({ + 'sql': 'select rowid, distance from v where embedding match ? and k = ? AND distance > 5', + 'rows': list([ + OrderedDict({ + 'rowid': 7, + 'distance': 6.0, + }), + OrderedDict({ + 'rowid': 8, + 'distance': 7.0, + }), + OrderedDict({ + 'rowid': 9, + 'distance': 8.0, + }), + OrderedDict({ + 'rowid': 10, + 'distance': 9.0, + }), + OrderedDict({ + 'rowid': 11, + 'distance': 10.0, + }), + ]), + }) +# --- +# name: test_normal.3 + OrderedDict({ + 'sql': 'select rowid, distance from v where embedding match ? and k = ? AND distance >= 5', + 'rows': list([ + OrderedDict({ + 'rowid': 6, + 'distance': 5.0, + }), + OrderedDict({ + 'rowid': 7, + 'distance': 6.0, + }), + OrderedDict({ + 'rowid': 8, + 'distance': 7.0, + }), + OrderedDict({ + 'rowid': 9, + 'distance': 8.0, + }), + OrderedDict({ + 'rowid': 10, + 'distance': 9.0, + }), + ]), + }) +# --- +# name: test_normal.4 + OrderedDict({ + 'sql': 'select rowid, distance from v where embedding match ? and k = ? AND distance < 3', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'distance': 0.0, + }), + OrderedDict({ + 'rowid': 2, + 'distance': 1.0, + }), + OrderedDict({ + 'rowid': 3, + 'distance': 2.0, + }), + ]), + }) +# --- +# name: test_normal.5 + OrderedDict({ + 'sql': 'select rowid, distance from v where embedding match ? and k = ? AND distance <= 3', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'distance': 0.0, + }), + OrderedDict({ + 'rowid': 2, + 'distance': 1.0, + }), + OrderedDict({ + 'rowid': 3, + 'distance': 2.0, + }), + OrderedDict({ + 'rowid': 4, + 'distance': 3.0, + }), + ]), + }) +# --- +# name: test_normal.6 + OrderedDict({ + 'sql': 'select rowid, distance from v where embedding match ? and k = ? AND distance > 7 AND distance <= 10', + 'rows': list([ + OrderedDict({ + 'rowid': 9, + 'distance': 8.0, + }), + OrderedDict({ + 'rowid': 10, + 'distance': 9.0, + }), + OrderedDict({ + 'rowid': 11, + 'distance': 10.0, + }), + ]), + }) +# --- +# name: test_normal.7 + OrderedDict({ + 'sql': 'select rowid, distance from v where embedding match ? and k = ? AND distance BETWEEN 7 AND 10', + 'rows': list([ + OrderedDict({ + 'rowid': 8, + 'distance': 7.0, + }), + OrderedDict({ + 'rowid': 9, + 'distance': 8.0, + }), + OrderedDict({ + 'rowid': 10, + 'distance': 9.0, + }), + OrderedDict({ + 'rowid': 11, + 'distance': 10.0, + }), + ]), + }) +# --- +# name: test_normal.8 + OrderedDict({ + 'sql': 'select rowid, distance from v where embedding match ? and k = ? AND is_odd == TRUE AND distance BETWEEN 7 AND 10', + 'rows': list([ + OrderedDict({ + 'rowid': 9, + 'distance': 8.0, + }), + OrderedDict({ + 'rowid': 11, + 'distance': 10.0, + }), + ]), + }) +# --- diff --git a/tests/test-knn-distance-constraints.py b/tests/test-knn-distance-constraints.py new file mode 100644 index 0000000..ed2d9ec --- /dev/null +++ b/tests/test-knn-distance-constraints.py @@ -0,0 +1,82 @@ +import sqlite3 +from collections import OrderedDict + + +def test_normal(db, snapshot): + db.execute("create virtual table v using vec0(embedding float[1], is_odd boolean, chunk_size=8)") + db.executemany( + "insert into v(rowid, is_odd, embedding) values (?1, ?1 % 2, ?2)", + [ + [1, "[1]"], + [2, "[2]"], + [3, "[3]"], + [4, "[4]"], + [5, "[5]"], + [6, "[6]"], + [7, "[7]"], + [8, "[8]"], + [9, "[9]"], + [10, "[10]"], + [11, "[11]"], + [12, "[12]"], + [13, "[13]"], + [14, "[14]"], + [15, "[15]"], + [16, "[16]"], + [17, "[17]"], + ], + ) + assert exec(db,"SELECT * FROM v") == snapshot() + + BASE_KNN = "select rowid, distance from v where embedding match ? and k = ? " + assert exec(db, BASE_KNN, ["[1]", 5]) == snapshot() + assert exec(db, BASE_KNN + "AND distance > 5", ["[1]", 5]) == snapshot() + assert exec(db, BASE_KNN + "AND distance >= 5", ["[1]", 5]) == snapshot() + assert exec(db, BASE_KNN + "AND distance < 3", ["[1]", 5]) == snapshot() + assert exec(db, BASE_KNN + "AND distance <= 3", ["[1]", 5]) == snapshot() + assert exec(db, BASE_KNN + "AND distance > 7 AND distance <= 10", ["[1]", 5]) == snapshot() + assert exec(db, BASE_KNN + "AND distance BETWEEN 7 AND 10", ["[1]", 5]) == snapshot() + assert exec(db, BASE_KNN + "AND is_odd == TRUE AND distance BETWEEN 7 AND 10", ["[1]", 5]) == snapshot() + + +class Row: + def __init__(self): + pass + + def __repr__(self) -> str: + return repr() + + +def exec(db, sql, parameters=[]): + try: + rows = db.execute(sql, parameters).fetchall() + except (sqlite3.OperationalError, sqlite3.DatabaseError) as e: + return { + "error": e.__class__.__name__, + "message": str(e), + } + a = [] + for row in rows: + o = OrderedDict() + for k in row.keys(): + o[k] = row[k] + a.append(o) + result = OrderedDict() + result["sql"] = sql + result["rows"] = a + return result + + +def vec0_shadow_table_contents(db, v): + shadow_tables = [ + row[0] + for row in db.execute( + "select name from sqlite_master where name like ? order by 1", [f"{v}_%"] + ).fetchall() + ] + o = {} + for shadow_table in shadow_tables: + if shadow_table.endswith("_info"): + continue + o[shadow_table] = exec(db, f"select * from {shadow_table}") + return o