mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 08:46:49 +02:00
* initial pass at PARTITION KEY support. * Initial pass, allow auxiliary columns on vec0 virtual tables * update TODO * Initial pass at metadata filtering * unit tests * gha this PR branch * fixup tests * doc internal * fix tests, KNN/rowids in * define SQLITE_INDEX_CONSTRAINT_OFFSET * whoops * update tests, syrupy, use uv * un ignore pyproject.toml * dot * tests/ * type error? * win: .exe, update error name * try fix macos python, paren around expr? * win bash? * dbg :( * explicit error * op * dbg win * win ./tests/.venv/Scripts/python.exe * block UPDATEs on partition key values for now * test this branch * accidentally removved "partition key type mistmatch" block during merge * typo ugh * bruv * start aux snapshots * drop aux shadow table on destroy * enforce column types * block WHERE constraints on auxiliary columns in KNN queries * support delete * support UPDATE on auxiliary columns * test this PR * dont inline that * test-metadata.py * memzero text buffer * stress test * more snpashot tests * rm double/int32, just float/int64 * finish type checking * long text support * DELETE support * UPDATE support * fix snapshot names * drop not-used in eqp * small fixes * boolean comparison handling * ensure error is raised when long string constraint * new version string for beta builds * typo whoops * ann-filtering-benchmark directory * test-case * updates * fix aux column error when using non-default rowid values, needs test * refactor some text knn filtering * rowids blob read only on text metadata filters * refactor * add failing test causes for non eq text knn * text knn NE * test cases diff * GT * text knn GT/GE fixes * text knn LT/LE * clean * vtab_in handling * unblock aux failures for now * guard sqlite3_vtab_in * else in guard? * fixes and tests * add broken shadow table test * rename _metadata_chunksNN shadown table to _metadatachunksNN, for proper shadowName detection * _metadata_text_NN shadow tables to _metadatatextNN * SQLITE_VEC_VERSION_MAJOR SQLITE_VEC_VERSION_MINOR and SQLITE_VEC_VERSION_PATCH in sqlite-vec.h * _info shadow table * forgot to update aux snapshot? * fix aux tests
231 lines
8.2 KiB
Python
231 lines
8.2 KiB
Python
import numpy as np
|
|
from tqdm import tqdm
|
|
from deepdiff import DeepDiff
|
|
|
|
import tarfile
|
|
import json
|
|
from io import BytesIO
|
|
import sqlite3
|
|
from typing import List
|
|
from struct import pack
|
|
import time
|
|
from pathlib import Path
|
|
import argparse
|
|
|
|
|
|
def serialize_float32(vector: List[float]) -> bytes:
|
|
"""Serializes a list of floats into the "raw bytes" format sqlite-vec expects"""
|
|
return pack("%sf" % len(vector), *vector)
|
|
|
|
|
|
def build_command(file_path, metadata_set=None):
|
|
if metadata_set:
|
|
metadata_set = set(metadata_set.split(","))
|
|
|
|
file_path = Path(file_path)
|
|
print(f"reading {file_path}...")
|
|
t0 = time.time()
|
|
with tarfile.open(file_path, "r:gz") as archive:
|
|
for file in archive:
|
|
if file.name == "./payloads.jsonl":
|
|
payloads = [
|
|
json.loads(line)
|
|
for line in archive.extractfile(file.name).readlines()
|
|
]
|
|
if file.name == "./tests.jsonl":
|
|
tests = [
|
|
json.loads(line)
|
|
for line in archive.extractfile(file.name).readlines()
|
|
]
|
|
if file.name == "./vectors.npy":
|
|
f = BytesIO()
|
|
f.write(archive.extractfile(file.name).read())
|
|
f.seek(0)
|
|
vectors = np.load(f)
|
|
|
|
assert payloads is not None
|
|
assert tests is not None
|
|
assert vectors is not None
|
|
dimensions = vectors.shape[1]
|
|
metadata_columns = sorted(list(payloads[0].keys()))
|
|
|
|
def col_type(v):
|
|
if isinstance(v, int):
|
|
return "integer"
|
|
if isinstance(v, float):
|
|
return "float"
|
|
if isinstance(v, str):
|
|
return "text"
|
|
raise Exception(f"Unknown column type: {v}")
|
|
|
|
metadata_columns_types = [col_type(payloads[0][col]) for col in metadata_columns]
|
|
|
|
print(time.time() - t0)
|
|
t0 = time.time()
|
|
print("seeding...")
|
|
|
|
db = sqlite3.connect(f"{file_path.stem}.db")
|
|
db.execute("PRAGMA page_size = 16384")
|
|
db.row_factory = sqlite3.Row
|
|
db.enable_load_extension(True)
|
|
db.load_extension("../../dist/vec0")
|
|
db.enable_load_extension(False)
|
|
|
|
with db:
|
|
db.execute("create table tests(data)")
|
|
|
|
for test in tests:
|
|
db.execute("insert into tests values (?)", [json.dumps(test)])
|
|
|
|
with db:
|
|
create_sql = f"create virtual table v using vec0(vector float[{dimensions}] distance_metric=cosine"
|
|
insert_sql = "insert into v(rowid, vector"
|
|
for name, type in zip(metadata_columns, metadata_columns_types):
|
|
if metadata_set:
|
|
if name in metadata_set:
|
|
create_sql += f", {name} {type}"
|
|
else:
|
|
create_sql += f", +{name} {type}"
|
|
else:
|
|
create_sql += f", {name} {type}"
|
|
|
|
insert_sql += f", {name}"
|
|
create_sql += ")"
|
|
insert_sql += ") values (" + ",".join("?" * (2 + len(metadata_columns))) + ")"
|
|
print(create_sql)
|
|
print(insert_sql)
|
|
|
|
db.execute(create_sql)
|
|
|
|
for idx, (payload, vector) in enumerate(
|
|
tqdm(zip(payloads, vectors), total=len(payloads))
|
|
):
|
|
params = [idx, vector]
|
|
for c in metadata_columns:
|
|
params.append(payload[c])
|
|
db.execute(insert_sql, params)
|
|
|
|
print(time.time() - t0)
|
|
|
|
|
|
def tests_command(file_path):
|
|
file_path = Path(file_path)
|
|
db = sqlite3.connect(f"{file_path.stem}.db")
|
|
db.execute("PRAGMA cache_size = -100000000")
|
|
db.row_factory = sqlite3.Row
|
|
db.enable_load_extension(True)
|
|
db.load_extension("../../dist/vec0")
|
|
db.enable_load_extension(False)
|
|
|
|
tests = [
|
|
json.loads(row["data"])
|
|
for row in db.execute("select data from tests").fetchall()
|
|
]
|
|
|
|
num_or_skips = 0
|
|
num_1off_errors = 0
|
|
|
|
t0 = time.time()
|
|
print("testing...")
|
|
for idx, test in enumerate(tqdm(tests)):
|
|
query = test["query"]
|
|
conditions = test["conditions"]
|
|
expected_closest_ids = test["closest_ids"]
|
|
expected_closest_scores = test["closest_scores"]
|
|
|
|
sql = "select rowid, 1 - distance as similarity from v where vector match ? and k = ?"
|
|
params = [serialize_float32(query), len(expected_closest_ids)]
|
|
|
|
if "and" in conditions:
|
|
for condition in conditions["and"]:
|
|
assert len(condition.keys()) == 1
|
|
column = list(condition.keys())[0]
|
|
assert len(list(condition[column].keys())) == 1
|
|
condition_type = list(condition[column].keys())[0]
|
|
if condition_type == "match":
|
|
value = condition[column]["match"]["value"]
|
|
sql += f" and {column} = ?"
|
|
params.append(value)
|
|
elif condition_type == "range":
|
|
sql += f" and {column} between ? and ?"
|
|
params.append(condition[column]["range"]["gt"])
|
|
params.append(condition[column]["range"]["lt"])
|
|
else:
|
|
raise Exception(f"Unknown condition type: {condition_type}")
|
|
elif "or" in conditions:
|
|
column = list(conditions["or"][0].keys())[0]
|
|
condition_type = list(conditions["or"][0][column].keys())[0]
|
|
assert condition_type == "match"
|
|
sql += f" and {column} in ("
|
|
for idx, condition in enumerate(conditions["or"]):
|
|
if condition_type == "match":
|
|
value = condition[column]["match"]["value"]
|
|
if idx != 0:
|
|
sql += ","
|
|
sql += "?"
|
|
params.append(value)
|
|
elif condition_type == "range":
|
|
breakpoint()
|
|
else:
|
|
raise Exception(f"Unknown condition type: {condition_type}")
|
|
sql += ")"
|
|
|
|
# print(sql, params[1:])
|
|
rows = db.execute(sql, params).fetchall()
|
|
actual_closest_ids = [row["rowid"] for row in rows]
|
|
matches = expected_closest_ids == actual_closest_ids
|
|
if not matches:
|
|
diff = DeepDiff(
|
|
expected_closest_ids, actual_closest_ids, ignore_order=False
|
|
)
|
|
assert len(list(diff.keys())) == 1
|
|
assert "values_changed" in diff.keys()
|
|
keys_changed = list(diff["values_changed"].keys())
|
|
if len(keys_changed) == 2:
|
|
akey, bkey = keys_changed
|
|
a = int(akey.lstrip("root[").rstrip("]"))
|
|
b = int(bkey.lstrip("root[").rstrip("]"))
|
|
assert abs(a - b) == 1
|
|
assert (
|
|
diff["values_changed"][akey]["new_value"]
|
|
== diff["values_changed"][bkey]["old_value"]
|
|
)
|
|
assert (
|
|
diff["values_changed"][akey]["old_value"]
|
|
== diff["values_changed"][bkey]["new_value"]
|
|
)
|
|
elif len(keys_changed) == 1:
|
|
v = int(keys_changed[0].lstrip("root[").rstrip("]"))
|
|
assert (v + 1) == len(expected_closest_ids)
|
|
else:
|
|
raise Exception("fuck")
|
|
num_1off_errors += 1
|
|
# print(closest_scores)
|
|
# print([row["similarity"] for row in rows])
|
|
# assert closest_scores == [row["similarity"] for row in rows]
|
|
print("Number skipped: ", num_or_skips)
|
|
print("Num 1 off errors: ", num_1off_errors)
|
|
print("1 off error rate: ", num_1off_errors / (len(tests) - num_or_skips))
|
|
print(time.time() - t0)
|
|
print("done")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="CLI tool")
|
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
|
|
build_parser = subparsers.add_parser("build")
|
|
build_parser.add_argument("file", type=str, help="Path to input file")
|
|
build_parser.add_argument("--metadata", type=str, help="Metadata columns")
|
|
build_parser.set_defaults(func=lambda args: build_command(args.file, args.metadata))
|
|
|
|
tests_parser = subparsers.add_parser("test")
|
|
tests_parser.add_argument("file", type=str, help="Path to input file")
|
|
tests_parser.set_defaults(func=lambda args: tests_command(args.file))
|
|
|
|
args = parser.parse_args()
|
|
args.func(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|