sqlite-vec/tests/afbd/test-afbd.py

232 lines
8.2 KiB
Python
Raw Normal View History

Metadata filtering (#124) * initial pass at PARTITION KEY support. * Initial pass, allow auxiliary columns on vec0 virtual tables * update TODO * Initial pass at metadata filtering * unit tests * gha this PR branch * fixup tests * doc internal * fix tests, KNN/rowids in * define SQLITE_INDEX_CONSTRAINT_OFFSET * whoops * update tests, syrupy, use uv * un ignore pyproject.toml * dot * tests/ * type error? * win: .exe, update error name * try fix macos python, paren around expr? * win bash? * dbg :( * explicit error * op * dbg win * win ./tests/.venv/Scripts/python.exe * block UPDATEs on partition key values for now * test this branch * accidentally removved "partition key type mistmatch" block during merge * typo ugh * bruv * start aux snapshots * drop aux shadow table on destroy * enforce column types * block WHERE constraints on auxiliary columns in KNN queries * support delete * support UPDATE on auxiliary columns * test this PR * dont inline that * test-metadata.py * memzero text buffer * stress test * more snpashot tests * rm double/int32, just float/int64 * finish type checking * long text support * DELETE support * UPDATE support * fix snapshot names * drop not-used in eqp * small fixes * boolean comparison handling * ensure error is raised when long string constraint * new version string for beta builds * typo whoops * ann-filtering-benchmark directory * test-case * updates * fix aux column error when using non-default rowid values, needs test * refactor some text knn filtering * rowids blob read only on text metadata filters * refactor * add failing test causes for non eq text knn * text knn NE * test cases diff * GT * text knn GT/GE fixes * text knn LT/LE * clean * vtab_in handling * unblock aux failures for now * guard sqlite3_vtab_in * else in guard? * fixes and tests * add broken shadow table test * rename _metadata_chunksNN shadown table to _metadatachunksNN, for proper shadowName detection * _metadata_text_NN shadow tables to _metadatatextNN * SQLITE_VEC_VERSION_MAJOR SQLITE_VEC_VERSION_MINOR and SQLITE_VEC_VERSION_PATCH in sqlite-vec.h * _info shadow table * forgot to update aux snapshot? * fix aux tests
2024-11-20 00:59:34 -08:00
import numpy as np
from tqdm import tqdm
from deepdiff import DeepDiff
import tarfile
import json
from io import BytesIO
import sqlite3
from typing import List
from struct import pack
import time
from pathlib import Path
import argparse
def serialize_float32(vector: List[float]) -> bytes:
"""Serializes a list of floats into the "raw bytes" format sqlite-vec expects"""
return pack("%sf" % len(vector), *vector)
def build_command(file_path, metadata_set=None):
if metadata_set:
metadata_set = set(metadata_set.split(","))
file_path = Path(file_path)
print(f"reading {file_path}...")
t0 = time.time()
with tarfile.open(file_path, "r:gz") as archive:
for file in archive:
if file.name == "./payloads.jsonl":
payloads = [
json.loads(line)
for line in archive.extractfile(file.name).readlines()
]
if file.name == "./tests.jsonl":
tests = [
json.loads(line)
for line in archive.extractfile(file.name).readlines()
]
if file.name == "./vectors.npy":
f = BytesIO()
f.write(archive.extractfile(file.name).read())
f.seek(0)
vectors = np.load(f)
assert payloads is not None
assert tests is not None
assert vectors is not None
dimensions = vectors.shape[1]
metadata_columns = sorted(list(payloads[0].keys()))
def col_type(v):
if isinstance(v, int):
return "integer"
if isinstance(v, float):
return "float"
if isinstance(v, str):
return "text"
raise Exception(f"Unknown column type: {v}")
metadata_columns_types = [col_type(payloads[0][col]) for col in metadata_columns]
print(time.time() - t0)
t0 = time.time()
print("seeding...")
db = sqlite3.connect(f"{file_path.stem}.db")
db.execute("PRAGMA page_size = 16384")
db.row_factory = sqlite3.Row
db.enable_load_extension(True)
db.load_extension("../../dist/vec0")
db.enable_load_extension(False)
with db:
db.execute("create table tests(data)")
for test in tests:
db.execute("insert into tests values (?)", [json.dumps(test)])
with db:
create_sql = f"create virtual table v using vec0(vector float[{dimensions}] distance_metric=cosine"
insert_sql = "insert into v(rowid, vector"
for name, type in zip(metadata_columns, metadata_columns_types):
if metadata_set:
if name in metadata_set:
create_sql += f", {name} {type}"
else:
create_sql += f", +{name} {type}"
else:
create_sql += f", {name} {type}"
insert_sql += f", {name}"
create_sql += ")"
insert_sql += ") values (" + ",".join("?" * (2 + len(metadata_columns))) + ")"
print(create_sql)
print(insert_sql)
db.execute(create_sql)
for idx, (payload, vector) in enumerate(
tqdm(zip(payloads, vectors), total=len(payloads))
):
params = [idx, vector]
for c in metadata_columns:
params.append(payload[c])
db.execute(insert_sql, params)
print(time.time() - t0)
def tests_command(file_path):
file_path = Path(file_path)
db = sqlite3.connect(f"{file_path.stem}.db")
db.execute("PRAGMA cache_size = -100000000")
db.row_factory = sqlite3.Row
db.enable_load_extension(True)
db.load_extension("../../dist/vec0")
db.enable_load_extension(False)
tests = [
json.loads(row["data"])
for row in db.execute("select data from tests").fetchall()
]
num_or_skips = 0
num_1off_errors = 0
t0 = time.time()
print("testing...")
for idx, test in enumerate(tqdm(tests)):
query = test["query"]
conditions = test["conditions"]
expected_closest_ids = test["closest_ids"]
expected_closest_scores = test["closest_scores"]
sql = "select rowid, 1 - distance as similarity from v where vector match ? and k = ?"
params = [serialize_float32(query), len(expected_closest_ids)]
if "and" in conditions:
for condition in conditions["and"]:
assert len(condition.keys()) == 1
column = list(condition.keys())[0]
assert len(list(condition[column].keys())) == 1
condition_type = list(condition[column].keys())[0]
if condition_type == "match":
value = condition[column]["match"]["value"]
sql += f" and {column} = ?"
params.append(value)
elif condition_type == "range":
sql += f" and {column} between ? and ?"
params.append(condition[column]["range"]["gt"])
params.append(condition[column]["range"]["lt"])
else:
raise Exception(f"Unknown condition type: {condition_type}")
elif "or" in conditions:
column = list(conditions["or"][0].keys())[0]
condition_type = list(conditions["or"][0][column].keys())[0]
assert condition_type == "match"
sql += f" and {column} in ("
for idx, condition in enumerate(conditions["or"]):
if condition_type == "match":
value = condition[column]["match"]["value"]
if idx != 0:
sql += ","
sql += "?"
params.append(value)
elif condition_type == "range":
breakpoint()
else:
raise Exception(f"Unknown condition type: {condition_type}")
sql += ")"
# print(sql, params[1:])
rows = db.execute(sql, params).fetchall()
actual_closest_ids = [row["rowid"] for row in rows]
matches = expected_closest_ids == actual_closest_ids
if not matches:
diff = DeepDiff(
expected_closest_ids, actual_closest_ids, ignore_order=False
)
assert len(list(diff.keys())) == 1
assert "values_changed" in diff.keys()
keys_changed = list(diff["values_changed"].keys())
if len(keys_changed) == 2:
akey, bkey = keys_changed
a = int(akey.lstrip("root[").rstrip("]"))
b = int(bkey.lstrip("root[").rstrip("]"))
assert abs(a - b) == 1
assert (
diff["values_changed"][akey]["new_value"]
== diff["values_changed"][bkey]["old_value"]
)
assert (
diff["values_changed"][akey]["old_value"]
== diff["values_changed"][bkey]["new_value"]
)
elif len(keys_changed) == 1:
v = int(keys_changed[0].lstrip("root[").rstrip("]"))
assert (v + 1) == len(expected_closest_ids)
else:
raise Exception("fuck")
num_1off_errors += 1
# print(closest_scores)
# print([row["similarity"] for row in rows])
# assert closest_scores == [row["similarity"] for row in rows]
print("Number skipped: ", num_or_skips)
print("Num 1 off errors: ", num_1off_errors)
print("1 off error rate: ", num_1off_errors / (len(tests) - num_or_skips))
print(time.time() - t0)
print("done")
def main():
parser = argparse.ArgumentParser(description="CLI tool")
subparsers = parser.add_subparsers(dest="command", required=True)
build_parser = subparsers.add_parser("build")
build_parser.add_argument("file", type=str, help="Path to input file")
build_parser.add_argument("--metadata", type=str, help="Metadata columns")
build_parser.set_defaults(func=lambda args: build_command(args.file, args.metadata))
tests_parser = subparsers.add_parser("test")
tests_parser.add_argument("file", type=str, help="Path to input file")
tests_parser.set_defaults(func=lambda args: tests_command(args.file))
args = parser.parse_args()
args.func(args)
if __name__ == "__main__":
main()