mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 00:36:56 +02:00
fmt
This commit is contained in:
parent
7a1b14976a
commit
0f5bc2f254
2 changed files with 258 additions and 238 deletions
|
|
@ -513,12 +513,14 @@ def test_vec_slice():
|
|||
|
||||
|
||||
def test_vec_type():
|
||||
vec_type = lambda *args, a="?": db.execute(f"select vec_type({a})", args).fetchone()[0]
|
||||
assert vec_type('[1]') == "float32"
|
||||
vec_type = lambda *args, a="?": db.execute(
|
||||
f"select vec_type({a})", args
|
||||
).fetchone()[0]
|
||||
assert vec_type("[1]") == "float32"
|
||||
assert vec_type(b"\xaa\xbb\xcc\xdd") == "float32"
|
||||
assert vec_type('[1]', a='vec_f32(?)') == "float32"
|
||||
assert vec_type('[1]', a='vec_int8(?)') == "int8"
|
||||
assert vec_type(b"\xaa", a='vec_bit(?)') == "bit"
|
||||
assert vec_type("[1]", a="vec_f32(?)") == "float32"
|
||||
assert vec_type("[1]", a="vec_int8(?)") == "int8"
|
||||
assert vec_type(b"\xaa", a="vec_bit(?)") == "bit"
|
||||
|
||||
with _raises("invalid float32 vector"):
|
||||
vec_type(b"\xaa")
|
||||
|
|
@ -697,7 +699,10 @@ def test_vec0_inserts():
|
|||
db.commit()
|
||||
db.set_authorizer(authorizer_deny_on(sqlite3.SQLITE_INSERT, "t1_rowids"))
|
||||
# EVIDENCE-OF: V04679_21517 vec0 INSERT failed on _rowid shadow insert raises error
|
||||
with _raises("Internal sqlite-vec error: could not initialize 'insert rowids' statement", sqlite3.DatabaseError):
|
||||
with _raises(
|
||||
"Internal sqlite-vec error: could not initialize 'insert rowids' statement",
|
||||
sqlite3.DatabaseError,
|
||||
):
|
||||
db.execute("insert into t1 values (2, '[2,2,2,2]')")
|
||||
db.set_authorizer(None)
|
||||
db.rollback()
|
||||
|
|
@ -1798,7 +1803,7 @@ def test_vec0_create_errors():
|
|||
db.set_authorizer(authorizer_deny_on(sqlite3.SQLITE_READ, "t1_chunks", ""))
|
||||
with _raises(
|
||||
"Internal sqlite-vec error: could not initialize 'latest chunk' statement",
|
||||
sqlite3.DatabaseError
|
||||
sqlite3.DatabaseError,
|
||||
):
|
||||
db.execute("create virtual table t1 using vec0(a float[1])")
|
||||
db.execute("insert into t1(a) values (X'AABBCCDD')")
|
||||
|
|
@ -1808,21 +1813,22 @@ def test_vec0_create_errors():
|
|||
db.execute("BEGIN")
|
||||
db.set_authorizer(authorizer_deny_on(sqlite3.SQLITE_INSERT, "t1_rowids"))
|
||||
with _raises(
|
||||
"Internal sqlite-vec error: could not initialize 'insert rowids id' statement", sqlite3.DatabaseError
|
||||
"Internal sqlite-vec error: could not initialize 'insert rowids id' statement",
|
||||
sqlite3.DatabaseError,
|
||||
):
|
||||
db.execute("create virtual table t1 using vec0(a float[1])")
|
||||
db.execute("insert into t1(a) values (X'AABBCCDD')")
|
||||
db.set_authorizer(None)
|
||||
db.rollback()
|
||||
|
||||
|
||||
db.commit()
|
||||
db.execute("BEGIN")
|
||||
db.set_authorizer(
|
||||
authorizer_deny_on(sqlite3.SQLITE_UPDATE, "t1_rowids", "chunk_id")
|
||||
)
|
||||
with _raises(
|
||||
"Internal sqlite-vec error: could not initialize 'update rowids position' statement", sqlite3.DatabaseError
|
||||
"Internal sqlite-vec error: could not initialize 'update rowids position' statement",
|
||||
sqlite3.DatabaseError,
|
||||
):
|
||||
db.execute("create virtual table t1 using vec0(a float[1])")
|
||||
db.execute("insert into t1(a) values (X'AABBCCDD')")
|
||||
|
|
@ -1830,16 +1836,16 @@ def test_vec0_create_errors():
|
|||
db.rollback()
|
||||
|
||||
# TODO wut
|
||||
#db.commit()
|
||||
#db.execute("BEGIN")
|
||||
#db.set_authorizer(authorizer_deny_on(sqlite3.SQLITE_UPDATE, "t1_rowids", "id"))
|
||||
#with _raises(
|
||||
# db.commit()
|
||||
# db.execute("BEGIN")
|
||||
# db.set_authorizer(authorizer_deny_on(sqlite3.SQLITE_UPDATE, "t1_rowids", "id"))
|
||||
# with _raises(
|
||||
# "Internal sqlite-vec error: could not initialize 'rowids get chunk position' statement", sqlite3.DatabaseError
|
||||
#):
|
||||
# ):
|
||||
# db.execute("create virtual table t1 using vec0(a float[1])")
|
||||
# db.execute("insert into t1(a) values (X'AABBCCDD')")
|
||||
#db.set_authorizer(None)
|
||||
#db.rollback()
|
||||
# db.set_authorizer(None)
|
||||
# db.rollback()
|
||||
|
||||
|
||||
def test_vec0_knn():
|
||||
|
|
@ -2247,58 +2253,68 @@ def test_vec0_stress_small_chunks():
|
|||
]
|
||||
)
|
||||
|
||||
|
||||
def test_vec0_distance_metric():
|
||||
base = "('[1, 2]'), ('[3, 4]'), ('[5, 6]')"
|
||||
q = '[-1, -2]'
|
||||
base = "('[1, 2]'), ('[3, 4]'), ('[5, 6]')"
|
||||
q = "[-1, -2]"
|
||||
|
||||
db = connect(EXT_PATH)
|
||||
db.execute("create virtual table v1 using vec0( a float[2])")
|
||||
db.execute(f"insert into v1(a) values {base}")
|
||||
db = connect(EXT_PATH)
|
||||
db.execute("create virtual table v1 using vec0( a float[2])")
|
||||
db.execute(f"insert into v1(a) values {base}")
|
||||
|
||||
db.execute("create virtual table v2 using vec0( a float[2] distance_metric=l2)")
|
||||
db.execute(f"insert into v2(a) values {base}")
|
||||
db.execute("create virtual table v2 using vec0( a float[2] distance_metric=l2)")
|
||||
db.execute(f"insert into v2(a) values {base}")
|
||||
|
||||
db.execute("create virtual table v3 using vec0( a float[2] distance_metric=l1)")
|
||||
db.execute(f"insert into v3(a) values {base}")
|
||||
db.execute("create virtual table v3 using vec0( a float[2] distance_metric=l1)")
|
||||
db.execute(f"insert into v3(a) values {base}")
|
||||
|
||||
db.execute("create virtual table v4 using vec0( a float[2] distance_metric=cosine)")
|
||||
db.execute(f"insert into v4(a) values {base}")
|
||||
db.execute("create virtual table v4 using vec0( a float[2] distance_metric=cosine)")
|
||||
db.execute(f"insert into v4(a) values {base}")
|
||||
|
||||
# default (L2)
|
||||
assert execute_all(db, "select rowid, distance from v1 where a match ? and k = 3", [q]) == [
|
||||
{"rowid": 1, "distance": 4.4721360206604},
|
||||
{"rowid": 2, "distance": 7.211102485656738},
|
||||
{"rowid": 3, "distance": 10.0},
|
||||
]
|
||||
# default (L2)
|
||||
assert execute_all(
|
||||
db, "select rowid, distance from v1 where a match ? and k = 3", [q]
|
||||
) == [
|
||||
{"rowid": 1, "distance": 4.4721360206604},
|
||||
{"rowid": 2, "distance": 7.211102485656738},
|
||||
{"rowid": 3, "distance": 10.0},
|
||||
]
|
||||
|
||||
# l2
|
||||
assert execute_all(db, "select rowid, distance from v2 where a match ? and k = 3", [q]) == [
|
||||
{"rowid": 1, "distance": 4.4721360206604},
|
||||
{"rowid": 2, "distance": 7.211102485656738},
|
||||
{"rowid": 3, "distance": 10.0},
|
||||
]
|
||||
# l1
|
||||
assert execute_all(db, "select rowid, distance from v3 where a match ? and k = 3", [q]) == [
|
||||
{"rowid": 1, "distance": 6},
|
||||
{"rowid": 2, "distance": 10},
|
||||
{"rowid": 3, "distance": 14},
|
||||
]
|
||||
# consine
|
||||
assert execute_all(db, "select rowid, distance from v4 where a match ? and k = 3", [q]) == [
|
||||
{"rowid": 3, "distance": 1.9734171628952026},
|
||||
{"rowid": 2, "distance": 1.9838699102401733},
|
||||
{"rowid": 1, "distance": 2},
|
||||
]
|
||||
# l2
|
||||
assert execute_all(
|
||||
db, "select rowid, distance from v2 where a match ? and k = 3", [q]
|
||||
) == [
|
||||
{"rowid": 1, "distance": 4.4721360206604},
|
||||
{"rowid": 2, "distance": 7.211102485656738},
|
||||
{"rowid": 3, "distance": 10.0},
|
||||
]
|
||||
# l1
|
||||
assert execute_all(
|
||||
db, "select rowid, distance from v3 where a match ? and k = 3", [q]
|
||||
) == [
|
||||
{"rowid": 1, "distance": 6},
|
||||
{"rowid": 2, "distance": 10},
|
||||
{"rowid": 3, "distance": 14},
|
||||
]
|
||||
# consine
|
||||
assert execute_all(
|
||||
db, "select rowid, distance from v4 where a match ? and k = 3", [q]
|
||||
) == [
|
||||
{"rowid": 3, "distance": 1.9734171628952026},
|
||||
{"rowid": 2, "distance": 1.9838699102401733},
|
||||
{"rowid": 1, "distance": 2},
|
||||
]
|
||||
|
||||
|
||||
def test_vec0_vacuum():
|
||||
db = connect(EXT_PATH)
|
||||
db.execute('create virtual table vec_t using vec0(a float[1]);')
|
||||
db.execute("create virtual table vec_t using vec0(a float[1]);")
|
||||
db.execute("begin")
|
||||
db.execute("insert into vec_t(a) values (X'AABBCCDD')")
|
||||
db.commit()
|
||||
db.execute("vacuum")
|
||||
|
||||
|
||||
def rowids_value(buffer: bytearray) -> List[int]:
|
||||
assert (len(buffer) % 8) == 0
|
||||
n = int(len(buffer) / 8)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue