static updates

This commit is contained in:
Alex Garcia 2024-07-31 12:56:09 -07:00
parent e91ccf38ff
commit a0bc9404ce
3 changed files with 323 additions and 231 deletions

View file

@ -80,8 +80,6 @@ def connect(ext, path=":memory:", extra_entrypoint=None):
db = connect(EXT_PATH)
# db.load_extension(EXT_PATH, entrypoint="trace_debug")
def explain_query_plan(sql):
return db.execute("explain query plan " + sql).fetchone()["detail"]
@ -113,7 +111,6 @@ FUNCTIONS = [
"vec_quantize_binary",
"vec_quantize_int8",
"vec_slice",
"vec_static_blob_from_raw",
"vec_sub",
"vec_to_json",
"vec_type",
@ -123,24 +120,150 @@ MODULES = [
"vec0",
"vec_each",
"vec_npy_each",
"vec_static_blob_entries",
"vec_static_blobs",
#"vec_static_blob_entries",
#"vec_static_blobs",
]
@pytest.mark.skip(reason="TODO")
def test_vec_static_blob_from_raw():
pass
def register_numpy(db, name: str, array):
ptr = array.__array_interface__["data"][0]
nvectors, dimensions = array.__array_interface__["shape"]
element_type = array.__array_interface__["typestr"]
assert element_type == "<f4"
name_escaped = db.execute("select printf('%w', ?)", [name]).fetchone()[0]
db.execute(
"""
insert into temp.vec_static_blobs(name, data)
select ?, vec_static_blob_from_raw(?, ?, ?, ?)
""",
[name, ptr, element_type, dimensions, nvectors],
)
db.execute(
f'create virtual table "{name_escaped}" using vec_static_blob_entries({name_escaped})'
)
@pytest.mark.skip(reason="TODO")
def test_vec_static_blobs():
pass
@pytest.mark.skip(reason="TODO")
def test_vec_static_blob_entries():
pass
db = connect(EXT_PATH, extra_entrypoint="sqlite3_vec_static_blobs_init")
x = np.array([[0.1, 0.2, 0.3, 0.4], [0.9, 0.8, 0.7, 0.6]], dtype=np.float32)
y = np.array([[0.2, 0.3], [0.9, 0.8], [0.6, 0.5]], dtype=np.float32)
z = np.array(
[
[0.1, 0.1, 0.1, 0.1],
[0.2, 0.2, 0.2, 0.2],
[0.3, 0.3, 0.3, 0.3],
[0.4, 0.4, 0.4, 0.4],
[0.5, 0.5, 0.5, 0.5],
],
dtype=np.float32,
)
register_numpy(db, "x", x)
register_numpy(db, "y", y)
register_numpy(db, "z", z)
assert execute_all(
db, "select *, dimensions, count from temp.vec_static_blobs;"
) == [
{
"count": 2,
"data": None,
"dimensions": 4,
"name": "x",
},
{
"count": 3,
"data": None,
"dimensions": 2,
"name": "y",
},
{
"count": 5,
"data": None,
"dimensions": 4,
"name": "z",
},
]
assert execute_all(db, "select vec_to_json(vector) from x;") == [
{
"vec_to_json(vector)": "[0.100000,0.200000,0.300000,0.400000]",
},
{
"vec_to_json(vector)": "[0.900000,0.800000,0.700000,0.600000]",
},
]
assert execute_all(db, "select (vector) from y limit 2;") == [
{
"vector": b"\xcd\xccL>\x9a\x99\x99>",
},
{
"vector": b"fff?\xcd\xccL?",
},
]
assert execute_all(db, "select rowid, (vector) from z") == [
{
"rowid": 0,
"vector": b"\xcd\xcc\xcc=\xcd\xcc\xcc=\xcd\xcc\xcc=\xcd\xcc\xcc=",
},
{
"rowid": 1,
"vector": b"\xcd\xccL>\xcd\xccL>\xcd\xccL>\xcd\xccL>",
},
{
"rowid": 2,
"vector": b"\x9a\x99\x99>\x9a\x99\x99>\x9a\x99\x99>\x9a\x99\x99>",
},
{
"rowid": 3,
"vector": b"\xcd\xcc\xcc>\xcd\xcc\xcc>\xcd\xcc\xcc>\xcd\xcc\xcc>",
},
{
"rowid": 4,
"vector": b"\x00\x00\x00?\x00\x00\x00?\x00\x00\x00?\x00\x00\x00?",
},
]
assert execute_all(
db,
"select rowid, vec_to_json(vector) as v from z where vector match ? and k = 3 order by distance;",
[np.array([0.3, 0.3, 0.3, 0.3], dtype=np.float32)],
) == [
{
"rowid": 2,
"v": "[0.300000,0.300000,0.300000,0.300000]",
},
{
"rowid": 3,
"v": "[0.400000,0.400000,0.400000,0.400000]",
},
{
"rowid": 1,
"v": "[0.200000,0.200000,0.200000,0.200000]",
},
]
assert execute_all(
db,
"select rowid, vec_to_json(vector) as v from z where vector match ? and k = 3 order by distance;",
[np.array([0.6, 0.6, 0.6, 0.6], dtype=np.float32)],
) == [
{
"rowid": 4,
"v": "[0.500000,0.500000,0.500000,0.500000]",
},
{
"rowid": 3,
"v": "[0.400000,0.400000,0.400000,0.400000]",
},
{
"rowid": 2,
"v": "[0.300000,0.300000,0.300000,0.300000]",
},
]
def test_funcs():
@ -1872,7 +1995,7 @@ def test_vec0_knn():
db.execute("select * from v where aaa match vec_bit(X'AA') and k = 10")
with _raises(
'Dimension mismatch for inserted vector for the "aaa" column. Expected 8 dimensions but received 1.'
'Dimension mismatch for query vector for the "aaa" column. Expected 8 dimensions but received 1.'
):
db.execute("select * from v where aaa match vec_f32('[.1]') and k = 10")
@ -2120,36 +2243,42 @@ def test_smoke():
db.execute("insert into vec_xyz(rowid, a) select 2, X'0000000000000040'")
chunk = db.execute("select * from vec_xyz_chunks").fetchone()
assert chunk[
"rowids"
] == b"\x01\x00\x00\x00\x00\x00\x00\x00" + b"\x02\x00\x00\x00\x00\x00\x00\x00" + bytearray(
int(1024 * 8) - 8 * 2
assert (
chunk["rowids"]
== b"\x01\x00\x00\x00\x00\x00\x00\x00"
+ b"\x02\x00\x00\x00\x00\x00\x00\x00"
+ bytearray(int(1024 * 8) - 8 * 2)
)
assert chunk["chunk_id"] == 1
assert chunk["validity"] == b"\x03" + bytearray(int(1024 / 8) - 1)
vchunk = db.execute("select * from vec_xyz_vector_chunks00").fetchone()
assert vchunk["rowid"] == 1
assert vchunk[
"vectors"
] == b"\x00\x00\x00\x00\x00\x00\x80\x3f" + b"\x00\x00\x00\x00\x00\x00\x00\x40" + bytearray(
int(1024 * 4 * 2) - (2 * 4 * 2)
assert (
vchunk["vectors"]
== b"\x00\x00\x00\x00\x00\x00\x80\x3f"
+ b"\x00\x00\x00\x00\x00\x00\x00\x40"
+ bytearray(int(1024 * 4 * 2) - (2 * 4 * 2))
)
db.execute("insert into vec_xyz(rowid, a) select 3, X'00000000000080bf'")
chunk = db.execute("select * from vec_xyz_chunks").fetchone()
assert chunk["chunk_id"] == 1
assert chunk["validity"] == b"\x07" + bytearray(int(1024 / 8) - 1)
assert chunk[
"rowids"
] == b"\x01\x00\x00\x00\x00\x00\x00\x00" + b"\x02\x00\x00\x00\x00\x00\x00\x00" + b"\x03\x00\x00\x00\x00\x00\x00\x00" + bytearray(
int(1024 * 8) - 8 * 3
assert (
chunk["rowids"]
== b"\x01\x00\x00\x00\x00\x00\x00\x00"
+ b"\x02\x00\x00\x00\x00\x00\x00\x00"
+ b"\x03\x00\x00\x00\x00\x00\x00\x00"
+ bytearray(int(1024 * 8) - 8 * 3)
)
vchunk = db.execute("select * from vec_xyz_vector_chunks00").fetchone()
assert vchunk["rowid"] == 1
assert vchunk[
"vectors"
] == b"\x00\x00\x00\x00\x00\x00\x80\x3f" + b"\x00\x00\x00\x00\x00\x00\x00\x40" + b"\x00\x00\x00\x00\x00\x00\x80\xbf" + bytearray(
int(1024 * 4 * 2) - (2 * 4 * 3)
assert (
vchunk["vectors"]
== b"\x00\x00\x00\x00\x00\x00\x80\x3f"
+ b"\x00\x00\x00\x00\x00\x00\x00\x40"
+ b"\x00\x00\x00\x00\x00\x00\x80\xbf"
+ bytearray(int(1024 * 4 * 2) - (2 * 4 * 3))
)
# db.execute("select * from vec_xyz")
@ -2192,66 +2321,63 @@ def test_vec0_stress_small_chunks():
{"rowid": 994, "a": _f32([99.4] * 8)},
{"rowid": 993, "a": _f32([99.3] * 8)},
]
assert (
execute_all(
db,
"""
assert execute_all(
db,
"""
select rowid, a, distance
from vec_small
where a match ?
and k = 9
order by distance
""",
[_f32([50.0] * 8)],
)
== [
{
"a": _f32([500 * 0.1] * 8),
"distance": 0.0,
"rowid": 500,
},
{
"a": _f32([501 * 0.1] * 8),
"distance": 0.2828384041786194,
"rowid": 501,
},
{
"a": _f32([499 * 0.1] * 8),
"distance": 0.2828384041786194,
"rowid": 499,
},
{
"a": _f32([502 * 0.1] * 8),
"distance": 0.5656875967979431,
"rowid": 502,
},
{
"a": _f32([498 * 0.1] * 8),
"distance": 0.5656875967979431,
"rowid": 498,
},
{
"a": _f32([503 * 0.1] * 8),
"distance": 0.8485260009765625,
"rowid": 503,
},
{
"a": _f32([497 * 0.1] * 8),
"distance": 0.8485260009765625,
"rowid": 497,
},
{
"a": _f32([496 * 0.1] * 8),
"distance": 1.1313751935958862,
"rowid": 496,
},
{
"a": _f32([504 * 0.1] * 8),
"distance": 1.1313751935958862,
"rowid": 504,
},
]
)
[_f32([50.0] * 8)],
) == [
{
"a": _f32([500 * 0.1] * 8),
"distance": 0.0,
"rowid": 500,
},
{
"a": _f32([501 * 0.1] * 8),
"distance": 0.2828384041786194,
"rowid": 501,
},
{
"a": _f32([499 * 0.1] * 8),
"distance": 0.2828384041786194,
"rowid": 499,
},
{
"a": _f32([502 * 0.1] * 8),
"distance": 0.5656875967979431,
"rowid": 502,
},
{
"a": _f32([498 * 0.1] * 8),
"distance": 0.5656875967979431,
"rowid": 498,
},
{
"a": _f32([503 * 0.1] * 8),
"distance": 0.8485260009765625,
"rowid": 503,
},
{
"a": _f32([497 * 0.1] * 8),
"distance": 0.8485260009765625,
"rowid": 497,
},
{
"a": _f32([496 * 0.1] * 8),
"distance": 1.1313751935958862,
"rowid": 496,
},
{
"a": _f32([504 * 0.1] * 8),
"distance": 1.1313751935958862,
"rowid": 504,
},
]
def test_vec0_distance_metric():