mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 16:56:27 +02:00
ann-filtering-benchmark directory
This commit is contained in:
parent
052ba4b089
commit
f55e14cce8
7 changed files with 259 additions and 18 deletions
215
tests/afbd/test-afbd.py
Normal file
215
tests/afbd/test-afbd.py
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from deepdiff import DeepDiff
|
||||
|
||||
import tarfile
|
||||
import json
|
||||
from io import BytesIO
|
||||
import sqlite3
|
||||
from typing import List
|
||||
from struct import pack
|
||||
import time
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
|
||||
def serialize_float32(vector: List[float]) -> bytes:
|
||||
"""Serializes a list of floats into the "raw bytes" format sqlite-vec expects"""
|
||||
return pack("%sf" % len(vector), *vector)
|
||||
|
||||
|
||||
def build_command(file_path, metadata_set=None):
|
||||
if metadata_set:
|
||||
metadata_set = set(metadata_set.split(","))
|
||||
|
||||
file_path = Path(file_path)
|
||||
print(f"reading {file_path}...")
|
||||
t0 = time.time()
|
||||
with tarfile.open(file_path, "r:gz") as archive:
|
||||
for file in archive:
|
||||
if file.name == "./payloads.jsonl":
|
||||
payloads = [
|
||||
json.loads(line)
|
||||
for line in archive.extractfile(file.name).readlines()
|
||||
]
|
||||
if file.name == "./tests.jsonl":
|
||||
tests = [
|
||||
json.loads(line)
|
||||
for line in archive.extractfile(file.name).readlines()
|
||||
]
|
||||
if file.name == "./vectors.npy":
|
||||
f = BytesIO()
|
||||
f.write(archive.extractfile(file.name).read())
|
||||
f.seek(0)
|
||||
vectors = np.load(f)
|
||||
|
||||
assert payloads is not None
|
||||
assert tests is not None
|
||||
assert vectors is not None
|
||||
dimensions = vectors.shape[1]
|
||||
metadata_columns = sorted(list(payloads[0].keys()))
|
||||
|
||||
def col_type(v):
|
||||
if isinstance(v, int):
|
||||
return "integer"
|
||||
if isinstance(v, float):
|
||||
return "float"
|
||||
if isinstance(v, str):
|
||||
return "text"
|
||||
raise Exception(f"Unknown column type: {v}")
|
||||
|
||||
metadata_columns_types = [col_type(payloads[0][col]) for col in metadata_columns]
|
||||
|
||||
print(time.time() - t0)
|
||||
t0 = time.time()
|
||||
print("seeding...")
|
||||
|
||||
db = sqlite3.connect(f"{file_path.stem}.db")
|
||||
db.execute("PRAGMA page_size = 16384")
|
||||
db.row_factory = sqlite3.Row
|
||||
db.enable_load_extension(True)
|
||||
db.load_extension("../../dist/vec0")
|
||||
db.enable_load_extension(False)
|
||||
|
||||
with db:
|
||||
db.execute("create table tests(data)")
|
||||
|
||||
for test in tests:
|
||||
db.execute("insert into tests values (?)", [json.dumps(test)])
|
||||
|
||||
with db:
|
||||
create_sql = f"create virtual table v using vec0(vector float[{dimensions}] distance_metric=cosine"
|
||||
insert_sql = "insert into v(rowid, vector"
|
||||
for name, type in zip(metadata_columns, metadata_columns_types):
|
||||
if metadata_set:
|
||||
if name in metadata_set:
|
||||
create_sql += f", {name} {type}"
|
||||
else:
|
||||
create_sql += f", +{name} {type}"
|
||||
else:
|
||||
create_sql += f", {name} {type}"
|
||||
|
||||
insert_sql += f", {name}"
|
||||
create_sql += ")"
|
||||
insert_sql += ") values (" + ",".join("?" * (2 + len(metadata_columns))) + ")"
|
||||
print(create_sql)
|
||||
print(insert_sql)
|
||||
|
||||
db.execute(create_sql)
|
||||
|
||||
for idx, (payload, vector) in enumerate(
|
||||
tqdm(zip(payloads, vectors), total=len(payloads))
|
||||
):
|
||||
params = [idx, vector]
|
||||
for c in metadata_columns:
|
||||
params.append(payload[c])
|
||||
db.execute(insert_sql, params)
|
||||
|
||||
print(time.time() - t0)
|
||||
|
||||
|
||||
def tests_command(file_path):
|
||||
file_path = Path(file_path)
|
||||
db = sqlite3.connect(f"{file_path.stem}.db")
|
||||
db.execute("PRAGMA cache_size = -100000000")
|
||||
db.row_factory = sqlite3.Row
|
||||
db.enable_load_extension(True)
|
||||
db.load_extension("../../dist/vec0")
|
||||
db.enable_load_extension(False)
|
||||
|
||||
tests = [
|
||||
json.loads(row["data"])
|
||||
for row in db.execute("select data from tests limit 2000").fetchall()
|
||||
]
|
||||
|
||||
num_or_skips = 0
|
||||
num_1off_errors = 0
|
||||
|
||||
t0 = time.time()
|
||||
print("testing...")
|
||||
for idx, test in enumerate(tqdm(tests)):
|
||||
query = test["query"]
|
||||
conditions = test["conditions"]
|
||||
expected_closest_ids = test["closest_ids"]
|
||||
expected_closest_scores = test["closest_scores"]
|
||||
if "or" in conditions:
|
||||
num_or_skips += 1
|
||||
continue
|
||||
|
||||
sql = "select rowid, 1 - distance as similarity from v where vector match ? and k = ?"
|
||||
params = [serialize_float32(query), len(expected_closest_ids)]
|
||||
|
||||
for condition in conditions["and"]:
|
||||
assert len(condition.keys()) == 1
|
||||
column = list(condition.keys())[0]
|
||||
assert len(list(condition[column].keys())) == 1
|
||||
condition_type = list(condition[column].keys())[0]
|
||||
if condition_type == "match":
|
||||
value = condition[column]["match"]["value"]
|
||||
sql += f" and {column} = ?"
|
||||
params.append(value)
|
||||
elif condition_type == "range":
|
||||
sql += f" and {column} between ? and ?"
|
||||
params.append(condition[column]["range"]["gt"])
|
||||
params.append(condition[column]["range"]["lt"])
|
||||
else:
|
||||
raise Exception(f"Unknown condition type: {condition_type}")
|
||||
|
||||
rows = db.execute(sql, params).fetchall()
|
||||
actual_closest_ids = [row["rowid"] for row in rows]
|
||||
matches = expected_closest_ids == actual_closest_ids
|
||||
if not matches:
|
||||
diff = DeepDiff(
|
||||
expected_closest_ids, actual_closest_ids, ignore_order=False
|
||||
)
|
||||
assert len(list(diff.keys())) == 1
|
||||
assert "values_changed" in diff.keys()
|
||||
keys_changed = list(diff["values_changed"].keys())
|
||||
if len(keys_changed) == 2:
|
||||
akey, bkey = keys_changed
|
||||
a = int(akey.lstrip("root[").rstrip("]"))
|
||||
b = int(bkey.lstrip("root[").rstrip("]"))
|
||||
assert abs(a - b) == 1
|
||||
assert (
|
||||
diff["values_changed"][akey]["new_value"]
|
||||
== diff["values_changed"][bkey]["old_value"]
|
||||
)
|
||||
assert (
|
||||
diff["values_changed"][akey]["old_value"]
|
||||
== diff["values_changed"][bkey]["new_value"]
|
||||
)
|
||||
elif len(keys_changed) == 1:
|
||||
v = int(akey.lstrip("root[").rstrip("]"))
|
||||
assert v == len(expected_closest_ids)
|
||||
else:
|
||||
raise Exception("fuck")
|
||||
num_1off_errors += 1
|
||||
# print(closest_scores)
|
||||
# print([row["similarity"] for row in rows])
|
||||
# assert closest_scores == [row["similarity"] for row in rows]
|
||||
print("Number skipped: ", num_or_skips)
|
||||
print("Num 1 off errors: ", num_1off_errors)
|
||||
print("1 off error rate: ", num_1off_errors / (len(tests) - num_or_skips))
|
||||
print(time.time() - t0)
|
||||
print("done")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="CLI tool")
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
build_parser = subparsers.add_parser("build")
|
||||
build_parser.add_argument("file", type=str, help="Path to input file")
|
||||
build_parser.add_argument("--metadata", type=str, help="Metadata columns")
|
||||
build_parser.set_defaults(func=lambda args: build_command(args.file, args.metadata))
|
||||
|
||||
tests_parser = subparsers.add_parser("test")
|
||||
tests_parser.add_argument("file", type=str, help="Path to input file")
|
||||
tests_parser.set_defaults(func=lambda args: tests_command(args.file))
|
||||
|
||||
args = parser.parse_args()
|
||||
args.func(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue