sqlite-vec/tests/afbd/test-afbd.py

232 lines
8.2 KiB
Python
Raw Normal View History

2024-11-17 08:38:57 -08:00
import numpy as np
from tqdm import tqdm
from deepdiff import DeepDiff
import tarfile
import json
from io import BytesIO
import sqlite3
from typing import List
from struct import pack
import time
from pathlib import Path
import argparse
def serialize_float32(vector: List[float]) -> bytes:
"""Serializes a list of floats into the "raw bytes" format sqlite-vec expects"""
return pack("%sf" % len(vector), *vector)
def build_command(file_path, metadata_set=None):
if metadata_set:
metadata_set = set(metadata_set.split(","))
file_path = Path(file_path)
print(f"reading {file_path}...")
t0 = time.time()
with tarfile.open(file_path, "r:gz") as archive:
for file in archive:
if file.name == "./payloads.jsonl":
payloads = [
json.loads(line)
for line in archive.extractfile(file.name).readlines()
]
if file.name == "./tests.jsonl":
tests = [
json.loads(line)
for line in archive.extractfile(file.name).readlines()
]
if file.name == "./vectors.npy":
f = BytesIO()
f.write(archive.extractfile(file.name).read())
f.seek(0)
vectors = np.load(f)
assert payloads is not None
assert tests is not None
assert vectors is not None
dimensions = vectors.shape[1]
metadata_columns = sorted(list(payloads[0].keys()))
def col_type(v):
if isinstance(v, int):
return "integer"
if isinstance(v, float):
return "float"
if isinstance(v, str):
return "text"
raise Exception(f"Unknown column type: {v}")
metadata_columns_types = [col_type(payloads[0][col]) for col in metadata_columns]
print(time.time() - t0)
t0 = time.time()
print("seeding...")
db = sqlite3.connect(f"{file_path.stem}.db")
db.execute("PRAGMA page_size = 16384")
db.row_factory = sqlite3.Row
db.enable_load_extension(True)
db.load_extension("../../dist/vec0")
db.enable_load_extension(False)
with db:
db.execute("create table tests(data)")
for test in tests:
db.execute("insert into tests values (?)", [json.dumps(test)])
with db:
create_sql = f"create virtual table v using vec0(vector float[{dimensions}] distance_metric=cosine"
insert_sql = "insert into v(rowid, vector"
for name, type in zip(metadata_columns, metadata_columns_types):
if metadata_set:
if name in metadata_set:
create_sql += f", {name} {type}"
else:
create_sql += f", +{name} {type}"
else:
create_sql += f", {name} {type}"
insert_sql += f", {name}"
create_sql += ")"
insert_sql += ") values (" + ",".join("?" * (2 + len(metadata_columns))) + ")"
print(create_sql)
print(insert_sql)
db.execute(create_sql)
for idx, (payload, vector) in enumerate(
tqdm(zip(payloads, vectors), total=len(payloads))
):
params = [idx, vector]
for c in metadata_columns:
params.append(payload[c])
db.execute(insert_sql, params)
print(time.time() - t0)
def tests_command(file_path):
file_path = Path(file_path)
db = sqlite3.connect(f"{file_path.stem}.db")
db.execute("PRAGMA cache_size = -100000000")
db.row_factory = sqlite3.Row
db.enable_load_extension(True)
db.load_extension("../../dist/vec0")
db.enable_load_extension(False)
tests = [
json.loads(row["data"])
2024-11-17 14:56:31 -08:00
for row in db.execute("select data from tests").fetchall()
2024-11-17 08:38:57 -08:00
]
num_or_skips = 0
num_1off_errors = 0
t0 = time.time()
print("testing...")
for idx, test in enumerate(tqdm(tests)):
query = test["query"]
conditions = test["conditions"]
expected_closest_ids = test["closest_ids"]
expected_closest_scores = test["closest_scores"]
sql = "select rowid, 1 - distance as similarity from v where vector match ? and k = ?"
params = [serialize_float32(query), len(expected_closest_ids)]
2024-11-19 21:46:50 -08:00
if "and" in conditions:
for condition in conditions["and"]:
assert len(condition.keys()) == 1
column = list(condition.keys())[0]
assert len(list(condition[column].keys())) == 1
condition_type = list(condition[column].keys())[0]
if condition_type == "match":
value = condition[column]["match"]["value"]
sql += f" and {column} = ?"
params.append(value)
elif condition_type == "range":
sql += f" and {column} between ? and ?"
params.append(condition[column]["range"]["gt"])
params.append(condition[column]["range"]["lt"])
else:
raise Exception(f"Unknown condition type: {condition_type}")
elif "or" in conditions:
column = list(conditions["or"][0].keys())[0]
condition_type = list(conditions["or"][0][column].keys())[0]
assert condition_type == "match"
sql += f" and {column} in ("
for idx, condition in enumerate(conditions["or"]):
if condition_type == "match":
value = condition[column]["match"]["value"]
if idx != 0:
sql += ","
sql += "?"
params.append(value)
elif condition_type == "range":
breakpoint()
else:
raise Exception(f"Unknown condition type: {condition_type}")
sql += ")"
2024-11-17 08:38:57 -08:00
2024-11-19 21:46:50 -08:00
# print(sql, params[1:])
2024-11-17 08:38:57 -08:00
rows = db.execute(sql, params).fetchall()
actual_closest_ids = [row["rowid"] for row in rows]
matches = expected_closest_ids == actual_closest_ids
if not matches:
diff = DeepDiff(
expected_closest_ids, actual_closest_ids, ignore_order=False
)
assert len(list(diff.keys())) == 1
assert "values_changed" in diff.keys()
keys_changed = list(diff["values_changed"].keys())
if len(keys_changed) == 2:
akey, bkey = keys_changed
a = int(akey.lstrip("root[").rstrip("]"))
b = int(bkey.lstrip("root[").rstrip("]"))
assert abs(a - b) == 1
assert (
diff["values_changed"][akey]["new_value"]
== diff["values_changed"][bkey]["old_value"]
)
assert (
diff["values_changed"][akey]["old_value"]
== diff["values_changed"][bkey]["new_value"]
)
elif len(keys_changed) == 1:
2024-11-17 14:56:31 -08:00
v = int(keys_changed[0].lstrip("root[").rstrip("]"))
assert (v + 1) == len(expected_closest_ids)
2024-11-17 08:38:57 -08:00
else:
raise Exception("fuck")
num_1off_errors += 1
# print(closest_scores)
# print([row["similarity"] for row in rows])
# assert closest_scores == [row["similarity"] for row in rows]
print("Number skipped: ", num_or_skips)
print("Num 1 off errors: ", num_1off_errors)
print("1 off error rate: ", num_1off_errors / (len(tests) - num_or_skips))
print(time.time() - t0)
print("done")
def main():
parser = argparse.ArgumentParser(description="CLI tool")
subparsers = parser.add_subparsers(dest="command", required=True)
build_parser = subparsers.add_parser("build")
build_parser.add_argument("file", type=str, help="Path to input file")
build_parser.add_argument("--metadata", type=str, help="Metadata columns")
build_parser.set_defaults(func=lambda args: build_command(args.file, args.metadata))
tests_parser = subparsers.add_parser("test")
tests_parser.add_argument("file", type=str, help="Path to input file")
tests_parser.set_defaults(func=lambda args: tests_command(args.file))
args = parser.parse_args()
args.func(args)
if __name__ == "__main__":
main()