vtab_in handling

This commit is contained in:
Alex Garcia 2024-11-18 22:43:24 -08:00
parent 0db2e52974
commit 7b67c78530
6 changed files with 646 additions and 9 deletions

View file

@ -3806,3 +3806,260 @@
}),
})
# ---
# name: test_vtab_in[allow-int-all]
OrderedDict({
'sql': "select rowid, n, distance from v where vector match '[0]' and k = 8 and n in (555, 999)",
'rows': list([
OrderedDict({
'rowid': 1,
'n': 999,
'distance': 1.0,
}),
OrderedDict({
'rowid': 2,
'n': 555,
'distance': 2.0,
}),
OrderedDict({
'rowid': 3,
'n': 999,
'distance': 3.0,
}),
OrderedDict({
'rowid': 4,
'n': 555,
'distance': 4.0,
}),
OrderedDict({
'rowid': 5,
'n': 999,
'distance': 5.0,
}),
OrderedDict({
'rowid': 6,
'n': 555,
'distance': 6.0,
}),
OrderedDict({
'rowid': 7,
'n': 999,
'distance': 7.0,
}),
OrderedDict({
'rowid': 8,
'n': 555,
'distance': 8.0,
}),
]),
})
# ---
# name: test_vtab_in[allow-int-superfluous]
OrderedDict({
'sql': "select rowid, n, distance from v where vector match '[0]' and k = 8 and n in (555, -1, -2)",
'rows': list([
OrderedDict({
'rowid': 2,
'n': 555,
'distance': 2.0,
}),
OrderedDict({
'rowid': 4,
'n': 555,
'distance': 4.0,
}),
OrderedDict({
'rowid': 6,
'n': 555,
'distance': 6.0,
}),
OrderedDict({
'rowid': 8,
'n': 555,
'distance': 8.0,
}),
]),
})
# ---
# name: test_vtab_in[allow-text-all]
OrderedDict({
'sql': "select rowid, t, distance from v where vector match '[0]' and k = 8 and t in ('aaaa', 'zzzz')",
'rows': list([
OrderedDict({
'rowid': 1,
't': 'aaaa',
'distance': 1.0,
}),
OrderedDict({
'rowid': 2,
't': 'aaaa',
'distance': 2.0,
}),
OrderedDict({
'rowid': 3,
't': 'aaaa',
'distance': 3.0,
}),
OrderedDict({
'rowid': 4,
't': 'aaaa',
'distance': 4.0,
}),
OrderedDict({
'rowid': 5,
't': 'zzzz',
'distance': 5.0,
}),
OrderedDict({
'rowid': 6,
't': 'zzzz',
'distance': 6.0,
}),
OrderedDict({
'rowid': 7,
't': 'zzzz',
'distance': 7.0,
}),
OrderedDict({
'rowid': 8,
't': 'zzzz',
'distance': 8.0,
}),
]),
})
# ---
# name: test_vtab_in[allow-text-superfluous]
OrderedDict({
'sql': "select rowid, t, distance from v where vector match '[0]' and k = 8 and t in ('aaaa', 'foo', 'bar')",
'rows': list([
OrderedDict({
'rowid': 1,
't': 'aaaa',
'distance': 1.0,
}),
OrderedDict({
'rowid': 2,
't': 'aaaa',
'distance': 2.0,
}),
OrderedDict({
'rowid': 3,
't': 'aaaa',
'distance': 3.0,
}),
OrderedDict({
'rowid': 4,
't': 'aaaa',
'distance': 4.0,
}),
]),
})
# ---
# name: test_vtab_in[block-bool]
dict({
'error': 'OperationalError',
'message': "'xxx in (...)' is only available on INTEGER or TEXT metadata columns.",
})
# ---
# name: test_vtab_in[block-float]
dict({
'error': 'OperationalError',
'message': "'xxx in (...)' is only available on INTEGER or TEXT metadata columns.",
})
# ---
# name: test_vtab_in_long_text[all]
OrderedDict({
'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (select value from json_each(?))",
'rows': list([
OrderedDict({
'rowid': 1,
't': 'aaaa',
}),
OrderedDict({
'rowid': 2,
't': 'aaaaaaaaaaaa_aaa',
}),
OrderedDict({
'rowid': 3,
't': 'bbbb',
}),
OrderedDict({
'rowid': 4,
't': 'bbbbbbbbbbbb_bbb',
}),
OrderedDict({
'rowid': 5,
't': 'cccc',
}),
OrderedDict({
'rowid': 6,
't': 'cccccccccccc_ccc',
}),
]),
})
# ---
# name: test_vtab_in_long_text[individual-aaaa]
OrderedDict({
'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')",
'rows': list([
OrderedDict({
'rowid': 1,
't': 'aaaa',
}),
]),
})
# ---
# name: test_vtab_in_long_text[individual-aaaaaaaaaaaa_aaa]
OrderedDict({
'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')",
'rows': list([
OrderedDict({
'rowid': 2,
't': 'aaaaaaaaaaaa_aaa',
}),
]),
})
# ---
# name: test_vtab_in_long_text[individual-bbbb]
OrderedDict({
'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')",
'rows': list([
OrderedDict({
'rowid': 3,
't': 'bbbb',
}),
]),
})
# ---
# name: test_vtab_in_long_text[individual-bbbbbbbbbbbb_bbb]
OrderedDict({
'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')",
'rows': list([
OrderedDict({
'rowid': 4,
't': 'bbbbbbbbbbbb_bbb',
}),
]),
})
# ---
# name: test_vtab_in_long_text[individual-cccc]
OrderedDict({
'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')",
'rows': list([
OrderedDict({
'rowid': 5,
't': 'cccc',
}),
]),
})
# ---
# name: test_vtab_in_long_text[individual-cccccccccccc_ccc]
OrderedDict({
'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')",
'rows': list([
OrderedDict({
'rowid': 6,
't': 'cccccccccccc_ccc',
}),
]),
})
# ---

View file

@ -1,5 +1,7 @@
import pytest
import sqlite3
from collections import OrderedDict
import json
def test_constructor_limit(db, snapshot):
@ -284,6 +286,87 @@ def test_knn(db, snapshot):
)
SUPPORTS_VTAB_IN = sqlite3.sqlite_version_info[1] >= 38
@pytest.mark.skipif(
not SUPPORTS_VTAB_IN, reason="requires vtab `x in (...)` support in SQLite >=3.38"
)
def test_vtab_in(db, snapshot):
db.execute(
"create virtual table v using vec0(vector float[1], n int, t text, b boolean, f float, chunk_size=8)"
)
db.executemany(
"insert into v(rowid, vector, n, t, b, f) values (?, ?, ?, ?, ?, ?)",
[
(1, "[1]", 999, "aaaa", 0, 1.1),
(2, "[2]", 555, "aaaa", 0, 1.1),
(3, "[3]", 999, "aaaa", 0, 1.1),
(4, "[4]", 555, "aaaa", 0, 1.1),
(5, "[5]", 999, "zzzz", 0, 1.1),
(6, "[6]", 555, "zzzz", 0, 1.1),
(7, "[7]", 999, "zzzz", 0, 1.1),
(8, "[8]", 555, "zzzz", 0, 1.1),
],
)
assert exec(
db, "select * from v where vector match '[0]' and k = 8 and b in (1, 0)"
) == snapshot(name="block-bool")
assert exec(
db, "select * from v where vector match '[0]' and k = 8 and f in (1.1, 0.0)"
) == snapshot(name="block-float")
assert exec(
db,
"select rowid, n, distance from v where vector match '[0]' and k = 8 and n in (555, 999)",
) == snapshot(name="allow-int-all")
assert exec(
db,
"select rowid, n, distance from v where vector match '[0]' and k = 8 and n in (555, -1, -2)",
) == snapshot(name="allow-int-superfluous")
assert exec(
db,
"select rowid, t, distance from v where vector match '[0]' and k = 8 and t in ('aaaa', 'zzzz')",
) == snapshot(name="allow-text-all")
assert exec(
db,
"select rowid, t, distance from v where vector match '[0]' and k = 8 and t in ('aaaa', 'foo', 'bar')",
) == snapshot(name="allow-text-superfluous")
def test_vtab_in_long_text(db, snapshot):
db.execute(
"create virtual table v using vec0(vector float[1], t text, chunk_size=8)"
)
data = [
(1, "aaaa"),
(2, "aaaaaaaaaaaa_aaa"),
(3, "bbbb"),
(4, "bbbbbbbbbbbb_bbb"),
(5, "cccc"),
(6, "cccccccccccc_ccc"),
]
db.executemany(
"insert into v(rowid, vector, t) values (:rowid, printf('[%d]', :rowid), :vector)",
[{"rowid": row[0], "vector": row[1]} for row in data],
)
for _, lookup in data:
assert exec(
db,
"select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')",
[lookup],
) == snapshot(name=f"individual-{lookup}")
assert exec(
db,
"select rowid, t from v where vector match '[0]' and k = 10 and t in (select value from json_each(?))",
[json.dumps([row[1] for row in data])],
) == snapshot(name="all")
def test_idxstr(db, snapshot):
db.execute(
"""