mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-27 09:46:27 +02:00
ann-filtering-benchmark directory
This commit is contained in:
parent
052ba4b089
commit
f55e14cce8
7 changed files with 259 additions and 18 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -26,3 +26,5 @@ sqlite-vec.h
|
||||||
tmp/
|
tmp/
|
||||||
|
|
||||||
poetry.lock
|
poetry.lock
|
||||||
|
|
||||||
|
*.jsonl
|
||||||
|
|
|
||||||
37
sqlite-vec.c
37
sqlite-vec.c
|
|
@ -5972,6 +5972,15 @@ int vec0_set_metadata_filter_bitmap(
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case VEC0_METADATA_OPERATOR_NE: {
|
||||||
|
for(int i = 0; i < size; i++) {
|
||||||
|
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
|
||||||
|
int n = ((int*) view)[0];
|
||||||
|
char * s = (char *) &view[4];
|
||||||
|
bitmap_set(b, i, strncmp(s, target, n) != 0);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
case VEC0_METADATA_OPERATOR_GT: {
|
case VEC0_METADATA_OPERATOR_GT: {
|
||||||
for(int i = 0; i < size; i++) {
|
for(int i = 0; i < size; i++) {
|
||||||
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
|
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
|
||||||
|
|
@ -5981,6 +5990,15 @@ int vec0_set_metadata_filter_bitmap(
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case VEC0_METADATA_OPERATOR_GE: {
|
||||||
|
for(int i = 0; i < size; i++) {
|
||||||
|
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
|
||||||
|
int n = ((int*) view)[0];
|
||||||
|
char * s = (char *) &view[4];
|
||||||
|
bitmap_set(b, i, strncmp(s, target, n) >= 0);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
case VEC0_METADATA_OPERATOR_LE: {
|
case VEC0_METADATA_OPERATOR_LE: {
|
||||||
for(int i = 0; i < size; i++) {
|
for(int i = 0; i < size; i++) {
|
||||||
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
|
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
|
||||||
|
|
@ -5999,24 +6017,7 @@ int vec0_set_metadata_filter_bitmap(
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case VEC0_METADATA_OPERATOR_GE: {
|
|
||||||
for(int i = 0; i < size; i++) {
|
|
||||||
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
|
|
||||||
int n = ((int*) view)[0];
|
|
||||||
char * s = (char *) &view[4];
|
|
||||||
bitmap_set(b, i, strncmp(s, target, n) >= 0);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case VEC0_METADATA_OPERATOR_NE: {
|
|
||||||
for(int i = 0; i < size; i++) {
|
|
||||||
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
|
|
||||||
int n = ((int*) view)[0];
|
|
||||||
char * s = (char *) &view[4];
|
|
||||||
bitmap_set(b, i, strncmp(s, target, n) != 0);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
1
tests/afbd/.gitignore
vendored
Normal file
1
tests/afbd/.gitignore
vendored
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
*.tgz
|
||||||
1
tests/afbd/.python-version
Normal file
1
tests/afbd/.python-version
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
3.12
|
||||||
9
tests/afbd/Makefile
Normal file
9
tests/afbd/Makefile
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
random_ints_1m.tgz:
|
||||||
|
curl -o $@ https://storage.googleapis.com/ann-filtered-benchmark/datasets/random_ints_1m.tgz
|
||||||
|
|
||||||
|
random_float_1m.tgz:
|
||||||
|
curl -o $@ https://storage.googleapis.com/ann-filtered-benchmark/datasets/random_float_1m.tgz
|
||||||
|
|
||||||
|
random_keywords_1m.tgz:
|
||||||
|
curl -o $@ https://storage.googleapis.com/ann-filtered-benchmark/datasets/random_keywords_1m.tgz
|
||||||
|
all: random_ints_1m.tgz random_float_1m.tgz random_keywords_1m.tgz
|
||||||
12
tests/afbd/README.md
Normal file
12
tests/afbd/README.md
Normal file
|
|
@ -0,0 +1,12 @@
|
||||||
|
|
||||||
|
# hnm
|
||||||
|
|
||||||
|
```
|
||||||
|
tar -xOzf hnm.tgz ./tests.jsonl > tests.jsonl
|
||||||
|
solite q "select group_concat(distinct key) from lines_read('tests.jsonl'), json_each(line -> '$.conditions.and[0]')"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
> python test-afbd.py build hnm.tgz --metadata product_group_name,colour_group_name,index_group_name,perceived_colour_value_name,section_name,product_type_name,department_name,graphical_appearance_name,garment_group_name,perceived_colour_master_name
|
||||||
|
```
|
||||||
215
tests/afbd/test-afbd.py
Normal file
215
tests/afbd/test-afbd.py
Normal file
|
|
@ -0,0 +1,215 @@
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from deepdiff import DeepDiff
|
||||||
|
|
||||||
|
import tarfile
|
||||||
|
import json
|
||||||
|
from io import BytesIO
|
||||||
|
import sqlite3
|
||||||
|
from typing import List
|
||||||
|
from struct import pack
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_float32(vector: List[float]) -> bytes:
|
||||||
|
"""Serializes a list of floats into the "raw bytes" format sqlite-vec expects"""
|
||||||
|
return pack("%sf" % len(vector), *vector)
|
||||||
|
|
||||||
|
|
||||||
|
def build_command(file_path, metadata_set=None):
|
||||||
|
if metadata_set:
|
||||||
|
metadata_set = set(metadata_set.split(","))
|
||||||
|
|
||||||
|
file_path = Path(file_path)
|
||||||
|
print(f"reading {file_path}...")
|
||||||
|
t0 = time.time()
|
||||||
|
with tarfile.open(file_path, "r:gz") as archive:
|
||||||
|
for file in archive:
|
||||||
|
if file.name == "./payloads.jsonl":
|
||||||
|
payloads = [
|
||||||
|
json.loads(line)
|
||||||
|
for line in archive.extractfile(file.name).readlines()
|
||||||
|
]
|
||||||
|
if file.name == "./tests.jsonl":
|
||||||
|
tests = [
|
||||||
|
json.loads(line)
|
||||||
|
for line in archive.extractfile(file.name).readlines()
|
||||||
|
]
|
||||||
|
if file.name == "./vectors.npy":
|
||||||
|
f = BytesIO()
|
||||||
|
f.write(archive.extractfile(file.name).read())
|
||||||
|
f.seek(0)
|
||||||
|
vectors = np.load(f)
|
||||||
|
|
||||||
|
assert payloads is not None
|
||||||
|
assert tests is not None
|
||||||
|
assert vectors is not None
|
||||||
|
dimensions = vectors.shape[1]
|
||||||
|
metadata_columns = sorted(list(payloads[0].keys()))
|
||||||
|
|
||||||
|
def col_type(v):
|
||||||
|
if isinstance(v, int):
|
||||||
|
return "integer"
|
||||||
|
if isinstance(v, float):
|
||||||
|
return "float"
|
||||||
|
if isinstance(v, str):
|
||||||
|
return "text"
|
||||||
|
raise Exception(f"Unknown column type: {v}")
|
||||||
|
|
||||||
|
metadata_columns_types = [col_type(payloads[0][col]) for col in metadata_columns]
|
||||||
|
|
||||||
|
print(time.time() - t0)
|
||||||
|
t0 = time.time()
|
||||||
|
print("seeding...")
|
||||||
|
|
||||||
|
db = sqlite3.connect(f"{file_path.stem}.db")
|
||||||
|
db.execute("PRAGMA page_size = 16384")
|
||||||
|
db.row_factory = sqlite3.Row
|
||||||
|
db.enable_load_extension(True)
|
||||||
|
db.load_extension("../../dist/vec0")
|
||||||
|
db.enable_load_extension(False)
|
||||||
|
|
||||||
|
with db:
|
||||||
|
db.execute("create table tests(data)")
|
||||||
|
|
||||||
|
for test in tests:
|
||||||
|
db.execute("insert into tests values (?)", [json.dumps(test)])
|
||||||
|
|
||||||
|
with db:
|
||||||
|
create_sql = f"create virtual table v using vec0(vector float[{dimensions}] distance_metric=cosine"
|
||||||
|
insert_sql = "insert into v(rowid, vector"
|
||||||
|
for name, type in zip(metadata_columns, metadata_columns_types):
|
||||||
|
if metadata_set:
|
||||||
|
if name in metadata_set:
|
||||||
|
create_sql += f", {name} {type}"
|
||||||
|
else:
|
||||||
|
create_sql += f", +{name} {type}"
|
||||||
|
else:
|
||||||
|
create_sql += f", {name} {type}"
|
||||||
|
|
||||||
|
insert_sql += f", {name}"
|
||||||
|
create_sql += ")"
|
||||||
|
insert_sql += ") values (" + ",".join("?" * (2 + len(metadata_columns))) + ")"
|
||||||
|
print(create_sql)
|
||||||
|
print(insert_sql)
|
||||||
|
|
||||||
|
db.execute(create_sql)
|
||||||
|
|
||||||
|
for idx, (payload, vector) in enumerate(
|
||||||
|
tqdm(zip(payloads, vectors), total=len(payloads))
|
||||||
|
):
|
||||||
|
params = [idx, vector]
|
||||||
|
for c in metadata_columns:
|
||||||
|
params.append(payload[c])
|
||||||
|
db.execute(insert_sql, params)
|
||||||
|
|
||||||
|
print(time.time() - t0)
|
||||||
|
|
||||||
|
|
||||||
|
def tests_command(file_path):
|
||||||
|
file_path = Path(file_path)
|
||||||
|
db = sqlite3.connect(f"{file_path.stem}.db")
|
||||||
|
db.execute("PRAGMA cache_size = -100000000")
|
||||||
|
db.row_factory = sqlite3.Row
|
||||||
|
db.enable_load_extension(True)
|
||||||
|
db.load_extension("../../dist/vec0")
|
||||||
|
db.enable_load_extension(False)
|
||||||
|
|
||||||
|
tests = [
|
||||||
|
json.loads(row["data"])
|
||||||
|
for row in db.execute("select data from tests limit 2000").fetchall()
|
||||||
|
]
|
||||||
|
|
||||||
|
num_or_skips = 0
|
||||||
|
num_1off_errors = 0
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
print("testing...")
|
||||||
|
for idx, test in enumerate(tqdm(tests)):
|
||||||
|
query = test["query"]
|
||||||
|
conditions = test["conditions"]
|
||||||
|
expected_closest_ids = test["closest_ids"]
|
||||||
|
expected_closest_scores = test["closest_scores"]
|
||||||
|
if "or" in conditions:
|
||||||
|
num_or_skips += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
sql = "select rowid, 1 - distance as similarity from v where vector match ? and k = ?"
|
||||||
|
params = [serialize_float32(query), len(expected_closest_ids)]
|
||||||
|
|
||||||
|
for condition in conditions["and"]:
|
||||||
|
assert len(condition.keys()) == 1
|
||||||
|
column = list(condition.keys())[0]
|
||||||
|
assert len(list(condition[column].keys())) == 1
|
||||||
|
condition_type = list(condition[column].keys())[0]
|
||||||
|
if condition_type == "match":
|
||||||
|
value = condition[column]["match"]["value"]
|
||||||
|
sql += f" and {column} = ?"
|
||||||
|
params.append(value)
|
||||||
|
elif condition_type == "range":
|
||||||
|
sql += f" and {column} between ? and ?"
|
||||||
|
params.append(condition[column]["range"]["gt"])
|
||||||
|
params.append(condition[column]["range"]["lt"])
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unknown condition type: {condition_type}")
|
||||||
|
|
||||||
|
rows = db.execute(sql, params).fetchall()
|
||||||
|
actual_closest_ids = [row["rowid"] for row in rows]
|
||||||
|
matches = expected_closest_ids == actual_closest_ids
|
||||||
|
if not matches:
|
||||||
|
diff = DeepDiff(
|
||||||
|
expected_closest_ids, actual_closest_ids, ignore_order=False
|
||||||
|
)
|
||||||
|
assert len(list(diff.keys())) == 1
|
||||||
|
assert "values_changed" in diff.keys()
|
||||||
|
keys_changed = list(diff["values_changed"].keys())
|
||||||
|
if len(keys_changed) == 2:
|
||||||
|
akey, bkey = keys_changed
|
||||||
|
a = int(akey.lstrip("root[").rstrip("]"))
|
||||||
|
b = int(bkey.lstrip("root[").rstrip("]"))
|
||||||
|
assert abs(a - b) == 1
|
||||||
|
assert (
|
||||||
|
diff["values_changed"][akey]["new_value"]
|
||||||
|
== diff["values_changed"][bkey]["old_value"]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
diff["values_changed"][akey]["old_value"]
|
||||||
|
== diff["values_changed"][bkey]["new_value"]
|
||||||
|
)
|
||||||
|
elif len(keys_changed) == 1:
|
||||||
|
v = int(akey.lstrip("root[").rstrip("]"))
|
||||||
|
assert v == len(expected_closest_ids)
|
||||||
|
else:
|
||||||
|
raise Exception("fuck")
|
||||||
|
num_1off_errors += 1
|
||||||
|
# print(closest_scores)
|
||||||
|
# print([row["similarity"] for row in rows])
|
||||||
|
# assert closest_scores == [row["similarity"] for row in rows]
|
||||||
|
print("Number skipped: ", num_or_skips)
|
||||||
|
print("Num 1 off errors: ", num_1off_errors)
|
||||||
|
print("1 off error rate: ", num_1off_errors / (len(tests) - num_or_skips))
|
||||||
|
print(time.time() - t0)
|
||||||
|
print("done")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="CLI tool")
|
||||||
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||||
|
|
||||||
|
build_parser = subparsers.add_parser("build")
|
||||||
|
build_parser.add_argument("file", type=str, help="Path to input file")
|
||||||
|
build_parser.add_argument("--metadata", type=str, help="Metadata columns")
|
||||||
|
build_parser.set_defaults(func=lambda args: build_command(args.file, args.metadata))
|
||||||
|
|
||||||
|
tests_parser = subparsers.add_parser("test")
|
||||||
|
tests_parser.add_argument("file", type=str, help="Path to input file")
|
||||||
|
tests_parser.set_defaults(func=lambda args: tests_command(args.file))
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.func(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue