Metadata filtering (#124)

* initial pass at PARTITION KEY support.

* Initial pass, allow auxiliary columns on vec0 virtual tables

* update TODO

* Initial pass at metadata filtering

* unit tests

* gha this PR branch

* fixup tests

* doc internal

* fix tests, KNN/rowids in

* define SQLITE_INDEX_CONSTRAINT_OFFSET

* whoops

* update tests, syrupy, use uv

* un ignore pyproject.toml

* dot

* tests/

* type error?

* win: .exe, update error name

* try fix macos python, paren around expr?

* win bash?

* dbg :(

* explicit error

* op

* dbg win

* win ./tests/.venv/Scripts/python.exe

* block UPDATEs on partition key values for now

* test this branch

* accidentally removved "partition key type mistmatch" block during merge

* typo ugh

* bruv

* start aux snapshots

* drop aux shadow table on destroy

* enforce column types

* block WHERE constraints on auxiliary columns in KNN queries

* support delete

* support UPDATE on auxiliary columns

* test this PR

* dont inline that

* test-metadata.py

* memzero text buffer

* stress test

* more snpashot tests

* rm double/int32, just float/int64

* finish type checking

* long text support

* DELETE support

* UPDATE support

* fix snapshot names

* drop not-used in eqp

* small fixes

* boolean comparison handling

* ensure error is raised when long string constraint

* new version string for beta builds

* typo whoops

* ann-filtering-benchmark directory

* test-case

* updates

* fix aux column error when using non-default rowid values, needs test

* refactor some text knn filtering

* rowids blob read only on text metadata filters

* refactor

* add failing test causes for non eq text knn

* text knn NE

* test cases diff

* GT

* text knn GT/GE fixes

* text knn LT/LE

* clean

* vtab_in handling

* unblock aux failures for now

* guard sqlite3_vtab_in

* else in guard?

* fixes and tests

* add broken shadow table test

* rename _metadata_chunksNN shadown table to _metadatachunksNN, for proper shadowName detection

* _metadata_text_NN shadow tables to _metadatatextNN

* SQLITE_VEC_VERSION_MAJOR SQLITE_VEC_VERSION_MINOR and SQLITE_VEC_VERSION_PATCH in sqlite-vec.h

* _info shadow table

* forgot to update aux snapshot?

* fix aux tests
This commit is contained in:
Alex Garcia 2024-11-20 00:59:34 -08:00 committed by GitHub
parent 9bfeaa7842
commit 352f953fc0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 7361 additions and 105 deletions

View file

@ -316,7 +316,7 @@
'type': 'table',
'name': 'sqlite_sequence',
'tbl_name': 'sqlite_sequence',
'rootpage': 3,
'rootpage': 5,
'sql': 'CREATE TABLE sqlite_sequence(name,seq)',
}),
]),
@ -326,18 +326,25 @@
OrderedDict({
'sql': 'select * from sqlite_master order by name',
'rows': list([
OrderedDict({
'type': 'index',
'name': 'sqlite_autoindex_v_info_1',
'tbl_name': 'v_info',
'rootpage': 3,
'sql': None,
}),
OrderedDict({
'type': 'index',
'name': 'sqlite_autoindex_v_vector_chunks00_1',
'tbl_name': 'v_vector_chunks00',
'rootpage': 6,
'rootpage': 8,
'sql': None,
}),
OrderedDict({
'type': 'table',
'name': 'sqlite_sequence',
'tbl_name': 'sqlite_sequence',
'rootpage': 3,
'rootpage': 5,
'sql': 'CREATE TABLE sqlite_sequence(name,seq)',
}),
OrderedDict({
@ -351,28 +358,35 @@
'type': 'table',
'name': 'v_auxiliary',
'tbl_name': 'v_auxiliary',
'rootpage': 7,
'rootpage': 9,
'sql': 'CREATE TABLE "v_auxiliary"( rowid integer PRIMARY KEY , value00)',
}),
OrderedDict({
'type': 'table',
'name': 'v_chunks',
'tbl_name': 'v_chunks',
'rootpage': 2,
'rootpage': 4,
'sql': 'CREATE TABLE "v_chunks"(chunk_id INTEGER PRIMARY KEY AUTOINCREMENT,size INTEGER NOT NULL,validity BLOB NOT NULL,rowids BLOB NOT NULL)',
}),
OrderedDict({
'type': 'table',
'name': 'v_info',
'tbl_name': 'v_info',
'rootpage': 2,
'sql': 'CREATE TABLE "v_info" (key text primary key, value any)',
}),
OrderedDict({
'type': 'table',
'name': 'v_rowids',
'tbl_name': 'v_rowids',
'rootpage': 4,
'rootpage': 6,
'sql': 'CREATE TABLE "v_rowids"(rowid INTEGER PRIMARY KEY AUTOINCREMENT,id,chunk_id INTEGER,chunk_offset INTEGER)',
}),
OrderedDict({
'type': 'table',
'name': 'v_vector_chunks00',
'tbl_name': 'v_vector_chunks00',
'rootpage': 5,
'rootpage': 7,
'sql': 'CREATE TABLE "v_vector_chunks00"(rowid PRIMARY KEY,vectors BLOB NOT NULL)',
}),
]),
@ -409,25 +423,25 @@
# ---
# name: test_types.3
dict({
'error': 'OperationalError',
'error': 'IntegrityError',
'message': 'Auxiliary column type mismatch: The auxiliary column aux_int has type INTEGER, but TEXT was provided.',
})
# ---
# name: test_types.4
dict({
'error': 'OperationalError',
'error': 'IntegrityError',
'message': 'Auxiliary column type mismatch: The auxiliary column aux_float has type FLOAT, but TEXT was provided.',
})
# ---
# name: test_types.5
dict({
'error': 'OperationalError',
'error': 'IntegrityError',
'message': 'Auxiliary column type mismatch: The auxiliary column aux_text has type TEXT, but INTEGER was provided.',
})
# ---
# name: test_types.6
dict({
'error': 'OperationalError',
'error': 'IntegrityError',
'message': 'Auxiliary column type mismatch: The auxiliary column aux_blob has type BLOB, but INTEGER was provided.',
})
# ---

View file

@ -0,0 +1,184 @@
# serializer version: 1
# name: test_info
OrderedDict({
'sql': 'select key, typeof(value) from v_info order by 1',
'rows': list([
OrderedDict({
'key': 'CREATE_VERSION',
'typeof(value)': 'text',
}),
OrderedDict({
'key': 'CREATE_VERSION_MAJOR',
'typeof(value)': 'integer',
}),
OrderedDict({
'key': 'CREATE_VERSION_MINOR',
'typeof(value)': 'integer',
}),
OrderedDict({
'key': 'CREATE_VERSION_PATCH',
'typeof(value)': 'integer',
}),
]),
})
# ---
# name: test_shadow
OrderedDict({
'sql': 'select * from sqlite_master order by name',
'rows': list([
OrderedDict({
'type': 'index',
'name': 'sqlite_autoindex_v_info_1',
'tbl_name': 'v_info',
'rootpage': 3,
'sql': None,
}),
OrderedDict({
'type': 'index',
'name': 'sqlite_autoindex_v_metadatachunks00_1',
'tbl_name': 'v_metadatachunks00',
'rootpage': 10,
'sql': None,
}),
OrderedDict({
'type': 'index',
'name': 'sqlite_autoindex_v_metadatatext00_1',
'tbl_name': 'v_metadatatext00',
'rootpage': 12,
'sql': None,
}),
OrderedDict({
'type': 'index',
'name': 'sqlite_autoindex_v_vector_chunks00_1',
'tbl_name': 'v_vector_chunks00',
'rootpage': 8,
'sql': None,
}),
OrderedDict({
'type': 'table',
'name': 'sqlite_sequence',
'tbl_name': 'sqlite_sequence',
'rootpage': 5,
'sql': 'CREATE TABLE sqlite_sequence(name,seq)',
}),
OrderedDict({
'type': 'table',
'name': 'v',
'tbl_name': 'v',
'rootpage': 0,
'sql': 'CREATE VIRTUAL TABLE v using vec0(a float[1], partition text partition key, metadata text, +name text, chunk_size=8)',
}),
OrderedDict({
'type': 'table',
'name': 'v_auxiliary',
'tbl_name': 'v_auxiliary',
'rootpage': 13,
'sql': 'CREATE TABLE "v_auxiliary"( rowid integer PRIMARY KEY , value00)',
}),
OrderedDict({
'type': 'table',
'name': 'v_chunks',
'tbl_name': 'v_chunks',
'rootpage': 4,
'sql': 'CREATE TABLE "v_chunks"(chunk_id INTEGER PRIMARY KEY AUTOINCREMENT,size INTEGER NOT NULL,sequence_id integer,partition00,validity BLOB NOT NULL, rowids BLOB NOT NULL)',
}),
OrderedDict({
'type': 'table',
'name': 'v_info',
'tbl_name': 'v_info',
'rootpage': 2,
'sql': 'CREATE TABLE "v_info" (key text primary key, value any)',
}),
OrderedDict({
'type': 'table',
'name': 'v_metadatachunks00',
'tbl_name': 'v_metadatachunks00',
'rootpage': 9,
'sql': 'CREATE TABLE "v_metadatachunks00"(rowid PRIMARY KEY, data BLOB NOT NULL)',
}),
OrderedDict({
'type': 'table',
'name': 'v_metadatatext00',
'tbl_name': 'v_metadatatext00',
'rootpage': 11,
'sql': 'CREATE TABLE "v_metadatatext00"(rowid PRIMARY KEY, data TEXT)',
}),
OrderedDict({
'type': 'table',
'name': 'v_rowids',
'tbl_name': 'v_rowids',
'rootpage': 6,
'sql': 'CREATE TABLE "v_rowids"(rowid INTEGER PRIMARY KEY AUTOINCREMENT,id,chunk_id INTEGER,chunk_offset INTEGER)',
}),
OrderedDict({
'type': 'table',
'name': 'v_vector_chunks00',
'tbl_name': 'v_vector_chunks00',
'rootpage': 7,
'sql': 'CREATE TABLE "v_vector_chunks00"(rowid PRIMARY KEY,vectors BLOB NOT NULL)',
}),
]),
})
# ---
# name: test_shadow.1
OrderedDict({
'sql': "select * from pragma_table_list where type = 'shadow'",
'rows': list([
OrderedDict({
'schema': 'main',
'name': 'v_auxiliary',
'type': 'shadow',
'ncol': 2,
'wr': 0,
'strict': 0,
}),
OrderedDict({
'schema': 'main',
'name': 'v_chunks',
'type': 'shadow',
'ncol': 6,
'wr': 0,
'strict': 0,
}),
OrderedDict({
'schema': 'main',
'name': 'v_info',
'type': 'shadow',
'ncol': 2,
'wr': 0,
'strict': 0,
}),
OrderedDict({
'schema': 'main',
'name': 'v_rowids',
'type': 'shadow',
'ncol': 4,
'wr': 0,
'strict': 0,
}),
OrderedDict({
'schema': 'main',
'name': 'v_metadatachunks00',
'type': 'shadow',
'ncol': 2,
'wr': 0,
'strict': 0,
}),
OrderedDict({
'schema': 'main',
'name': 'v_metadatatext00',
'type': 'shadow',
'ncol': 2,
'wr': 0,
'strict': 0,
}),
]),
})
# ---
# name: test_shadow.2
OrderedDict({
'sql': "select * from pragma_table_list where type = 'shadow'",
'rows': list([
]),
})
# ---

File diff suppressed because it is too large Load diff

1
tests/afbd/.gitignore vendored Normal file
View file

@ -0,0 +1 @@
*.tgz

View file

@ -0,0 +1 @@
3.12

9
tests/afbd/Makefile Normal file
View file

@ -0,0 +1,9 @@
random_ints_1m.tgz:
curl -o $@ https://storage.googleapis.com/ann-filtered-benchmark/datasets/random_ints_1m.tgz
random_float_1m.tgz:
curl -o $@ https://storage.googleapis.com/ann-filtered-benchmark/datasets/random_float_1m.tgz
random_keywords_1m.tgz:
curl -o $@ https://storage.googleapis.com/ann-filtered-benchmark/datasets/random_keywords_1m.tgz
all: random_ints_1m.tgz random_float_1m.tgz random_keywords_1m.tgz

12
tests/afbd/README.md Normal file
View file

@ -0,0 +1,12 @@
# hnm
```
tar -xOzf hnm.tgz ./tests.jsonl > tests.jsonl
solite q "select group_concat(distinct key) from lines_read('tests.jsonl'), json_each(line -> '$.conditions.and[0]')"
```
```
> python test-afbd.py build hnm.tgz --metadata product_group_name,colour_group_name,index_group_name,perceived_colour_value_name,section_name,product_type_name,department_name,graphical_appearance_name,garment_group_name,perceived_colour_master_name
```

231
tests/afbd/test-afbd.py Normal file
View file

@ -0,0 +1,231 @@
import numpy as np
from tqdm import tqdm
from deepdiff import DeepDiff
import tarfile
import json
from io import BytesIO
import sqlite3
from typing import List
from struct import pack
import time
from pathlib import Path
import argparse
def serialize_float32(vector: List[float]) -> bytes:
"""Serializes a list of floats into the "raw bytes" format sqlite-vec expects"""
return pack("%sf" % len(vector), *vector)
def build_command(file_path, metadata_set=None):
if metadata_set:
metadata_set = set(metadata_set.split(","))
file_path = Path(file_path)
print(f"reading {file_path}...")
t0 = time.time()
with tarfile.open(file_path, "r:gz") as archive:
for file in archive:
if file.name == "./payloads.jsonl":
payloads = [
json.loads(line)
for line in archive.extractfile(file.name).readlines()
]
if file.name == "./tests.jsonl":
tests = [
json.loads(line)
for line in archive.extractfile(file.name).readlines()
]
if file.name == "./vectors.npy":
f = BytesIO()
f.write(archive.extractfile(file.name).read())
f.seek(0)
vectors = np.load(f)
assert payloads is not None
assert tests is not None
assert vectors is not None
dimensions = vectors.shape[1]
metadata_columns = sorted(list(payloads[0].keys()))
def col_type(v):
if isinstance(v, int):
return "integer"
if isinstance(v, float):
return "float"
if isinstance(v, str):
return "text"
raise Exception(f"Unknown column type: {v}")
metadata_columns_types = [col_type(payloads[0][col]) for col in metadata_columns]
print(time.time() - t0)
t0 = time.time()
print("seeding...")
db = sqlite3.connect(f"{file_path.stem}.db")
db.execute("PRAGMA page_size = 16384")
db.row_factory = sqlite3.Row
db.enable_load_extension(True)
db.load_extension("../../dist/vec0")
db.enable_load_extension(False)
with db:
db.execute("create table tests(data)")
for test in tests:
db.execute("insert into tests values (?)", [json.dumps(test)])
with db:
create_sql = f"create virtual table v using vec0(vector float[{dimensions}] distance_metric=cosine"
insert_sql = "insert into v(rowid, vector"
for name, type in zip(metadata_columns, metadata_columns_types):
if metadata_set:
if name in metadata_set:
create_sql += f", {name} {type}"
else:
create_sql += f", +{name} {type}"
else:
create_sql += f", {name} {type}"
insert_sql += f", {name}"
create_sql += ")"
insert_sql += ") values (" + ",".join("?" * (2 + len(metadata_columns))) + ")"
print(create_sql)
print(insert_sql)
db.execute(create_sql)
for idx, (payload, vector) in enumerate(
tqdm(zip(payloads, vectors), total=len(payloads))
):
params = [idx, vector]
for c in metadata_columns:
params.append(payload[c])
db.execute(insert_sql, params)
print(time.time() - t0)
def tests_command(file_path):
file_path = Path(file_path)
db = sqlite3.connect(f"{file_path.stem}.db")
db.execute("PRAGMA cache_size = -100000000")
db.row_factory = sqlite3.Row
db.enable_load_extension(True)
db.load_extension("../../dist/vec0")
db.enable_load_extension(False)
tests = [
json.loads(row["data"])
for row in db.execute("select data from tests").fetchall()
]
num_or_skips = 0
num_1off_errors = 0
t0 = time.time()
print("testing...")
for idx, test in enumerate(tqdm(tests)):
query = test["query"]
conditions = test["conditions"]
expected_closest_ids = test["closest_ids"]
expected_closest_scores = test["closest_scores"]
sql = "select rowid, 1 - distance as similarity from v where vector match ? and k = ?"
params = [serialize_float32(query), len(expected_closest_ids)]
if "and" in conditions:
for condition in conditions["and"]:
assert len(condition.keys()) == 1
column = list(condition.keys())[0]
assert len(list(condition[column].keys())) == 1
condition_type = list(condition[column].keys())[0]
if condition_type == "match":
value = condition[column]["match"]["value"]
sql += f" and {column} = ?"
params.append(value)
elif condition_type == "range":
sql += f" and {column} between ? and ?"
params.append(condition[column]["range"]["gt"])
params.append(condition[column]["range"]["lt"])
else:
raise Exception(f"Unknown condition type: {condition_type}")
elif "or" in conditions:
column = list(conditions["or"][0].keys())[0]
condition_type = list(conditions["or"][0][column].keys())[0]
assert condition_type == "match"
sql += f" and {column} in ("
for idx, condition in enumerate(conditions["or"]):
if condition_type == "match":
value = condition[column]["match"]["value"]
if idx != 0:
sql += ","
sql += "?"
params.append(value)
elif condition_type == "range":
breakpoint()
else:
raise Exception(f"Unknown condition type: {condition_type}")
sql += ")"
# print(sql, params[1:])
rows = db.execute(sql, params).fetchall()
actual_closest_ids = [row["rowid"] for row in rows]
matches = expected_closest_ids == actual_closest_ids
if not matches:
diff = DeepDiff(
expected_closest_ids, actual_closest_ids, ignore_order=False
)
assert len(list(diff.keys())) == 1
assert "values_changed" in diff.keys()
keys_changed = list(diff["values_changed"].keys())
if len(keys_changed) == 2:
akey, bkey = keys_changed
a = int(akey.lstrip("root[").rstrip("]"))
b = int(bkey.lstrip("root[").rstrip("]"))
assert abs(a - b) == 1
assert (
diff["values_changed"][akey]["new_value"]
== diff["values_changed"][bkey]["old_value"]
)
assert (
diff["values_changed"][akey]["old_value"]
== diff["values_changed"][bkey]["new_value"]
)
elif len(keys_changed) == 1:
v = int(keys_changed[0].lstrip("root[").rstrip("]"))
assert (v + 1) == len(expected_closest_ids)
else:
raise Exception("fuck")
num_1off_errors += 1
# print(closest_scores)
# print([row["similarity"] for row in rows])
# assert closest_scores == [row["similarity"] for row in rows]
print("Number skipped: ", num_or_skips)
print("Num 1 off errors: ", num_1off_errors)
print("1 off error rate: ", num_1off_errors / (len(tests) - num_or_skips))
print(time.time() - t0)
print("done")
def main():
parser = argparse.ArgumentParser(description="CLI tool")
subparsers = parser.add_subparsers(dest="command", required=True)
build_parser = subparsers.add_parser("build")
build_parser.add_argument("file", type=str, help="Path to input file")
build_parser.add_argument("--metadata", type=str, help="Metadata columns")
build_parser.set_defaults(func=lambda args: build_command(args.file, args.metadata))
tests_parser = subparsers.add_parser("test")
tests_parser.add_argument("file", type=str, help="Path to input file")
tests_parser.set_defaults(func=lambda args: tests_command(args.file))
args = parser.parse_args()
args.func(args)
if __name__ == "__main__":
main()

View file

@ -55,7 +55,10 @@ def test_types(db, snapshot):
)
assert exec(db, "select * from v") == snapshot()
# TODO: integrity test transaction failures in shadow tables
db.commit()
# bad types
db.execute("BEGIN")
assert (
exec(db, INSERT, [b"\x11\x11\x11\x11", "not int", 1.2, "text", b"blob"])
== snapshot()
@ -66,6 +69,7 @@ def test_types(db, snapshot):
)
assert exec(db, INSERT, [b"\x11\x11\x11\x11", 1, 1.2, 1, b"blob"]) == snapshot()
assert exec(db, INSERT, [b"\x11\x11\x11\x11", 1, 1.2, "text", 1]) == snapshot()
db.execute("ROLLBACK")
# NULLs are totally chill
assert exec(db, INSERT, [b"\x11\x11\x11\x11", None, None, None, None]) == snapshot()
@ -151,5 +155,7 @@ def vec0_shadow_table_contents(db, v):
]
o = {}
for shadow_table in shadow_tables:
if shadow_table.endswith("_info"):
continue
o[shadow_table] = exec(db, f"select * from {shadow_table}")
return o

60
tests/test-general.py Normal file
View file

@ -0,0 +1,60 @@
import sqlite3
from collections import OrderedDict
import pytest
@pytest.mark.skipif(
sqlite3.sqlite_version_info[1] < 37,
reason="pragma_table_list was added in SQLite 3.37",
)
def test_shadow(db, snapshot):
db.execute(
"create virtual table v using vec0(a float[1], partition text partition key, metadata text, +name text, chunk_size=8)"
)
assert exec(db, "select * from sqlite_master order by name") == snapshot()
assert (
exec(db, "select * from pragma_table_list where type = 'shadow'") == snapshot()
)
db.execute("drop table v;")
assert (
exec(db, "select * from pragma_table_list where type = 'shadow'") == snapshot()
)
def test_info(db, snapshot):
db.execute("create virtual table v using vec0(a float[1])")
assert exec(db, "select key, typeof(value) from v_info order by 1") == snapshot()
def exec(db, sql, parameters=[]):
try:
rows = db.execute(sql, parameters).fetchall()
except (sqlite3.OperationalError, sqlite3.DatabaseError) as e:
return {
"error": e.__class__.__name__,
"message": str(e),
}
a = []
for row in rows:
o = OrderedDict()
for k in row.keys():
o[k] = row[k]
a.append(o)
result = OrderedDict()
result["sql"] = sql
result["rows"] = a
return result
def vec0_shadow_table_contents(db, v):
shadow_tables = [
row[0]
for row in db.execute(
"select name from sqlite_master where name like ? order by 1", [f"{v}_%"]
).fetchall()
]
o = {}
for shadow_table in shadow_tables:
o[shadow_table] = exec(db, f"select * from {shadow_table}")
return o

