mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 00:36:56 +02:00
static updates
This commit is contained in:
parent
e91ccf38ff
commit
a0bc9404ce
3 changed files with 323 additions and 231 deletions
|
|
@ -80,8 +80,6 @@ def connect(ext, path=":memory:", extra_entrypoint=None):
|
|||
|
||||
db = connect(EXT_PATH)
|
||||
|
||||
# db.load_extension(EXT_PATH, entrypoint="trace_debug")
|
||||
|
||||
|
||||
def explain_query_plan(sql):
|
||||
return db.execute("explain query plan " + sql).fetchone()["detail"]
|
||||
|
|
@ -113,7 +111,6 @@ FUNCTIONS = [
|
|||
"vec_quantize_binary",
|
||||
"vec_quantize_int8",
|
||||
"vec_slice",
|
||||
"vec_static_blob_from_raw",
|
||||
"vec_sub",
|
||||
"vec_to_json",
|
||||
"vec_type",
|
||||
|
|
@ -123,24 +120,150 @@ MODULES = [
|
|||
"vec0",
|
||||
"vec_each",
|
||||
"vec_npy_each",
|
||||
"vec_static_blob_entries",
|
||||
"vec_static_blobs",
|
||||
#"vec_static_blob_entries",
|
||||
#"vec_static_blobs",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="TODO")
|
||||
def test_vec_static_blob_from_raw():
|
||||
pass
|
||||
|
||||
def register_numpy(db, name: str, array):
|
||||
ptr = array.__array_interface__["data"][0]
|
||||
nvectors, dimensions = array.__array_interface__["shape"]
|
||||
element_type = array.__array_interface__["typestr"]
|
||||
|
||||
assert element_type == "<f4"
|
||||
|
||||
name_escaped = db.execute("select printf('%w', ?)", [name]).fetchone()[0]
|
||||
|
||||
db.execute(
|
||||
"""
|
||||
insert into temp.vec_static_blobs(name, data)
|
||||
select ?, vec_static_blob_from_raw(?, ?, ?, ?)
|
||||
""",
|
||||
[name, ptr, element_type, dimensions, nvectors],
|
||||
)
|
||||
|
||||
db.execute(
|
||||
f'create virtual table "{name_escaped}" using vec_static_blob_entries({name_escaped})'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="TODO")
|
||||
def test_vec_static_blobs():
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="TODO")
|
||||
def test_vec_static_blob_entries():
|
||||
pass
|
||||
db = connect(EXT_PATH, extra_entrypoint="sqlite3_vec_static_blobs_init")
|
||||
|
||||
x = np.array([[0.1, 0.2, 0.3, 0.4], [0.9, 0.8, 0.7, 0.6]], dtype=np.float32)
|
||||
y = np.array([[0.2, 0.3], [0.9, 0.8], [0.6, 0.5]], dtype=np.float32)
|
||||
z = np.array(
|
||||
[
|
||||
[0.1, 0.1, 0.1, 0.1],
|
||||
[0.2, 0.2, 0.2, 0.2],
|
||||
[0.3, 0.3, 0.3, 0.3],
|
||||
[0.4, 0.4, 0.4, 0.4],
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
register_numpy(db, "x", x)
|
||||
register_numpy(db, "y", y)
|
||||
register_numpy(db, "z", z)
|
||||
assert execute_all(
|
||||
db, "select *, dimensions, count from temp.vec_static_blobs;"
|
||||
) == [
|
||||
{
|
||||
"count": 2,
|
||||
"data": None,
|
||||
"dimensions": 4,
|
||||
"name": "x",
|
||||
},
|
||||
{
|
||||
"count": 3,
|
||||
"data": None,
|
||||
"dimensions": 2,
|
||||
"name": "y",
|
||||
},
|
||||
{
|
||||
"count": 5,
|
||||
"data": None,
|
||||
"dimensions": 4,
|
||||
"name": "z",
|
||||
},
|
||||
]
|
||||
|
||||
assert execute_all(db, "select vec_to_json(vector) from x;") == [
|
||||
{
|
||||
"vec_to_json(vector)": "[0.100000,0.200000,0.300000,0.400000]",
|
||||
},
|
||||
{
|
||||
"vec_to_json(vector)": "[0.900000,0.800000,0.700000,0.600000]",
|
||||
},
|
||||
]
|
||||
assert execute_all(db, "select (vector) from y limit 2;") == [
|
||||
{
|
||||
"vector": b"\xcd\xccL>\x9a\x99\x99>",
|
||||
},
|
||||
{
|
||||
"vector": b"fff?\xcd\xccL?",
|
||||
},
|
||||
]
|
||||
assert execute_all(db, "select rowid, (vector) from z") == [
|
||||
{
|
||||
"rowid": 0,
|
||||
"vector": b"\xcd\xcc\xcc=\xcd\xcc\xcc=\xcd\xcc\xcc=\xcd\xcc\xcc=",
|
||||
},
|
||||
{
|
||||
"rowid": 1,
|
||||
"vector": b"\xcd\xccL>\xcd\xccL>\xcd\xccL>\xcd\xccL>",
|
||||
},
|
||||
{
|
||||
"rowid": 2,
|
||||
"vector": b"\x9a\x99\x99>\x9a\x99\x99>\x9a\x99\x99>\x9a\x99\x99>",
|
||||
},
|
||||
{
|
||||
"rowid": 3,
|
||||
"vector": b"\xcd\xcc\xcc>\xcd\xcc\xcc>\xcd\xcc\xcc>\xcd\xcc\xcc>",
|
||||
},
|
||||
{
|
||||
"rowid": 4,
|
||||
"vector": b"\x00\x00\x00?\x00\x00\x00?\x00\x00\x00?\x00\x00\x00?",
|
||||
},
|
||||
]
|
||||
assert execute_all(
|
||||
db,
|
||||
"select rowid, vec_to_json(vector) as v from z where vector match ? and k = 3 order by distance;",
|
||||
[np.array([0.3, 0.3, 0.3, 0.3], dtype=np.float32)],
|
||||
) == [
|
||||
{
|
||||
"rowid": 2,
|
||||
"v": "[0.300000,0.300000,0.300000,0.300000]",
|
||||
},
|
||||
{
|
||||
"rowid": 3,
|
||||
"v": "[0.400000,0.400000,0.400000,0.400000]",
|
||||
},
|
||||
{
|
||||
"rowid": 1,
|
||||
"v": "[0.200000,0.200000,0.200000,0.200000]",
|
||||
},
|
||||
]
|
||||
assert execute_all(
|
||||
db,
|
||||
"select rowid, vec_to_json(vector) as v from z where vector match ? and k = 3 order by distance;",
|
||||
[np.array([0.6, 0.6, 0.6, 0.6], dtype=np.float32)],
|
||||
) == [
|
||||
{
|
||||
"rowid": 4,
|
||||
"v": "[0.500000,0.500000,0.500000,0.500000]",
|
||||
},
|
||||
{
|
||||
"rowid": 3,
|
||||
"v": "[0.400000,0.400000,0.400000,0.400000]",
|
||||
},
|
||||
{
|
||||
"rowid": 2,
|
||||
"v": "[0.300000,0.300000,0.300000,0.300000]",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_funcs():
|
||||
|
|
@ -1872,7 +1995,7 @@ def test_vec0_knn():
|
|||
db.execute("select * from v where aaa match vec_bit(X'AA') and k = 10")
|
||||
|
||||
with _raises(
|
||||
'Dimension mismatch for inserted vector for the "aaa" column. Expected 8 dimensions but received 1.'
|
||||
'Dimension mismatch for query vector for the "aaa" column. Expected 8 dimensions but received 1.'
|
||||
):
|
||||
db.execute("select * from v where aaa match vec_f32('[.1]') and k = 10")
|
||||
|
||||
|
|
@ -2120,36 +2243,42 @@ def test_smoke():
|
|||
|
||||
db.execute("insert into vec_xyz(rowid, a) select 2, X'0000000000000040'")
|
||||
chunk = db.execute("select * from vec_xyz_chunks").fetchone()
|
||||
assert chunk[
|
||||
"rowids"
|
||||
] == b"\x01\x00\x00\x00\x00\x00\x00\x00" + b"\x02\x00\x00\x00\x00\x00\x00\x00" + bytearray(
|
||||
int(1024 * 8) - 8 * 2
|
||||
assert (
|
||||
chunk["rowids"]
|
||||
== b"\x01\x00\x00\x00\x00\x00\x00\x00"
|
||||
+ b"\x02\x00\x00\x00\x00\x00\x00\x00"
|
||||
+ bytearray(int(1024 * 8) - 8 * 2)
|
||||
)
|
||||
assert chunk["chunk_id"] == 1
|
||||
assert chunk["validity"] == b"\x03" + bytearray(int(1024 / 8) - 1)
|
||||
vchunk = db.execute("select * from vec_xyz_vector_chunks00").fetchone()
|
||||
assert vchunk["rowid"] == 1
|
||||
assert vchunk[
|
||||
"vectors"
|
||||
] == b"\x00\x00\x00\x00\x00\x00\x80\x3f" + b"\x00\x00\x00\x00\x00\x00\x00\x40" + bytearray(
|
||||
int(1024 * 4 * 2) - (2 * 4 * 2)
|
||||
assert (
|
||||
vchunk["vectors"]
|
||||
== b"\x00\x00\x00\x00\x00\x00\x80\x3f"
|
||||
+ b"\x00\x00\x00\x00\x00\x00\x00\x40"
|
||||
+ bytearray(int(1024 * 4 * 2) - (2 * 4 * 2))
|
||||
)
|
||||
|
||||
db.execute("insert into vec_xyz(rowid, a) select 3, X'00000000000080bf'")
|
||||
chunk = db.execute("select * from vec_xyz_chunks").fetchone()
|
||||
assert chunk["chunk_id"] == 1
|
||||
assert chunk["validity"] == b"\x07" + bytearray(int(1024 / 8) - 1)
|
||||
assert chunk[
|
||||
"rowids"
|
||||
] == b"\x01\x00\x00\x00\x00\x00\x00\x00" + b"\x02\x00\x00\x00\x00\x00\x00\x00" + b"\x03\x00\x00\x00\x00\x00\x00\x00" + bytearray(
|
||||
int(1024 * 8) - 8 * 3
|
||||
assert (
|
||||
chunk["rowids"]
|
||||
== b"\x01\x00\x00\x00\x00\x00\x00\x00"
|
||||
+ b"\x02\x00\x00\x00\x00\x00\x00\x00"
|
||||
+ b"\x03\x00\x00\x00\x00\x00\x00\x00"
|
||||
+ bytearray(int(1024 * 8) - 8 * 3)
|
||||
)
|
||||
vchunk = db.execute("select * from vec_xyz_vector_chunks00").fetchone()
|
||||
assert vchunk["rowid"] == 1
|
||||
assert vchunk[
|
||||
"vectors"
|
||||
] == b"\x00\x00\x00\x00\x00\x00\x80\x3f" + b"\x00\x00\x00\x00\x00\x00\x00\x40" + b"\x00\x00\x00\x00\x00\x00\x80\xbf" + bytearray(
|
||||
int(1024 * 4 * 2) - (2 * 4 * 3)
|
||||
assert (
|
||||
vchunk["vectors"]
|
||||
== b"\x00\x00\x00\x00\x00\x00\x80\x3f"
|
||||
+ b"\x00\x00\x00\x00\x00\x00\x00\x40"
|
||||
+ b"\x00\x00\x00\x00\x00\x00\x80\xbf"
|
||||
+ bytearray(int(1024 * 4 * 2) - (2 * 4 * 3))
|
||||
)
|
||||
|
||||
# db.execute("select * from vec_xyz")
|
||||
|
|
@ -2192,66 +2321,63 @@ def test_vec0_stress_small_chunks():
|
|||
{"rowid": 994, "a": _f32([99.4] * 8)},
|
||||
{"rowid": 993, "a": _f32([99.3] * 8)},
|
||||
]
|
||||
assert (
|
||||
execute_all(
|
||||
db,
|
||||
"""
|
||||
assert execute_all(
|
||||
db,
|
||||
"""
|
||||
select rowid, a, distance
|
||||
from vec_small
|
||||
where a match ?
|
||||
and k = 9
|
||||
order by distance
|
||||
""",
|
||||
[_f32([50.0] * 8)],
|
||||
)
|
||||
== [
|
||||
{
|
||||
"a": _f32([500 * 0.1] * 8),
|
||||
"distance": 0.0,
|
||||
"rowid": 500,
|
||||
},
|
||||
{
|
||||
"a": _f32([501 * 0.1] * 8),
|
||||
"distance": 0.2828384041786194,
|
||||
"rowid": 501,
|
||||
},
|
||||
{
|
||||
"a": _f32([499 * 0.1] * 8),
|
||||
"distance": 0.2828384041786194,
|
||||
"rowid": 499,
|
||||
},
|
||||
{
|
||||
"a": _f32([502 * 0.1] * 8),
|
||||
"distance": 0.5656875967979431,
|
||||
"rowid": 502,
|
||||
},
|
||||
{
|
||||
"a": _f32([498 * 0.1] * 8),
|
||||
"distance": 0.5656875967979431,
|
||||
"rowid": 498,
|
||||
},
|
||||
{
|
||||
"a": _f32([503 * 0.1] * 8),
|
||||
"distance": 0.8485260009765625,
|
||||
"rowid": 503,
|
||||
},
|
||||
{
|
||||
"a": _f32([497 * 0.1] * 8),
|
||||
"distance": 0.8485260009765625,
|
||||
"rowid": 497,
|
||||
},
|
||||
{
|
||||
"a": _f32([496 * 0.1] * 8),
|
||||
"distance": 1.1313751935958862,
|
||||
"rowid": 496,
|
||||
},
|
||||
{
|
||||
"a": _f32([504 * 0.1] * 8),
|
||||
"distance": 1.1313751935958862,
|
||||
"rowid": 504,
|
||||
},
|
||||
]
|
||||
)
|
||||
[_f32([50.0] * 8)],
|
||||
) == [
|
||||
{
|
||||
"a": _f32([500 * 0.1] * 8),
|
||||
"distance": 0.0,
|
||||
"rowid": 500,
|
||||
},
|
||||
{
|
||||
"a": _f32([501 * 0.1] * 8),
|
||||
"distance": 0.2828384041786194,
|
||||
"rowid": 501,
|
||||
},
|
||||
{
|
||||
"a": _f32([499 * 0.1] * 8),
|
||||
"distance": 0.2828384041786194,
|
||||
"rowid": 499,
|
||||
},
|
||||
{
|
||||
"a": _f32([502 * 0.1] * 8),
|
||||
"distance": 0.5656875967979431,
|
||||
"rowid": 502,
|
||||
},
|
||||
{
|
||||
"a": _f32([498 * 0.1] * 8),
|
||||
"distance": 0.5656875967979431,
|
||||
"rowid": 498,
|
||||
},
|
||||
{
|
||||
"a": _f32([503 * 0.1] * 8),
|
||||
"distance": 0.8485260009765625,
|
||||
"rowid": 503,
|
||||
},
|
||||
{
|
||||
"a": _f32([497 * 0.1] * 8),
|
||||
"distance": 0.8485260009765625,
|
||||
"rowid": 497,
|
||||
},
|
||||
{
|
||||
"a": _f32([496 * 0.1] * 8),
|
||||
"distance": 1.1313751935958862,
|
||||
"rowid": 496,
|
||||
},
|
||||
{
|
||||
"a": _f32([504 * 0.1] * 8),
|
||||
"distance": 1.1313751935958862,
|
||||
"rowid": 504,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_vec0_distance_metric():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue