vec_type(), API references

This commit is contained in:
Alex Garcia 2024-07-22 21:24:44 -07:00
parent cfd8e9a46b
commit ff6cf96e2a
6 changed files with 677 additions and 240 deletions

View file

@ -110,12 +110,12 @@ FUNCTIONS = [
"vec_length",
"vec_normalize",
"vec_quantize_binary",
"vec_quantize_i8",
"vec_quantize_i8",
"vec_quantize_int8",
"vec_slice",
"vec_static_blob_from_raw",
"vec_sub",
"vec_to_json",
"vec_type",
"vec_version",
]
MODULES = [
@ -448,6 +448,20 @@ def test_vec_slice():
vec_slice(b"\xab\xab\xab\xab", 0, 0)
def test_vec_type():
vec_type = lambda *args, a="?": db.execute(f"select vec_type({a})", args).fetchone()[0]
assert vec_type('[1]') == "float32"
assert vec_type(b"\xaa\xbb\xcc\xdd") == "float32"
assert vec_type('[1]', a='vec_f32(?)') == "float32"
assert vec_type('[1]', a='vec_int8(?)') == "int8"
assert vec_type(b"\xaa", a='vec_bit(?)') == "bit"
with _raises("invalid float32 vector"):
vec_type(b"\xaa")
with _raises("found NULL"):
vec_type(None)
def test_vec_add():
vec_add = lambda *args, a="?", b="?": db.execute(
f"select vec_add({a}, {b})", args
@ -517,11 +531,11 @@ def test_vec_to_json():
@pytest.mark.skip(reason="TODO")
def test_vec_quantize_i8():
vec_quantize_i8 = lambda *args: db.execute(
"select vec_quantize_i8()", args
def test_vec_quantize_int8():
vec_quantize_int8 = lambda *args: db.execute(
"select vec_quantize_int8()", args
).fetchone()[0]
assert vec_quantize_i8() == 111
assert vec_quantize_int8() == 111
def test_vec_quantize_binary():
@ -1020,9 +1034,9 @@ def test_vec0_updates():
db.execute(
"""
INSERT INTO t3 VALUES
(1, :x, vec_quantize_i8(:x, 'unit') ,vec_quantize_binary(:x)),
(2, :y, vec_quantize_i8(:y, 'unit') ,vec_quantize_binary(:y)),
(3, :z, vec_quantize_i8(:z, 'unit') ,vec_quantize_binary(:z));
(1, :x, vec_quantize_int8(:x, 'unit') ,vec_quantize_binary(:x)),
(2, :y, vec_quantize_int8(:y, 'unit') ,vec_quantize_binary(:y)),
(3, :z, vec_quantize_int8(:z, 'unit') ,vec_quantize_binary(:z));
""",
{
"x": "[.1, .1, .1, .1, -.1, -.1, -.1, -.1]",
@ -1795,7 +1809,7 @@ def test_vec0_knn():
db.executemany(
"""
INSERT INTO v VALUES
(:id, :vector, vec_quantize_i8(:vector, 'unit') ,vec_quantize_binary(:vector));
(:id, :vector, vec_quantize_int8(:vector, 'unit') ,vec_quantize_binary(:vector));
""",
[
{