View file

@ -1022,6 +1022,7 @@ def test_vec0_drops():
] == [
"t1",
"t1_chunks",
"t1_info",
"t1_rowids",
"t1_vector_chunks00",
"t1_vector_chunks01",
@ -2216,6 +2217,9 @@ def test_smoke():
{
"name": "vec_xyz_chunks",
},
{
"name": "vec_xyz_info",
},
{
"name": "vec_xyz_rowids",
},

629
tests/test-metadata.py Normal file
View file

@ -0,0 +1,629 @@
import pytest
import sqlite3
from collections import OrderedDict
import json
def test_constructor_limit(db, snapshot):
assert exec(
db,
f"""
create virtual table v using vec0(
{",".join([f"metadata{x} integer" for x in range(17)])}
v float[1]
)
""",
) == snapshot(name="max 16 metadata columns")
def test_normal(db, snapshot):
db.execute(
"create virtual table v using vec0(vector float[1], b boolean, n int, f float, t text, chunk_size=8)"
)
assert exec(
db, "select * from sqlite_master where type = 'table' order by name"
) == snapshot(name="sqlite_master")
assert vec0_shadow_table_contents(db, "v") == snapshot()
INSERT = "insert into v(vector, b, n, f, t) values (?, ?, ?, ?, ?)"
assert exec(db, INSERT, [b"\x11\x11\x11\x11", 1, 1, 1.1, "one"]) == snapshot()
assert exec(db, INSERT, [b"\x22\x22\x22\x22", 1, 2, 2.2, "two"]) == snapshot()
assert exec(db, INSERT, [b"\x33\x33\x33\x33", 1, 3, 3.3, "three"]) == snapshot()
assert exec(db, "select * from v") == snapshot()
assert vec0_shadow_table_contents(db, "v") == snapshot()
assert exec(db, "drop table v") == snapshot()
assert exec(db, "select * from sqlite_master") == snapshot()
#
# assert exec(db, "select * from v") == snapshot()
# assert vec0_shadow_table_contents(db, "v") == snapshot()
#
# db.execute("drop table v;")
# assert exec(db, "select * from sqlite_master order by name") == snapshot(
# name="sqlite_master post drop"
# )
def test_text_knn(db, snapshot):
db.execute(
"create virtual table v using vec0(vector float[1], name text, chunk_size=8)"
)
assert vec0_shadow_table_contents(db, "v") == snapshot()
INSERT = "insert into v(vector, name) values (?, ?)"
db.execute(
"""
INSERT INTO v(vector, name) VALUES
('[.11]', 'aaa'),
('[.22]', 'bbb'),
('[.33]', 'ccc'),
('[.44]', 'ddd'),
('[.55]', 'eee'),
('[.66]', 'fff'),
('[.77]', 'ggg'),
('[.88]', 'hhh'),
('[.99]', 'iii');
"""
)
assert exec(db, "select * from v") == snapshot()
assert vec0_shadow_table_contents(db, "v") == snapshot()
assert (
exec(
db,
"select rowid, name, distance from v where vector match '[1]' and k = 5",
)
== snapshot()
)
assert (
exec(
db,
"select rowid, name, distance from v where vector match '[1]' and k = 5 and name < 'ddd'",
)
== snapshot()
)
assert (
exec(
db,
"select rowid, name, distance from v where vector match '[1]' and k = 5 and name <= 'ddd'",
)
== snapshot()
)
assert (
exec(
db,
"select rowid, name, distance from v where vector match '[1]' and k = 5 and name > 'fff'",
)
== snapshot()
)
assert (
exec(
db,
"select rowid, name, distance from v where vector match '[1]' and k = 5 and name >= 'fff'",
)
== snapshot()
)
assert (
exec(
db,
"select rowid, name, distance from v where vector match '[1]' and k = 5 and name = 'aaa'",
)
== snapshot()
)
assert (
exec(
db,
"select rowid, name, distance from v where vector match '[.01]' and k = 5 and name != 'aaa'",
)
== snapshot()
)
def test_long_text_updates(db, snapshot):
db.execute(
"create virtual table v using vec0(vector float[1], name text, chunk_size=8)"
)
assert vec0_shadow_table_contents(db, "v") == snapshot()
INSERT = "insert into v(vector, name) values (?, ?)"
exec(db, INSERT, [b"\x11\x11\x11\x11", "123456789a12"])
exec(db, INSERT, [b"\x11\x11\x11\x11", "123456789a123"])
assert exec(db, "select * from v") == snapshot()
assert vec0_shadow_table_contents(db, "v") == snapshot()
def test_long_text_knn(db, snapshot):
db.execute(
"create virtual table v using vec0(vector float[1], name text, chunk_size=8)"
)
INSERT = "insert into v(vector, name) values (?, ?)"
exec(db, INSERT, ["[1]", "aaaa"])
exec(db, INSERT, ["[2]", "aaaaaaaaaaaa_aaa"])
exec(db, INSERT, ["[3]", "bbbb"])
exec(db, INSERT, ["[4]", "bbbbbbbbbbbb_bbb"])
exec(db, INSERT, ["[5]", "cccc"])
exec(db, INSERT, ["[6]", "cccccccccccc_ccc"])
tests = [
"bbbb",
"bb",
"bbbbbb",
"bbbbbbbbbbbb_bbb",
"bbbbbbbbbbbb_aaa",
"bbbbbbbbbbbb_ccc",
"longlonglonglonglonglonglong",
]
ops = ["=", "!=", "<", "<=", ">", ">="]
op_names = ["eq", "ne", "lt", "le", "gt", "ge"]
for test in tests:
for op, op_name in zip(ops, op_names):
assert exec(
db,
f"select rowid, name, distance from v where vector match '[100]' and k = 5 and name {op} ?",
[test],
) == snapshot(name=f"{op_name}-{test}")
def test_types(db, snapshot):
db.execute(
"create virtual table v using vec0(vector float[1], b boolean, n int, f float, t text, chunk_size=8)"
)
INSERT = "insert into v(vector, b, n, f, t) values (?, ?, ?, ?, ?)"
assert exec(db, INSERT, [b"\x11\x11\x11\x11", 1, 1, 1.1, "test"]) == snapshot(
name="legal"
)
# fmt: off
assert exec(db, INSERT, [b"\x11\x11\x11\x11", 'illegal', 1, 1.1, 'test']) == snapshot(name="illegal-type-boolean")
assert exec(db, INSERT, [b"\x11\x11\x11\x11", 1, 'illegal', 1.1, 'test']) == snapshot(name="illegal-type-int")
assert exec(db, INSERT, [b"\x11\x11\x11\x11", 1, 1, 'illegal', 'test']) == snapshot(name="illegal-type-float")
assert exec(db, INSERT, [b"\x11\x11\x11\x11", 1, 1, 1.1, 420]) == snapshot(name="illegal-type-text")
# fmt: on
assert exec(db, INSERT, [b"\x11\x11\x11\x11", 44, 1, 1.1, "test"]) == snapshot(
name="illegal-boolean"
)
def test_updates(db, snapshot):
db.execute(
"create virtual table v using vec0(vector float[1], b boolean, n int, f float, t text, chunk_size=8)"
)
INSERT = "insert into v(rowid, vector, b, n, f, t) values (?, ?, ?, ?, ?, ?)"
exec(db, INSERT, [1, b"\x11\x11\x11\x11", 1, 1, 1.1, "test1"])
exec(db, INSERT, [2, b"\x22\x22\x22\x22", 1, 2, 2.2, "test2"])
exec(db, INSERT, [3, b"\x33\x33\x33\x33", 1, 3, 3.3, "1234567890123"])
assert exec(db, "select * from v") == snapshot(name="1-init-contents")
assert vec0_shadow_table_contents(db, "v") == snapshot(name="1-init-shadow")
assert exec(
db, "UPDATE v SET b = 0, n = 11, f = 11.11, t = 'newtest1' where rowid = 1"
)
assert exec(db, "select * from v") == snapshot(name="general-update-contents")
assert vec0_shadow_table_contents(db, "v") == snapshot(
name="general-update-shaodnw"
)
# string update #1: long string updated to long string
exec(db, "UPDATE v SET t = '1234567890123-updated' where rowid = 3")
assert exec(db, "select * from v") == snapshot(name="string-update-1-contents")
assert vec0_shadow_table_contents(db, "v") == snapshot(
name="string-update-1-shadow"
)
# string update #2: short string updated to short string
exec(db, "UPDATE v SET t = 'test2-short' where rowid = 2")
assert exec(db, "select * from v") == snapshot(name="string-update-2-contents")
assert vec0_shadow_table_contents(db, "v") == snapshot(
name="string-update-2-shadow"
)
# string update #3: short string updated to long string
exec(db, "UPDATE v SET t = 'test2-long-long-long' where rowid = 2")
assert exec(db, "select * from v") == snapshot(name="string-update-3-contents")
assert vec0_shadow_table_contents(db, "v") == snapshot(
name="string-update-3-shadow"
)
# string update #4: long string updated to short string
exec(db, "UPDATE v SET t = 'test2-shortx' where rowid = 2")
assert exec(db, "select * from v") == snapshot(name="string-update-4-contents")
assert vec0_shadow_table_contents(db, "v") == snapshot(
name="string-update-4-shadow"
)
def test_deletes(db, snapshot):
db.execute(
"create virtual table v using vec0(vector float[1], b boolean, n int, f float, t text, chunk_size=8)"
)
INSERT = "insert into v(rowid, vector, b, n, f, t) values (?, ?, ?, ?, ?, ?)"
assert exec(db, INSERT, [1, b"\x11\x11\x11\x11", 1, 1, 1.1, "test1"]) == snapshot()
assert exec(db, INSERT, [2, b"\x22\x22\x22\x22", 1, 2, 2.2, "test2"]) == snapshot()
assert (
exec(db, INSERT, [3, b"\x33\x33\x33\x33", 1, 3, 3.3, "1234567890123"])
== snapshot()
)
assert exec(db, "select * from v") == snapshot()
assert vec0_shadow_table_contents(db, "v") == snapshot()
assert exec(db, "DELETE FROM v where rowid = 1") == snapshot()
assert exec(db, "select * from v") == snapshot()
assert vec0_shadow_table_contents(db, "v") == snapshot()
assert exec(db, "DELETE FROM v where rowid = 3") == snapshot()
assert exec(db, "select * from v") == snapshot()
assert vec0_shadow_table_contents(db, "v") == snapshot()
def test_knn(db, snapshot):
db.execute(
"create virtual table v using vec0(vector float[1], name text, chunk_size=8)"
)
assert exec(
db, "select * from sqlite_master where type = 'table' order by name"
) == snapshot(name="sqlite_master")
db.executemany(
"insert into v(vector, name) values (?, ?)",
[("[1]", "alex"), ("[2]", "brian"), ("[3]", "craig")],
)
# EVIDENCE-OF: V16511_00582 catches "illegal" constraints on metadata columns
assert (
exec(
db,
"select *, distance from v where vector match '[5]' and k = 3 and name like 'illegal'",
)
== snapshot()
)
SUPPORTS_VTAB_IN = sqlite3.sqlite_version_info[1] >= 38
@pytest.mark.skipif(
not SUPPORTS_VTAB_IN, reason="requires vtab `x in (...)` support in SQLite >=3.38"
)
def test_vtab_in(db, snapshot):
db.execute(
"create virtual table v using vec0(vector float[1], n int, t text, b boolean, f float, chunk_size=8)"
)
db.executemany(
"insert into v(rowid, vector, n, t, b, f) values (?, ?, ?, ?, ?, ?)",
[
(1, "[1]", 999, "aaaa", 0, 1.1),
(2, "[2]", 555, "aaaa", 0, 1.1),
(3, "[3]", 999, "aaaa", 0, 1.1),
(4, "[4]", 555, "aaaa", 0, 1.1),
(5, "[5]", 999, "zzzz", 0, 1.1),
(6, "[6]", 555, "zzzz", 0, 1.1),
(7, "[7]", 999, "zzzz", 0, 1.1),
(8, "[8]", 555, "zzzz", 0, 1.1),
],
)
# EVIDENCE-OF: V15248_32086
assert exec(
db, "select * from v where vector match '[0]' and k = 8 and b in (1, 0)"
) == snapshot(name="block-bool")
assert exec(
db, "select * from v where vector match '[0]' and k = 8 and f in (1.1, 0.0)"
) == snapshot(name="block-float")
assert exec(
db,
"select rowid, n, distance from v where vector match '[0]' and k = 8 and n in (555, 999)",
) == snapshot(name="allow-int-all")
assert exec(
db,
"select rowid, n, distance from v where vector match '[0]' and k = 8 and n in (555, -1, -2)",
) == snapshot(name="allow-int-superfluous")
assert exec(
db,
"select rowid, t, distance from v where vector match '[0]' and k = 8 and t in ('aaaa', 'zzzz')",
) == snapshot(name="allow-text-all")
assert exec(
db,
"select rowid, t, distance from v where vector match '[0]' and k = 8 and t in ('aaaa', 'foo', 'bar')",
) == snapshot(name="allow-text-superfluous")
def test_vtab_in_long_text(db, snapshot):
db.execute(
"create virtual table v using vec0(vector float[1], t text, chunk_size=8)"
)
data = [
(1, "aaaa"),
(2, "aaaaaaaaaaaa_aaa"),
(3, "bbbb"),
(4, "bbbbbbbbbbbb_bbb"),
(5, "cccc"),
(6, "cccccccccccc_ccc"),
]
db.executemany(
"insert into v(rowid, vector, t) values (:rowid, printf('[%d]', :rowid), :vector)",
[{"rowid": row[0], "vector": row[1]} for row in data],
)
for _, lookup in data:
assert exec(
db,
"select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')",
[lookup],
) == snapshot(name=f"individual-{lookup}")
assert exec(
db,
"select rowid, t from v where vector match '[0]' and k = 10 and t in (select value from json_each(?))",
[json.dumps([row[1] for row in data])],
) == snapshot(name="all")
def test_idxstr(db, snapshot):
db.execute(
"""
create virtual table vec_movies using vec0(
movie_id integer primary key,
synopsis_embedding float[1],
+title text,
is_favorited boolean,
genre text,
num_reviews int,
mean_rating float,
chunk_size=8
);
"""
)
assert (
eqp(
db,
"select * from vec_movies where synopsis_embedding match '' and k = 0 and is_favorited = true",
)
== snapshot()
)
ops = ["<", ">", "<=", ">=", "!="]
for op in ops:
assert eqp(
db,
f"select * from vec_movies where synopsis_embedding match '' and k = 0 and genre {op} NULL",
) == snapshot(name=f"knn-constraint-text {op}")
for op in ops:
assert eqp(
db,
f"select * from vec_movies where synopsis_embedding match '' and k = 0 and num_reviews {op} NULL",
) == snapshot(name=f"knn-constraint-int {op}")
for op in ops:
assert eqp(
db,
f"select * from vec_movies where synopsis_embedding match '' and k = 0 and mean_rating {op} NULL",
) == snapshot(name=f"knn-constraint-float {op}")
# for op in ops:
# assert eqp(
# db,
# f"select * from vec_movies where synopsis_embedding match '' and k = 0 and is_favorited {op} NULL",
# ) == snapshot(name=f"knn-constraint-boolean {op}")
def eqp(db, sql):
o = OrderedDict()
o["sql"] = sql
o["plan"] = [
dict(row) for row in db.execute(f"explain query plan {sql}").fetchall()
]
for p in o["plan"]:
# value is different on macos-aarch64 in github actions, not sure why
del p["notused"]
return o
def test_stress(db, snapshot):
db.execute(
"""
create virtual table vec_movies using vec0(
movie_id integer primary key,
synopsis_embedding float[1],
+title text,
is_favorited boolean,
genre text,
num_reviews int,
mean_rating float,
chunk_size=8
);
"""
)
db.execute(
"""
INSERT INTO vec_movies(movie_id, synopsis_embedding, is_favorited, genre, title, num_reviews, mean_rating)
VALUES
(1, '[1]', 0, 'horror', 'The Conjuring', 153, 4.6),
(2, '[2]', 0, 'comedy', 'Dumb and Dumber', 382, 2.6),
(3, '[3]', 0, 'scifi', 'Interstellar', 53, 5.0),
(4, '[4]', 0, 'fantasy', 'The Lord of the Rings: The Fellowship of the Ring', 210, 4.2),
(5, '[5]', 1, 'documentary', 'An Inconvenient Truth', 93, 3.4),
(6, '[6]', 1, 'horror', 'Hereditary', 167, 4.7),
(7, '[7]', 1, 'comedy', 'Anchorman: The Legend of Ron Burgundy', 482, 2.9),
(8, '[8]', 0, 'scifi', 'Blade Runner 2049', 301, 5.0),
(9, '[9]', 1, 'fantasy', 'Harry Potter and the Sorcerer''s Stone', 134, 4.1),
(10, '[10]', 0, 'documentary', 'Free Solo', 66, 3.2),
(11, '[11]', 1, 'horror', 'Get Out', 88, 4.9),
(12, '[12]', 0, 'comedy', 'The Hangover', 59, 2.8),
(13, '[13]', 1, 'scifi', 'The Matrix', 423, 4.5),
(14, '[14]', 0, 'fantasy', 'Pan''s Labyrinth', 275, 3.6),
(15, '[15]', 1, 'documentary', '13th', 191, 4.4),
(16, '[16]', 0, 'horror', 'It Follows', 314, 4.3),
(17, '[17]', 1, 'comedy', 'Step Brothers', 74, 3.0),
(18, '[18]', 1, 'scifi', 'Inception', 201, 5.0),
(19, '[19]', 1, 'fantasy', 'The Shape of Water', 399, 2.7),
(20, '[20]', 1, 'documentary', 'Won''t You Be My Neighbor?', 186, 4.8),
(21, '[21]', 1, 'scifi', 'Gravity', 342, 4.0),
(22, '[22]', 1, 'scifi', 'Dune', 451, 4.4),
(23, '[23]', 1, 'scifi', 'The Martian', 522, 4.6),
(24, '[24]', 1, 'horror', 'A Quiet Place', 271, 4.3),
(25, '[25]', 1, 'fantasy', 'The Chronicles of Narnia: The Lion, the Witch and the Wardrobe', 310, 3.9);
"""
)
assert vec0_shadow_table_contents(db, "vec_movies") == snapshot()
assert (
exec(
db,
"""
select
movie_id,
title,
genre,
num_reviews,
mean_rating,
is_favorited,
distance
from vec_movies
where synopsis_embedding match '[15.5]'
and genre = 'scifi'
and num_reviews between 100 and 500
and mean_rating > 3.5
and k = 5;
""",
)
== snapshot()
)
assert (
exec(
db,
"select movie_id, genre, distance from vec_movies where synopsis_embedding match '[100]' and k = 5 and genre = 'horror'",
)
== snapshot()
)
assert (
exec(
db,
"select movie_id, genre, distance from vec_movies where synopsis_embedding match '[100]' and k = 5 and genre = 'comedy'",
)
== snapshot()
)
assert (
exec(
db,
"select movie_id, num_reviews, distance from vec_movies where synopsis_embedding match '[100]' and k = 5 and num_reviews between 100 and 500",
)
== snapshot()
)
assert (
exec(
db,
"select movie_id, num_reviews, distance from vec_movies where synopsis_embedding match '[100]' and k = 5 and num_reviews >= 500",
)
== snapshot()
)
assert (
exec(
db,
"select movie_id, mean_rating, distance from vec_movies where synopsis_embedding match '[100]' and k = 5 and mean_rating < 3.0",
)
== snapshot()
)
assert (
exec(
db,
"select movie_id, mean_rating, distance from vec_movies where synopsis_embedding match '[100]' and k = 5 and mean_rating between 4.0 and 5.0",
)
== snapshot()
)
assert exec(
db,
"select movie_id, is_favorited, distance from vec_movies where synopsis_embedding match '[100]' and k = 5 and is_favorited = TRUE",
) == snapshot(name="bool-eq-true")
assert exec(
db,
"select movie_id, is_favorited, distance from vec_movies where synopsis_embedding match '[100]' and k = 5 and is_favorited != TRUE",
) == snapshot(name="bool-ne-true")
assert exec(
db,
"select movie_id, is_favorited, distance from vec_movies where synopsis_embedding match '[100]' and k = 5 and is_favorited = FALSE",
) == snapshot(name="bool-eq-false")
assert exec(
db,
"select movie_id, is_favorited, distance from vec_movies where synopsis_embedding match '[100]' and k = 5 and is_favorited != FALSE",
) == snapshot(name="bool-ne-false")
# EVIDENCE-OF: V10145_26984
assert exec(
db,
"select movie_id, is_favorited, distance from vec_movies where synopsis_embedding match '[100]' and k = 5 and is_favorited >= 999",
) == snapshot(name="bool-other-op")
def test_errors(db, snapshot):
db.execute("create virtual table v using vec0(vector float[1], t text)")
db.execute("insert into v(vector, t) values ('[1]', 'aaaaaaaaaaaax')")
assert exec(db, "select * from v") == snapshot()
# EVIDENCE-OF: V15466_32305
db.set_authorizer(
authorizer_deny_on(sqlite3.SQLITE_READ, "v_metadatatext00", "data")
)
assert exec(db, "select * from v") == snapshot()
def authorizer_deny_on(operation, x1, x2=None):
def _auth(op, p1, p2, p3, p4):
if op == operation and p1 == x1 and p2 == x2:
return sqlite3.SQLITE_DENY
return sqlite3.SQLITE_OK
return _auth
def exec(db, sql, parameters=[]):
try:
rows = db.execute(sql, parameters).fetchall()
except (sqlite3.OperationalError, sqlite3.DatabaseError) as e:
return {
"error": e.__class__.__name__,
"message": str(e),
}
a = []
for row in rows:
o = OrderedDict()
for k in row.keys():
o[k] = row[k]
a.append(o)
result = OrderedDict()
result["sql"] = sql
result["rows"] = a
return result
def vec0_shadow_table_contents(db, v):
shadow_tables = [
row[0]
for row in db.execute(
"select name from sqlite_master where name like ? order by 1", [f"{v}_%"]
).fetchall()
]
o = {}
for shadow_table in shadow_tables:
if shadow_table.endswith("_info"):
continue
o[shadow_table] = exec(db, f"select * from {shadow_table}")
return o

View file

@ -111,5 +111,7 @@ def vec0_shadow_table_contents(db, v):
]
o = {}
for shadow_table in shadow_tables:
if shadow_table.endswith("_info"):
continue
o[shadow_table] = exec(db, f"select * from {shadow_table}")
return o