Merge branch 'main' of github.com:asg017/sqlite-vec into main

This commit is contained in:
Alex Garcia 2024-07-23 12:27:37 -07:00
commit 79491542e5
3 changed files with 267 additions and 15 deletions

View file

@ -104,6 +104,7 @@ FUNCTIONS = [
"vec_debug",
"vec_distance_cosine",
"vec_distance_hamming",
"vec_distance_l1",
"vec_distance_l2",
"vec_f32",
"vec_int8",
@ -309,6 +310,68 @@ def test_vec_distance_hamming():
db.execute("select vec_distance_hamming(vec_int8(X'FF'), vec_int8(X'FF'))")
def test_vec_distance_l1():
vec_distance_l1 = lambda *args, a="?", b="?": db.execute(
f"select vec_distance_l1({a}, {b})", args
).fetchone()[0]
def check(a, b, dtype=np.float32):
if dtype == np.float32:
transform = "?"
elif dtype == np.int8:
transform = "vec_int8(?)"
a_sql_t = np.array(a, dtype=dtype)
b_sql_t = np.array(b, dtype=dtype)
x = vec_distance_l1(a_sql_t, b_sql_t, a=transform, b=transform)
# dont use dtype here bc overflow
y = np.sum(np.abs(np.array(a) - np.array(b)))
assert isclose(x, y, abs_tol=1e-6)
check([1, 2, 3], [-9, -8, -7], dtype=np.int8)
# check overflow
check([127] * 20, [-128] * 20, dtype=np.int8)
check([-128, 127], [127, -128], dtype=np.int8)
check(
[1, 2, 3, 4, 5, 6, 7, 8, 1, 1, 2, 3, 4, 5, 6, 7, 8, 1],
[1, 20, 38, 23, 29, 4, 10, 9, 3, 1, 20, 38, 23, 29, 4, 10, 9, 3],
dtype=np.int8,
)
check([0] * 20, [0] * 20, dtype=np.int8)
check(
[5, 15, -20, 5, 15, -20, 5, 15, -20, 5, 15, -20, 5, 15, -20, 5, 15, -20],
[5, 15, -20, 5, 15, -20, 5, 15, -20, 5, 15, -20, 5, 15, -20, 5, 15, -20],
dtype=np.int8,
)
check([100] * 20, [-100] * 20, dtype=np.int8)
check([127] * 1000000, [-128] * 1000000, dtype=np.int8)
check(
[1.2, 0.1, 0.5, 0.9, 1.4, 4.5],
[0.4, -0.4, 0.1, 0.1, 0.5, 0.9],
dtype=np.float32,
)
check([1.0, 2.0, 3.0], [-1.0, -2.0, -3.0], dtype=np.float32)
check(
[1e10, 2e10, np.finfo(np.float32).max],
[-1e10, -2e10, np.finfo(np.float32).min],
dtype=np.float32,
)
# overflow in leftover elements
check(
[1e10, 2e10, 1e10, 2e10, np.finfo(np.float32).max],
[-1e10, -2e10, -1e10, -2e10, np.finfo(np.float32).min],
dtype=np.float32,
)
# overflow in neon elements
check(
[np.finfo(np.float32).max, 1e10, 2e10, 1e10, 2e10],
[np.finfo(np.float32).min, -1e10, -2e10, -1e10, -2e10],
dtype=np.float32,
)
def test_vec_distance_l2():
vec_distance_l2 = lambda *args, a="?", b="?": db.execute(
f"select vec_distance_l2({a}, {b})", args
@ -319,11 +382,12 @@ def test_vec_distance_l2():
transform = "?"
elif dtype == np.int8:
transform = "vec_int8(?)"
a = np.array(a, dtype=dtype)
b = np.array(b, dtype=dtype)
x = vec_distance_l2(a, b, a=transform, b=transform)
y = npy_l2(a, b)
a_sql_t = np.array(a, dtype=dtype)
b_sql_t = np.array(b, dtype=dtype)
x = vec_distance_l2(a_sql_t, b_sql_t, a=transform, b=transform)
y = npy_l2(np.array(a), np.array(b))
assert isclose(x, y, abs_tol=1e-6)
check([1.2, 0.1], [0.4, -0.4])