mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 16:56:27 +02:00
vtab_in handling
This commit is contained in:
parent
0db2e52974
commit
7b67c78530
6 changed files with 646 additions and 9 deletions
|
|
@ -3806,3 +3806,260 @@
|
|||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_vtab_in[allow-int-all]
|
||||
OrderedDict({
|
||||
'sql': "select rowid, n, distance from v where vector match '[0]' and k = 8 and n in (555, 999)",
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 1,
|
||||
'n': 999,
|
||||
'distance': 1.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 2,
|
||||
'n': 555,
|
||||
'distance': 2.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 3,
|
||||
'n': 999,
|
||||
'distance': 3.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 4,
|
||||
'n': 555,
|
||||
'distance': 4.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 5,
|
||||
'n': 999,
|
||||
'distance': 5.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 6,
|
||||
'n': 555,
|
||||
'distance': 6.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 7,
|
||||
'n': 999,
|
||||
'distance': 7.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 8,
|
||||
'n': 555,
|
||||
'distance': 8.0,
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_vtab_in[allow-int-superfluous]
|
||||
OrderedDict({
|
||||
'sql': "select rowid, n, distance from v where vector match '[0]' and k = 8 and n in (555, -1, -2)",
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 2,
|
||||
'n': 555,
|
||||
'distance': 2.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 4,
|
||||
'n': 555,
|
||||
'distance': 4.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 6,
|
||||
'n': 555,
|
||||
'distance': 6.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 8,
|
||||
'n': 555,
|
||||
'distance': 8.0,
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_vtab_in[allow-text-all]
|
||||
OrderedDict({
|
||||
'sql': "select rowid, t, distance from v where vector match '[0]' and k = 8 and t in ('aaaa', 'zzzz')",
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 1,
|
||||
't': 'aaaa',
|
||||
'distance': 1.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 2,
|
||||
't': 'aaaa',
|
||||
'distance': 2.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 3,
|
||||
't': 'aaaa',
|
||||
'distance': 3.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 4,
|
||||
't': 'aaaa',
|
||||
'distance': 4.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 5,
|
||||
't': 'zzzz',
|
||||
'distance': 5.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 6,
|
||||
't': 'zzzz',
|
||||
'distance': 6.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 7,
|
||||
't': 'zzzz',
|
||||
'distance': 7.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 8,
|
||||
't': 'zzzz',
|
||||
'distance': 8.0,
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_vtab_in[allow-text-superfluous]
|
||||
OrderedDict({
|
||||
'sql': "select rowid, t, distance from v where vector match '[0]' and k = 8 and t in ('aaaa', 'foo', 'bar')",
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 1,
|
||||
't': 'aaaa',
|
||||
'distance': 1.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 2,
|
||||
't': 'aaaa',
|
||||
'distance': 2.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 3,
|
||||
't': 'aaaa',
|
||||
'distance': 3.0,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 4,
|
||||
't': 'aaaa',
|
||||
'distance': 4.0,
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_vtab_in[block-bool]
|
||||
dict({
|
||||
'error': 'OperationalError',
|
||||
'message': "'xxx in (...)' is only available on INTEGER or TEXT metadata columns.",
|
||||
})
|
||||
# ---
|
||||
# name: test_vtab_in[block-float]
|
||||
dict({
|
||||
'error': 'OperationalError',
|
||||
'message': "'xxx in (...)' is only available on INTEGER or TEXT metadata columns.",
|
||||
})
|
||||
# ---
|
||||
# name: test_vtab_in_long_text[all]
|
||||
OrderedDict({
|
||||
'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (select value from json_each(?))",
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 1,
|
||||
't': 'aaaa',
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 2,
|
||||
't': 'aaaaaaaaaaaa_aaa',
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 3,
|
||||
't': 'bbbb',
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 4,
|
||||
't': 'bbbbbbbbbbbb_bbb',
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 5,
|
||||
't': 'cccc',
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 6,
|
||||
't': 'cccccccccccc_ccc',
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_vtab_in_long_text[individual-aaaa]
|
||||
OrderedDict({
|
||||
'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')",
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 1,
|
||||
't': 'aaaa',
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_vtab_in_long_text[individual-aaaaaaaaaaaa_aaa]
|
||||
OrderedDict({
|
||||
'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')",
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 2,
|
||||
't': 'aaaaaaaaaaaa_aaa',
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_vtab_in_long_text[individual-bbbb]
|
||||
OrderedDict({
|
||||
'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')",
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 3,
|
||||
't': 'bbbb',
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_vtab_in_long_text[individual-bbbbbbbbbbbb_bbb]
|
||||
OrderedDict({
|
||||
'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')",
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 4,
|
||||
't': 'bbbbbbbbbbbb_bbb',
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_vtab_in_long_text[individual-cccc]
|
||||
OrderedDict({
|
||||
'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')",
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 5,
|
||||
't': 'cccc',
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_vtab_in_long_text[individual-cccccccccccc_ccc]
|
||||
OrderedDict({
|
||||
'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')",
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 6,
|
||||
't': 'cccccccccccc_ccc',
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import pytest
|
||||
import sqlite3
|
||||
from collections import OrderedDict
|
||||
import json
|
||||
|
||||
|
||||
def test_constructor_limit(db, snapshot):
|
||||
|
|
@ -284,6 +286,87 @@ def test_knn(db, snapshot):
|
|||
)
|
||||
|
||||
|
||||
SUPPORTS_VTAB_IN = sqlite3.sqlite_version_info[1] >= 38
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not SUPPORTS_VTAB_IN, reason="requires vtab `x in (...)` support in SQLite >=3.38"
|
||||
)
|
||||
def test_vtab_in(db, snapshot):
|
||||
db.execute(
|
||||
"create virtual table v using vec0(vector float[1], n int, t text, b boolean, f float, chunk_size=8)"
|
||||
)
|
||||
db.executemany(
|
||||
"insert into v(rowid, vector, n, t, b, f) values (?, ?, ?, ?, ?, ?)",
|
||||
[
|
||||
(1, "[1]", 999, "aaaa", 0, 1.1),
|
||||
(2, "[2]", 555, "aaaa", 0, 1.1),
|
||||
(3, "[3]", 999, "aaaa", 0, 1.1),
|
||||
(4, "[4]", 555, "aaaa", 0, 1.1),
|
||||
(5, "[5]", 999, "zzzz", 0, 1.1),
|
||||
(6, "[6]", 555, "zzzz", 0, 1.1),
|
||||
(7, "[7]", 999, "zzzz", 0, 1.1),
|
||||
(8, "[8]", 555, "zzzz", 0, 1.1),
|
||||
],
|
||||
)
|
||||
assert exec(
|
||||
db, "select * from v where vector match '[0]' and k = 8 and b in (1, 0)"
|
||||
) == snapshot(name="block-bool")
|
||||
|
||||
assert exec(
|
||||
db, "select * from v where vector match '[0]' and k = 8 and f in (1.1, 0.0)"
|
||||
) == snapshot(name="block-float")
|
||||
|
||||
assert exec(
|
||||
db,
|
||||
"select rowid, n, distance from v where vector match '[0]' and k = 8 and n in (555, 999)",
|
||||
) == snapshot(name="allow-int-all")
|
||||
assert exec(
|
||||
db,
|
||||
"select rowid, n, distance from v where vector match '[0]' and k = 8 and n in (555, -1, -2)",
|
||||
) == snapshot(name="allow-int-superfluous")
|
||||
|
||||
assert exec(
|
||||
db,
|
||||
"select rowid, t, distance from v where vector match '[0]' and k = 8 and t in ('aaaa', 'zzzz')",
|
||||
) == snapshot(name="allow-text-all")
|
||||
assert exec(
|
||||
db,
|
||||
"select rowid, t, distance from v where vector match '[0]' and k = 8 and t in ('aaaa', 'foo', 'bar')",
|
||||
) == snapshot(name="allow-text-superfluous")
|
||||
|
||||
|
||||
def test_vtab_in_long_text(db, snapshot):
|
||||
db.execute(
|
||||
"create virtual table v using vec0(vector float[1], t text, chunk_size=8)"
|
||||
)
|
||||
data = [
|
||||
(1, "aaaa"),
|
||||
(2, "aaaaaaaaaaaa_aaa"),
|
||||
(3, "bbbb"),
|
||||
(4, "bbbbbbbbbbbb_bbb"),
|
||||
(5, "cccc"),
|
||||
(6, "cccccccccccc_ccc"),
|
||||
]
|
||||
db.executemany(
|
||||
"insert into v(rowid, vector, t) values (:rowid, printf('[%d]', :rowid), :vector)",
|
||||
[{"rowid": row[0], "vector": row[1]} for row in data],
|
||||
)
|
||||
|
||||
for _, lookup in data:
|
||||
assert exec(
|
||||
db,
|
||||
"select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')",
|
||||
[lookup],
|
||||
) == snapshot(name=f"individual-{lookup}")
|
||||
|
||||
assert exec(
|
||||
db,
|
||||
"select rowid, t from v where vector match '[0]' and k = 10 and t in (select value from json_each(?))",
|
||||
[json.dumps([row[1] for row in data])],
|
||||
) == snapshot(name="all")
|
||||
|
||||
|
||||
def test_idxstr(db, snapshot):
|
||||
db.execute(
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue