mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-26 01:06:27 +02:00
knn cleanups and tests
This commit is contained in:
parent
b1e7a93a11
commit
f217cbf2bd
8 changed files with 1328 additions and 398 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -3,6 +3,7 @@
|
||||||
sift/
|
sift/
|
||||||
*.tar.gz
|
*.tar.gz
|
||||||
*.db
|
*.db
|
||||||
|
*.npy
|
||||||
*.bin
|
*.bin
|
||||||
*.out
|
*.out
|
||||||
venv/
|
venv/
|
||||||
|
|
|
||||||
949
sqlite-vec.c
949
sqlite-vec.c
File diff suppressed because it is too large
Load diff
16
tests/correctness/build.py
Normal file
16
tests/correctness/build.py
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
import numpy as np
|
||||||
|
import duckdb
|
||||||
|
db = duckdb.connect(":memory:")
|
||||||
|
|
||||||
|
result = db.execute(
|
||||||
|
"""
|
||||||
|
select
|
||||||
|
-- _id,
|
||||||
|
-- title,
|
||||||
|
-- text as contents,
|
||||||
|
embedding::float[] as embeddings
|
||||||
|
from "hf://datasets/Supabase/dbpedia-openai-3-large-1M/dbpedia_openai_3_large_00.parquet"
|
||||||
|
"""
|
||||||
|
).fetchnumpy()['embeddings']
|
||||||
|
|
||||||
|
np.save("dbpedia_openai_3_large_00.npy", np.vstack(result))
|
||||||
124
tests/correctness/test-correctness.py
Normal file
124
tests/correctness/test-correctness.py
Normal file
|
|
@ -0,0 +1,124 @@
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
import time
|
||||||
|
import tqdm
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
def cosine_similarity(
|
||||||
|
vec: npt.NDArray[np.float32], mat: npt.NDArray[np.float32], do_norm: bool = True
|
||||||
|
) -> npt.NDArray[np.float32]:
|
||||||
|
sim = vec @ mat.T
|
||||||
|
if do_norm:
|
||||||
|
sim /= np.linalg.norm(vec) * np.linalg.norm(mat, axis=1)
|
||||||
|
return sim
|
||||||
|
|
||||||
|
def distance_l2(
|
||||||
|
vec: npt.NDArray[np.float32], mat: npt.NDArray[np.float32]
|
||||||
|
) -> npt.NDArray[np.float32]:
|
||||||
|
return np.sqrt(np.sum((mat - vec) ** 2, axis=1))
|
||||||
|
|
||||||
|
|
||||||
|
def topk(
|
||||||
|
vec: npt.NDArray[np.float32],
|
||||||
|
mat: npt.NDArray[np.float32],
|
||||||
|
k: int = 5,
|
||||||
|
) -> tuple[npt.NDArray[np.int32], npt.NDArray[np.float32]]:
|
||||||
|
distances = distance_l2(vec, mat)
|
||||||
|
# Rather than sorting all similarities and taking the top K, it's faster to
|
||||||
|
# argpartition and then just sort the top K.
|
||||||
|
# The difference is O(N logN) vs O(N + k logk)
|
||||||
|
indices = np.argpartition(distances, kth=k)[:k]
|
||||||
|
top_indices = indices[np.argsort(distances[indices])]
|
||||||
|
return top_indices, distances[top_indices]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vec = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||||
|
mat = np.array([
|
||||||
|
[4.0, 5.0, 6.0],
|
||||||
|
[1.0, 2.0, 1.0],
|
||||||
|
[7.0, 8.0, 9.0]
|
||||||
|
], dtype=np.float32)
|
||||||
|
indices, distances = topk(vec, mat, k=2)
|
||||||
|
print(indices)
|
||||||
|
print(distances)
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import json
|
||||||
|
db = sqlite3.connect(":memory:")
|
||||||
|
db.enable_load_extension(True)
|
||||||
|
db.load_extension("../../dist/vec0")
|
||||||
|
db.execute("select load_extension('../../dist/vec0', 'sqlite3_vec_fs_read_init')")
|
||||||
|
db.enable_load_extension(False)
|
||||||
|
|
||||||
|
results = db.execute(
|
||||||
|
'''
|
||||||
|
select
|
||||||
|
key,
|
||||||
|
--value,
|
||||||
|
vec_distance_l2(:q, value) as distance
|
||||||
|
from json_each(:base)
|
||||||
|
order by distance
|
||||||
|
limit 2
|
||||||
|
''',
|
||||||
|
{
|
||||||
|
'base': json.dumps(mat.tolist()),
|
||||||
|
'q': '[1.0, 2.0, 3.0]'
|
||||||
|
}).fetchall()
|
||||||
|
a = [row[0] for row in results]
|
||||||
|
b = [row[1] for row in results]
|
||||||
|
print(a)
|
||||||
|
print(b)
|
||||||
|
|
||||||
|
|
||||||
|
#import sys; sys.exit()
|
||||||
|
|
||||||
|
db.execute('PRAGMA page_size=16384')
|
||||||
|
|
||||||
|
print("Loading into sqlite-vec vec0 table...")
|
||||||
|
t0 = time.time()
|
||||||
|
db.execute("create virtual table v using vec0(a float[3072], chunk_size=16)")
|
||||||
|
db.execute('insert into v select rowid, vector from vec_npy_each(vec_npy_file("dbpedia_openai_3_large_00.npy"))')
|
||||||
|
print(time.time() - t0)
|
||||||
|
|
||||||
|
print("loading numpy array...")
|
||||||
|
t0 = time.time()
|
||||||
|
base = np.load('dbpedia_openai_3_large_00.npy')
|
||||||
|
print(time.time() - t0)
|
||||||
|
|
||||||
|
np.random.seed(1)
|
||||||
|
queries = base[np.random.choice(base.shape[0], 20, replace=False), :]
|
||||||
|
|
||||||
|
np_durations = []
|
||||||
|
vec_durations = []
|
||||||
|
from random import randrange
|
||||||
|
|
||||||
|
def test_all():
|
||||||
|
for idx, query in tqdm.tqdm(enumerate(queries)):
|
||||||
|
#k = randrange(20, 1000)
|
||||||
|
#k = 500
|
||||||
|
k = 10
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
np_ids, np_distances = topk(query, base, k=k)
|
||||||
|
np_durations.append(time.time() - t0)
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
rows = db.execute('select rowid, distance from v where a match ? and k = ?', [query, k]).fetchall()
|
||||||
|
vec_durations.append(time.time() - t0)
|
||||||
|
|
||||||
|
vec_ids = [row[0] for row in rows]
|
||||||
|
vec_distances = [row[1] for row in rows]
|
||||||
|
|
||||||
|
assert vec_distances == np_distances.tolist()
|
||||||
|
#assert vec_ids == np_ids.tolist()
|
||||||
|
#if (vec_ids != np_ids).any():
|
||||||
|
# print('idx', idx)
|
||||||
|
# print('query', query)
|
||||||
|
# print('np_ids', np_ids)
|
||||||
|
# print('np_distances', np_distances)
|
||||||
|
# print('vec_ids', vec_ids)
|
||||||
|
# print('vec_distances', vec_distances)
|
||||||
|
# raise Exception(idx)
|
||||||
|
|
||||||
|
print('final', 'np' ,np.mean(np_durations), 'vec', np.mean(vec_durations))
|
||||||
16
tests/leak-fixtures/each.sql
Normal file
16
tests/leak-fixtures/each.sql
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
.load dist/vec0
|
||||||
|
.mode box
|
||||||
|
.header on
|
||||||
|
.eqp on
|
||||||
|
.echo on
|
||||||
|
|
||||||
|
select sqlite_version(), vec_version();
|
||||||
|
|
||||||
|
select * from vec_each('[1,2,3]');
|
||||||
|
|
||||||
|
select *
|
||||||
|
from json_each('[
|
||||||
|
[1,2,3,4],
|
||||||
|
[1,2,3,4]
|
||||||
|
]')
|
||||||
|
join vec_each(json_each.value);
|
||||||
61
tests/leak-fixtures/knn.sql
Normal file
61
tests/leak-fixtures/knn.sql
Normal file
|
|
@ -0,0 +1,61 @@
|
||||||
|
.load dist/vec0
|
||||||
|
.mode box
|
||||||
|
.header on
|
||||||
|
.eqp on
|
||||||
|
.echo on
|
||||||
|
|
||||||
|
select sqlite_version(), vec_version();
|
||||||
|
|
||||||
|
create virtual table v using vec0(a float[1], chunk_size=8);
|
||||||
|
|
||||||
|
insert into v
|
||||||
|
select value, format('[%f]', value / 100.0)
|
||||||
|
from generate_series(1, 100);
|
||||||
|
|
||||||
|
select
|
||||||
|
rowid,
|
||||||
|
vec_to_json(a)
|
||||||
|
from v
|
||||||
|
where a match '[.3]'
|
||||||
|
and k = 2;
|
||||||
|
|
||||||
|
select
|
||||||
|
rowid,
|
||||||
|
vec_to_json(a)
|
||||||
|
from v
|
||||||
|
where a match '[.3]'
|
||||||
|
and k = 0;
|
||||||
|
|
||||||
|
|
||||||
|
select
|
||||||
|
rowid,
|
||||||
|
vec_to_json(a)
|
||||||
|
from v
|
||||||
|
where a match '[2.0]'
|
||||||
|
and k = 2
|
||||||
|
and rowid in (1,2,3,4,5);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
with queries as (
|
||||||
|
select
|
||||||
|
rowid as query_id,
|
||||||
|
json_array(value / 100.0) as value
|
||||||
|
from generate_series(24, 39)
|
||||||
|
)
|
||||||
|
select
|
||||||
|
query_id,
|
||||||
|
rowid,
|
||||||
|
distance,
|
||||||
|
vec_to_json(a)
|
||||||
|
from queries, v
|
||||||
|
where a match queries.value
|
||||||
|
and k =5;
|
||||||
|
|
||||||
|
|
||||||
|
select *
|
||||||
|
from v
|
||||||
|
where rowid in (1,2,3,4);
|
||||||
|
|
||||||
|
drop table v;
|
||||||
|
|
||||||
|
|
@ -113,11 +113,33 @@ FUNCTIONS = [
|
||||||
"vec_quantize_i8",
|
"vec_quantize_i8",
|
||||||
"vec_quantize_i8",
|
"vec_quantize_i8",
|
||||||
"vec_slice",
|
"vec_slice",
|
||||||
|
"vec_static_blob_from_raw",
|
||||||
"vec_sub",
|
"vec_sub",
|
||||||
"vec_to_json",
|
"vec_to_json",
|
||||||
"vec_version",
|
"vec_version",
|
||||||
]
|
]
|
||||||
MODULES = ["vec0", "vec_each", "vec_npy_each"]
|
MODULES = [
|
||||||
|
"vec0",
|
||||||
|
"vec_each",
|
||||||
|
"vec_npy_each",
|
||||||
|
"vec_static_blob_entries",
|
||||||
|
"vec_static_blobs",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="TODO")
|
||||||
|
def test_vec_static_blob_from_raw():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="TODO")
|
||||||
|
def test_vec_static_blobs():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="TODO")
|
||||||
|
def test_vec_static_blob_entries():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def test_funcs():
|
def test_funcs():
|
||||||
|
|
@ -420,6 +442,11 @@ def test_vec_slice():
|
||||||
):
|
):
|
||||||
vec_slice(b"\xab\xab\xab\xab", 1, 0)
|
vec_slice(b"\xab\xab\xab\xab", 1, 0)
|
||||||
|
|
||||||
|
with _raises(
|
||||||
|
"slice 'start' index is equal to the 'end' index, vectors must have non-zero length"
|
||||||
|
):
|
||||||
|
vec_slice(b"\xab\xab\xab\xab", 0, 0)
|
||||||
|
|
||||||
|
|
||||||
def test_vec_add():
|
def test_vec_add():
|
||||||
vec_add = lambda *args, a="?", b="?": db.execute(
|
vec_add = lambda *args, a="?", b="?": db.execute(
|
||||||
|
|
@ -775,6 +802,7 @@ def test_vec0_drops():
|
||||||
"t1_vector_chunks00",
|
"t1_vector_chunks00",
|
||||||
"t1_vector_chunks01",
|
"t1_vector_chunks01",
|
||||||
]
|
]
|
||||||
|
|
||||||
db.execute("drop table t1")
|
db.execute("drop table t1")
|
||||||
assert [
|
assert [
|
||||||
row["name"]
|
row["name"]
|
||||||
|
|
@ -1175,7 +1203,8 @@ def test_vec0_text_pk():
|
||||||
create virtual table t using vec0(
|
create virtual table t using vec0(
|
||||||
t_id text primary key,
|
t_id text primary key,
|
||||||
aaa float[1],
|
aaa float[1],
|
||||||
bbb float8[1]
|
bbb float8[1],
|
||||||
|
chunk_size=8
|
||||||
);
|
);
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
@ -1226,6 +1255,39 @@ def test_vec0_text_pk():
|
||||||
db.execute("select * from t")
|
db.execute("select * from t")
|
||||||
db.set_authorizer(None)
|
db.set_authorizer(None)
|
||||||
|
|
||||||
|
assert execute_all(
|
||||||
|
db, "select t_id, distance from t where aaa match ? and k = 3", ["[.01]"]
|
||||||
|
) == [
|
||||||
|
{
|
||||||
|
"t_id": "t_1",
|
||||||
|
"distance": 0.09000000357627869,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"t_id": "t_2",
|
||||||
|
"distance": 0.1899999976158142,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"t_id": "t_3",
|
||||||
|
"distance": 0.2900000214576721,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
if SUPPORTS_VTAB_IN:
|
||||||
|
assert execute_all(
|
||||||
|
db,
|
||||||
|
"select t_id, distance from t where aaa match ? and k = 3 and t_id in ('t_2', 't_3')",
|
||||||
|
["[.01]"],
|
||||||
|
) == [
|
||||||
|
{
|
||||||
|
"t_id": "t_2",
|
||||||
|
"distance": 0.1899999976158142,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"t_id": "t_3",
|
||||||
|
"distance": 0.2900000214576721,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_vec0_best_index():
|
def test_vec0_best_index():
|
||||||
db = connect(EXT_PATH)
|
db = connect(EXT_PATH)
|
||||||
|
|
@ -1679,6 +1741,214 @@ def test_vec0_create_errors():
|
||||||
db.set_authorizer(None)
|
db.set_authorizer(None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_vec0_knn():
|
||||||
|
db = connect(EXT_PATH)
|
||||||
|
db.execute(
|
||||||
|
"""
|
||||||
|
create virtual table v using vec0(
|
||||||
|
aaa float[8],
|
||||||
|
bbb int8[8],
|
||||||
|
ccc bit[8],
|
||||||
|
chunk_size=8
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
with _raises(
|
||||||
|
'Query vector on the "aaa" column is invalid: Input must have type BLOB (compact format) or TEXT (JSON), found NULL'
|
||||||
|
):
|
||||||
|
db.execute("select * from v where aaa match NULL and k = 10")
|
||||||
|
|
||||||
|
with _raises(
|
||||||
|
'Query vector for the "aaa" column is expected to be of type float32, but a bit vector was provided.'
|
||||||
|
):
|
||||||
|
db.execute("select * from v where aaa match vec_bit(X'AA') and k = 10")
|
||||||
|
|
||||||
|
with _raises(
|
||||||
|
'Dimension mismatch for inserted vector for the "aaa" column. Expected 8 dimensions but received 1.'
|
||||||
|
):
|
||||||
|
db.execute("select * from v where aaa match vec_f32('[.1]') and k = 10")
|
||||||
|
|
||||||
|
qaaa = json.dumps([0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01])
|
||||||
|
with _raises("k value in knn queries must be greater than or equal to 0."):
|
||||||
|
db.execute("select * from v where aaa match vec_f32(?) and k = -1", [qaaa])
|
||||||
|
|
||||||
|
assert (
|
||||||
|
execute_all(db, "select * from v where aaa match vec_f32(?) and k = 0", [qaaa])
|
||||||
|
== []
|
||||||
|
)
|
||||||
|
|
||||||
|
# EVIDENCE-OF: V06942_23781
|
||||||
|
db.set_authorizer(authorizer_deny_on(sqlite3.SQLITE_READ, "v_chunks", "chunk_id"))
|
||||||
|
with _raises(
|
||||||
|
"Error preparing stmtChunk: access to v_chunks.chunk_id is prohibited",
|
||||||
|
sqlite3.DatabaseError,
|
||||||
|
):
|
||||||
|
db.execute("select * from v where aaa match vec_f32(?) and k = 5", [qaaa])
|
||||||
|
db.set_authorizer(None)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
execute_all(db, "select * from v where aaa match vec_f32(?) and k = 5", [qaaa])
|
||||||
|
== []
|
||||||
|
)
|
||||||
|
|
||||||
|
db.executemany(
|
||||||
|
"""
|
||||||
|
INSERT INTO v VALUES
|
||||||
|
(:id, :vector, vec_quantize_i8(:vector, 'unit') ,vec_quantize_binary(:vector));
|
||||||
|
""",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": i,
|
||||||
|
"vector": json.dumps(
|
||||||
|
[
|
||||||
|
i * 0.01,
|
||||||
|
i * 0.01,
|
||||||
|
i * 0.01,
|
||||||
|
i * 0.01,
|
||||||
|
i * 0.01,
|
||||||
|
i * 0.01,
|
||||||
|
i * 0.01,
|
||||||
|
i * 0.01,
|
||||||
|
]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
for i in range(24)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert execute_all(
|
||||||
|
db, "select rowid from v where aaa match vec_f32(?) and k = 9", [qaaa]
|
||||||
|
) == [
|
||||||
|
{"rowid": 1},
|
||||||
|
{"rowid": 2}, # ordering of 2 and 0 here depends on if min_idx uses < or <=
|
||||||
|
{"rowid": 0}, #
|
||||||
|
{"rowid": 3},
|
||||||
|
{"rowid": 4},
|
||||||
|
{"rowid": 5},
|
||||||
|
{"rowid": 6},
|
||||||
|
{"rowid": 7},
|
||||||
|
{"rowid": 8},
|
||||||
|
]
|
||||||
|
# TODO separate test, DELETE FROM WHERE rowid in (...) is fullscan that calls vec0Rowid. try on text PKs
|
||||||
|
db.execute("delete from v where rowid in (1, 0, 8, 9)")
|
||||||
|
assert execute_all(
|
||||||
|
db, "select rowid from v where aaa match vec_f32(?) and k = 9", [qaaa]
|
||||||
|
) == [
|
||||||
|
{"rowid": 2},
|
||||||
|
{"rowid": 3},
|
||||||
|
{"rowid": 4},
|
||||||
|
{"rowid": 5},
|
||||||
|
{"rowid": 6},
|
||||||
|
{"rowid": 7},
|
||||||
|
{"rowid": 10},
|
||||||
|
{"rowid": 11},
|
||||||
|
{"rowid": 12},
|
||||||
|
]
|
||||||
|
|
||||||
|
# EVIDENCE-OF: V05271_22109 vec0 knn validates chunk size
|
||||||
|
db.commit()
|
||||||
|
db.execute("BEGIN")
|
||||||
|
db.execute("update v_chunks set validity = zeroblob(100)")
|
||||||
|
with _raises("chunk validity size doesn't match - expected 1, found 100"):
|
||||||
|
db.execute("select * from v where aaa match ? and k = 2", [qaaa])
|
||||||
|
db.rollback()
|
||||||
|
|
||||||
|
# EVIDENCE-OF: V02796_19635 vec0 knn validates rowids size
|
||||||
|
db.commit()
|
||||||
|
db.execute("BEGIN")
|
||||||
|
db.execute("update v_chunks set rowids = zeroblob(100)")
|
||||||
|
with _raises("chunk rowids size doesn't match - expected 64, found 100"):
|
||||||
|
db.execute("select * from v where aaa match ? and k = 2", [qaaa])
|
||||||
|
db.rollback()
|
||||||
|
|
||||||
|
# EVIDENCE-OF: V16465_00535 vec0 knn validates vector chunk size
|
||||||
|
db.commit()
|
||||||
|
db.execute("BEGIN")
|
||||||
|
db.execute("update v_vector_chunks00 set vectors = zeroblob(100)")
|
||||||
|
with _raises("vectors blob size doesn't match - expected 256, found 100"):
|
||||||
|
db.execute("select * from v where aaa match ? and k = 2", [qaaa])
|
||||||
|
db.rollback()
|
||||||
|
|
||||||
|
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
|
|
||||||
|
def np_distance_l2(
|
||||||
|
vec: npt.NDArray[np.float32], mat: npt.NDArray[np.float32]
|
||||||
|
) -> npt.NDArray[np.float32]:
|
||||||
|
return np.sqrt(np.sum((mat - vec) ** 2, axis=1))
|
||||||
|
|
||||||
|
|
||||||
|
def np_topk(
|
||||||
|
vec: npt.NDArray[np.float32],
|
||||||
|
mat: npt.NDArray[np.float32],
|
||||||
|
k: int = 5,
|
||||||
|
) -> tuple[npt.NDArray[np.int32], npt.NDArray[np.float32]]:
|
||||||
|
distances = np_distance_l2(vec, mat)
|
||||||
|
# Rather than sorting all similarities and taking the top K, it's faster to
|
||||||
|
# argpartition and then just sort the top K.
|
||||||
|
# The difference is O(N logN) vs O(N + k logk)
|
||||||
|
indices = np.argpartition(distances, kth=k)[:k]
|
||||||
|
top_indices = indices[np.argsort(distances[indices])]
|
||||||
|
return top_indices, distances[top_indices]
|
||||||
|
|
||||||
|
|
||||||
|
# import faiss
|
||||||
|
@pytest.mark.skip(reason="TODO")
|
||||||
|
def test_correctness_npy():
|
||||||
|
db = connect(EXT_PATH)
|
||||||
|
np.random.seed(420 + 1 + 2)
|
||||||
|
mat = np.random.uniform(low=-1.0, high=1.0, size=(10000, 24)).astype(np.float32)
|
||||||
|
queries = np.random.uniform(low=-1.0, high=1.0, size=(1000, 24)).astype(np.float32)
|
||||||
|
|
||||||
|
# sqlite-vec with vec0
|
||||||
|
db.execute("create virtual table v using vec0(a float[24], chunk_size=8)")
|
||||||
|
for v in mat:
|
||||||
|
db.execute("insert into v(a) values (?)", [v])
|
||||||
|
|
||||||
|
# sqlite-vec with scalar functions
|
||||||
|
db.execute("create table t(a float[24])")
|
||||||
|
for v in mat:
|
||||||
|
db.execute("insert into t(a) values (?)", [v])
|
||||||
|
|
||||||
|
faiss_index = faiss.IndexFlatL2(24)
|
||||||
|
faiss_index.add(mat)
|
||||||
|
|
||||||
|
k = 10000 - 1
|
||||||
|
for idx, q in enumerate(queries):
|
||||||
|
print(idx)
|
||||||
|
result = execute_all(
|
||||||
|
db,
|
||||||
|
"select rowid - 1 as idx, distance from v where a match ? and k = ?",
|
||||||
|
[q, k],
|
||||||
|
)
|
||||||
|
vec_vtab_rowids = [row["idx"] for row in result]
|
||||||
|
vec_vtab_distances = [row["distance"] for row in result]
|
||||||
|
|
||||||
|
result = execute_all(
|
||||||
|
db,
|
||||||
|
"select rowid - 1 as idx, vec_distance_l2(a, ?) as distance from t order by 2 limit ?",
|
||||||
|
[q, k],
|
||||||
|
)
|
||||||
|
vec_scalar_rowids = [row["idx"] for row in result]
|
||||||
|
vec_scalar_distances = [row["distance"] for row in result]
|
||||||
|
assert vec_scalar_rowids == vec_vtab_rowids
|
||||||
|
assert vec_scalar_distances == vec_vtab_distances
|
||||||
|
|
||||||
|
faiss_distances, faiss_rowids = faiss_index.search(np.array([q]), k)
|
||||||
|
faiss_distances = np.sqrt(faiss_distances)
|
||||||
|
assert faiss_rowids[0].tolist() == vec_scalar_rowids
|
||||||
|
assert faiss_distances[0].tolist() == vec_scalar_distances
|
||||||
|
|
||||||
|
assert faiss_distances[0].tolist() == vec_vtab_distances
|
||||||
|
assert faiss_rowids[0].tolist() == vec_vtab_rowids
|
||||||
|
|
||||||
|
np_rowids, np_distances = np_topk(mat, q, k=k)
|
||||||
|
# assert vec_vtab_rowids == np_rowids.tolist()
|
||||||
|
# assert vec_vtab_distances == np_distances.tolist()
|
||||||
|
|
||||||
|
|
||||||
def test_smoke():
|
def test_smoke():
|
||||||
db.execute("create virtual table vec_xyz using vec0( a float[2] )")
|
db.execute("create virtual table vec_xyz using vec0( a float[2] )")
|
||||||
assert execute_all(
|
assert execute_all(
|
||||||
|
|
@ -1833,20 +2103,15 @@ def test_vec0_stress_small_chunks():
|
||||||
"distance": 0.0,
|
"distance": 0.0,
|
||||||
"rowid": 500,
|
"rowid": 500,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"a": _f32([499 * 0.1] * 8),
|
|
||||||
"distance": 0.2828384041786194,
|
|
||||||
"rowid": 499,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"a": _f32([501 * 0.1] * 8),
|
"a": _f32([501 * 0.1] * 8),
|
||||||
"distance": 0.2828384041786194,
|
"distance": 0.2828384041786194,
|
||||||
"rowid": 501,
|
"rowid": 501,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"a": _f32([498 * 0.1] * 8),
|
"a": _f32([499 * 0.1] * 8),
|
||||||
"distance": 0.5656875967979431,
|
"distance": 0.2828384041786194,
|
||||||
"rowid": 498,
|
"rowid": 499,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"a": _f32([502 * 0.1] * 8),
|
"a": _f32([502 * 0.1] * 8),
|
||||||
|
|
@ -1854,15 +2119,20 @@ def test_vec0_stress_small_chunks():
|
||||||
"rowid": 502,
|
"rowid": 502,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"a": _f32([497 * 0.1] * 8),
|
"a": _f32([498 * 0.1] * 8),
|
||||||
"distance": 0.8485260009765625,
|
"distance": 0.5656875967979431,
|
||||||
"rowid": 497,
|
"rowid": 498,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"a": _f32([503 * 0.1] * 8),
|
"a": _f32([503 * 0.1] * 8),
|
||||||
"distance": 0.8485260009765625,
|
"distance": 0.8485260009765625,
|
||||||
"rowid": 503,
|
"rowid": 503,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"a": _f32([497 * 0.1] * 8),
|
||||||
|
"distance": 0.8485260009765625,
|
||||||
|
"rowid": 497,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"a": _f32([496 * 0.1] * 8),
|
"a": _f32([496 * 0.1] * 8),
|
||||||
"distance": 1.1313751935958862,
|
"distance": 1.1313751935958862,
|
||||||
|
|
|
||||||
|
|
@ -17,9 +17,55 @@ fn _min_idx(distances: Vec<f32>, k: i32) -> Vec<i32> {
|
||||||
out
|
out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn _merge_sorted_lists(
|
||||||
|
a: &Vec<f32>,
|
||||||
|
a_rowids: &Vec<i64>,
|
||||||
|
b: &Vec<f32>,
|
||||||
|
b_rowids: &Vec<i64>,
|
||||||
|
b_top_idx: &Vec<i32>,
|
||||||
|
n: usize,
|
||||||
|
) -> (Vec<i64>, Vec<f32>) {
|
||||||
|
let mut out_used: i64 = 0;
|
||||||
|
let mut out: Vec<f32> = Vec::with_capacity(n);
|
||||||
|
let mut out_rowids: Vec<i64> = Vec::with_capacity(n);
|
||||||
|
unsafe {
|
||||||
|
merge_sorted_lists(
|
||||||
|
a.as_ptr().cast(),
|
||||||
|
a_rowids.as_ptr().cast(),
|
||||||
|
a.len() as i64,
|
||||||
|
b.as_ptr().cast(),
|
||||||
|
b_rowids.as_ptr().cast(),
|
||||||
|
b_top_idx.as_ptr().cast(),
|
||||||
|
b.len() as i64,
|
||||||
|
out.as_ptr().cast(),
|
||||||
|
out_rowids.as_ptr().cast(),
|
||||||
|
n as i64,
|
||||||
|
&mut out_used,
|
||||||
|
);
|
||||||
|
out.set_len(out_used as usize);
|
||||||
|
out_rowids.set_len(out_used as usize);
|
||||||
|
}
|
||||||
|
|
||||||
|
(out_rowids, out)
|
||||||
|
}
|
||||||
|
|
||||||
#[link(name = "sqlite-vec-internal")]
|
#[link(name = "sqlite-vec-internal")]
|
||||||
extern "C" {
|
extern "C" {
|
||||||
fn min_idx(distances: *const f32, n: i32, out: *mut i32, k: i32) -> i32;
|
fn min_idx(distances: *const f32, n: i32, out: *mut i32, k: i32) -> i32;
|
||||||
|
|
||||||
|
fn merge_sorted_lists(
|
||||||
|
a: *const f32,
|
||||||
|
a_rowids: *const i64,
|
||||||
|
a_length: i64,
|
||||||
|
b: *const f32,
|
||||||
|
b_rowids: *const i64,
|
||||||
|
b_top_idx: *const i32,
|
||||||
|
b_length: i64,
|
||||||
|
out: *const f32,
|
||||||
|
out_rowids: *const i64,
|
||||||
|
out_length: i64,
|
||||||
|
out_used: *mut i64,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
@ -34,4 +80,85 @@ mod tests {
|
||||||
assert_eq!(_min_idx(vec![1.0, 2.0, 3.0], 2), vec![0, 1]);
|
assert_eq!(_min_idx(vec![1.0, 2.0, 3.0], 2), vec![0, 1]);
|
||||||
assert_eq!(_min_idx(vec![3.0, 2.0, 1.0], 2), vec![2, 1]);
|
assert_eq!(_min_idx(vec![3.0, 2.0, 1.0], 2), vec![2, 1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_merge_sorted_lists() {
|
||||||
|
let a = &vec![0.01, 0.02, 0.03];
|
||||||
|
let a_rowids = &vec![1, 2, 3];
|
||||||
|
|
||||||
|
//let b = &vec![0.1, 0.2, 0.3, 0.4];
|
||||||
|
//let b_rowids = &vec![4, 5, 6, 7];
|
||||||
|
let b = &vec![0.4, 0.2, 0.3, 0.1];
|
||||||
|
let b_rowids = &vec![7, 5, 6, 4];
|
||||||
|
let b_top_idx = &vec![3, 1, 2, 0];
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
_merge_sorted_lists(a, a_rowids, b, b_rowids, b_top_idx, 0),
|
||||||
|
(vec![], vec![])
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
_merge_sorted_lists(a, a_rowids, b, b_rowids, b_top_idx, 1),
|
||||||
|
(vec![1], vec![0.01])
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
_merge_sorted_lists(a, a_rowids, b, b_rowids, b_top_idx, 2),
|
||||||
|
(vec![1, 2], vec![0.01, 0.02])
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
_merge_sorted_lists(a, a_rowids, b, b_rowids, b_top_idx, 3),
|
||||||
|
(vec![1, 2, 3], vec![0.01, 0.02, 0.03])
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
_merge_sorted_lists(a, a_rowids, b, b_rowids, b_top_idx, 4),
|
||||||
|
(vec![1, 2, 3, 4], vec![0.01, 0.02, 0.03, 0.1])
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
_merge_sorted_lists(a, a_rowids, b, b_rowids, b_top_idx, 5),
|
||||||
|
(vec![1, 2, 3, 4, 5], vec![0.01, 0.02, 0.03, 0.1, 0.2])
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
_merge_sorted_lists(a, a_rowids, b, b_rowids, b_top_idx, 6),
|
||||||
|
(
|
||||||
|
vec![1, 2, 3, 4, 5, 6],
|
||||||
|
vec![0.01, 0.02, 0.03, 0.1, 0.2, 0.3]
|
||||||
|
)
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
_merge_sorted_lists(a, a_rowids, b, b_rowids, b_top_idx, 7),
|
||||||
|
(
|
||||||
|
vec![1, 2, 3, 4, 5, 6, 7],
|
||||||
|
vec![0.01, 0.02, 0.03, 0.1, 0.2, 0.3, 0.4]
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
_merge_sorted_lists(a, a_rowids, b, b_rowids, b_top_idx, 8),
|
||||||
|
(
|
||||||
|
vec![1, 2, 3, 4, 5, 6, 7],
|
||||||
|
vec![0.01, 0.02, 0.03, 0.1, 0.2, 0.3, 0.4]
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
#[test]
|
||||||
|
fn test_merge_sorted_lists_empty() {
|
||||||
|
let x = vec![0.1, 0.2, 0.3];
|
||||||
|
let x_rowids = vec![666, 888, 777];
|
||||||
|
assert_eq!(
|
||||||
|
_merge_sorted_lists(&x, &x_rowids, &vec![], &vec![], 3),
|
||||||
|
(vec![666, 888, 777], vec![0.1, 0.2, 0.3])
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
_merge_sorted_lists(&vec![], &vec![], &x, &x_rowids, 3),
|
||||||
|
(vec![666, 888, 777], vec![0.1, 0.2, 0.3])
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
_merge_sorted_lists(&vec![], &vec![], &x, &x_rowids, 4),
|
||||||
|
(vec![666, 888, 777], vec![0.1, 0.2, 0.3])
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
_merge_sorted_lists(&vec![], &vec![], &x, &x_rowids, 2),
|
||||||
|
(vec![666, 888], vec![0.1, 0.2])
|
||||||
|
);
|
||||||
|
}*/
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue