mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 16:56:27 +02:00
49 lines
1.1 KiB
Python
49 lines
1.1 KiB
Python
import sqlite3
|
|
import json
|
|
|
|
db = sqlite3.connect("test2.db")
|
|
db.enable_load_extension(True)
|
|
db.load_extension("dist/vec0")
|
|
db.enable_load_extension(False)
|
|
db.row_factory = sqlite3.Row
|
|
db.execute('attach database "sift1m-base.db" as sift1m')
|
|
|
|
|
|
#def test_sift1m():
|
|
rows = db.execute(
|
|
'''
|
|
with q as (
|
|
select rowid, vector, k100 from sift1m.sift1m_query limit 10
|
|
),
|
|
results as (
|
|
select
|
|
q.rowid as query_rowid,
|
|
vec_sift1m.rowid as vec_rowid,
|
|
distance,
|
|
k100 as k100_groundtruth
|
|
from q
|
|
join vec_sift1m
|
|
where
|
|
vec_sift1m.vector match q.vector
|
|
and k = 100
|
|
order by distance
|
|
)
|
|
select
|
|
query_rowid,
|
|
json_group_array(vec_rowid order by distance) as topk,
|
|
k100_groundtruth,
|
|
json_group_array(vec_rowid order by distance) == k100_groundtruth
|
|
from results
|
|
group by 1;
|
|
''').fetchall()
|
|
|
|
results = []
|
|
for row in rows:
|
|
actual = json.loads(row["topk"])
|
|
expected = json.loads(row["k100_groundtruth"])
|
|
|
|
ncorrect = sum([x in expected for x in actual])
|
|
results.append(ncorrect / 100.0)
|
|
|
|
from statistics import mean
|
|
print(mean(results))
|