sqlite-vec/tests/skip.test-correctness.py
2024-11-13 11:55:18 -08:00

49 lines
1.1 KiB
Python

import sqlite3
import json
db = sqlite3.connect("test2.db")
db.enable_load_extension(True)
db.load_extension("dist/vec0")
db.enable_load_extension(False)
db.row_factory = sqlite3.Row
db.execute('attach database "sift1m-base.db" as sift1m')
#def test_sift1m():
rows = db.execute(
'''
with q as (
select rowid, vector, k100 from sift1m.sift1m_query limit 10
),
results as (
select
q.rowid as query_rowid,
vec_sift1m.rowid as vec_rowid,
distance,
k100 as k100_groundtruth
from q
join vec_sift1m
where
vec_sift1m.vector match q.vector
and k = 100
order by distance
)
select
query_rowid,
json_group_array(vec_rowid order by distance) as topk,
k100_groundtruth,
json_group_array(vec_rowid order by distance) == k100_groundtruth
from results
group by 1;
''').fetchall()
results = []
for row in rows:
actual = json.loads(row["topk"])
expected = json.loads(row["k100_groundtruth"])
ncorrect = sum([x in expected for x in actual])
results.append(ncorrect / 100.0)
from statistics import mean
print(mean(results))