vec0 point and knn error handling

This commit is contained in:
Alex Garcia 2024-06-28 15:29:13 -07:00
parent 2fdd760dd1
commit a5525c9a5d
2 changed files with 366 additions and 131 deletions

View file

@ -16,6 +16,8 @@ EXT_PATH = "./dist/vec0"
SUPPORTS_SUBTYPE = sqlite3.sqlite_version_info[1] > 38
SUPPORTS_DROP_COLUMN = sqlite3.sqlite_version_info[1] >= 35
SUPPORTS_VTAB_IN = sqlite3.sqlite_version_info[1] >= 38
SUPPORTS_VTAB_LIMIT = sqlite3.sqlite_version_info[1] >= 41
def bitmap_full(n: int) -> bytearray:
@ -1133,38 +1135,138 @@ def test_vec0_updates():
# ]
def test_vec0_point():
db = connect(EXT_PATH)
db.execute("CREATE VIRTUAL TABLE t USING vec0(a float[1], b float[1])")
db.execute(
"INSERT INTO t VALUES (1, X'AABBCCDD', X'00112233'), (2, X'AABBCCDD', X'99887766');"
)
assert execute_all(db, "select * from t where rowid = 1") == [
{
"a": b"\xaa\xbb\xcc\xdd",
"b": b'\x00\x11"3',
"rowid": 1,
}
]
assert execute_all(db, "select * from t where rowid = 999") == []
db.execute(
"CREATE VIRTUAL TABLE t2 USING vec0(id text primary key, a float[1], b float[1])"
)
db.execute(
"INSERT INTO t2 VALUES ('A', X'AABBCCDD', X'00112233'), ('B', X'AABBCCDD', X'99887766');"
)
assert execute_all(db, "select * from t2 where id = 'A'") == [
{
"a": b"\xaa\xbb\xcc\xdd",
"b": b'\x00\x11"3',
"id": "A",
}
]
assert execute_all(db, "select * from t2 where id = 'xxx'") == []
def test_vec0_text_pk():
db = connect(EXT_PATH)
db.execute(
"""
create virtual table t using vec0(
t_id text primary key,
aaa float[8],
bbb float8[8]
aaa float[1],
bbb float8[1]
);
"""
)
assert execute_all(db, "select * from t") == []
with _raises(
"The t virtual table was declared with a TEXT primary key, but a non-TEXT value was provided in an INSERT."
):
db.execute("INSERT INTO t VALUES (1, X'AABBCCDD', X'AABBCCDD')")
db.executemany(
"INSERT INTO t VALUES (:t_id, :aaa, :bbb)",
[
{
"t_id": "t_1",
"aaa": "[.1, .1, .1, .1, -.1, -.1, -.1, -.1]",
"bbb": "[.1, .1, .1, .1, -.1, -.1, -.1, -.1]",
"aaa": "[.1]",
"bbb": "[-.1]",
},
{
"t_id": "t_2",
"aaa": "[.1, .1, .1, .1, -.1, -.1, -.1, -.1]",
"bbb": "[.1, .1, .1, .1, -.1, -.1, -.1, -.1]",
"aaa": "[.2]",
"bbb": "[-.2]",
},
{
"t_id": "t_3",
"aaa": "[.1, .1, .1, .1, -.1, -.1, -.1, -.1]",
"bbb": "[.1, .1, .1, .1, -.1, -.1, -.1, -.1]",
"aaa": "[.3]",
"bbb": "[-.3]",
},
],
)
assert execute_all(db, "select * from t") == []
assert execute_all(db, "select t_id from t") == [
{"t_id": "t_1"},
{"t_id": "t_2"},
{"t_id": "t_3"},
]
assert execute_all(db, "select * from t") == [
{"t_id": "t_1", "aaa": _f32([0.1]), "bbb": _f32([-0.1])},
{"t_id": "t_2", "aaa": _f32([0.2]), "bbb": _f32([-0.2])},
{"t_id": "t_3", "aaa": _f32([0.3]), "bbb": _f32([-0.3])},
]
# EVIDENCE-OF: V09901_26739 vec0 full scan catches _rowid prep error
db.set_authorizer(authorizer_deny_on(sqlite3.SQLITE_READ, "t_rowids", "rowid"))
with _raises(
"Error preparing rowid scan: access to t_rowids.rowid is prohibited",
sqlite3.DatabaseError,
):
db.execute("select * from t")
db.set_authorizer(None)
def test_vec0_best_index():
db = connect(EXT_PATH)
db.execute(
"""
create virtual table t using vec0(
aaa float[1],
bbb float8[1]
);
"""
)
with _raises("only 1 MATCH operator is allowed in a single vec0 query"):
db.execute("select * from t where aaa match NULL and bbb match NULL")
if SUPPORTS_VTAB_IN:
with _raises(
"only 1 'rowid in (..)' operator is allowed in a single vec0 query"
):
db.execute("select * from t where rowid in(4,5,6) and rowid in (1, 2,3)")
with _raises("A LIMIT or 'k = ?' constraint is required on vec0 knn queries."):
db.execute("select * from t where aaa MATCH ?")
with _raises("Only LIMIT or 'k =?' can be provided, not both"):
db.execute("select * from t where aaa MATCH ? and k = 10 limit 20")
with _raises(
"Only a single 'ORDER BY distance' clause is allowed on vec0 KNN queries"
):
db.execute(
"select * from t where aaa MATCH NULL and k = 10 order by distance, distance"
)
with _raises(
"Only ascending in ORDER BY distance clause is supported, DESC is not supported yet."
):
db.execute(
"select * from t where aaa MATCH NULL and k = 10 order by distance desc"
)
def authorizer_deny_on(operation, x1, x2=None):
@ -1610,6 +1712,13 @@ def test_smoke():
"select * from vec_xyz where a match X'' and k = 10 order by distance"
),
)
if SUPPORTS_VTAB_LIMIT:
assert re.match(
"SCAN (TABLE )?vec_xyz VIRTUAL TABLE INDEX 0:knn:",
explain_query_plan(
"select * from vec_xyz where a match X'' order by distance limit 10"
),
)
assert re.match(
"SCAN (TABLE )?vec_xyz VIRTUAL TABLE INDEX 0:fullscan",
explain_query_plan("select * from vec_xyz"),