add l1 distance to vec0 tables

This commit is contained in:
Alex Garcia 2024-07-23 14:04:17 -07:00
parent 79491542e5
commit 633db6e9cc
2 changed files with 56 additions and 1 deletions

View file

@ -2247,6 +2247,48 @@ def test_vec0_stress_small_chunks():
]
)
def test_vec0_distance_metric():
base = "('[1, 2]'), ('[3, 4]'), ('[5, 6]')"
q = '[-1, -2]'
db = connect(EXT_PATH)
db.execute("create virtual table v1 using vec0( a float[2])")
db.execute(f"insert into v1(a) values {base}")
db.execute("create virtual table v2 using vec0( a float[2] distance_metric=l2)")
db.execute(f"insert into v2(a) values {base}")
db.execute("create virtual table v3 using vec0( a float[2] distance_metric=l1)")
db.execute(f"insert into v3(a) values {base}")
db.execute("create virtual table v4 using vec0( a float[2] distance_metric=cosine)")
db.execute(f"insert into v4(a) values {base}")
# default (L2)
assert execute_all(db, "select rowid, distance from v1 where a match ? and k = 3", [q]) == [
{"rowid": 1, "distance": 4.4721360206604},
{"rowid": 2, "distance": 7.211102485656738},
{"rowid": 3, "distance": 10.0},
]
# l2
assert execute_all(db, "select rowid, distance from v2 where a match ? and k = 3", [q]) == [
{"rowid": 1, "distance": 4.4721360206604},
{"rowid": 2, "distance": 7.211102485656738},
{"rowid": 3, "distance": 10.0},
]
# l1
assert execute_all(db, "select rowid, distance from v3 where a match ? and k = 3", [q]) == [
{"rowid": 1, "distance": 6},
{"rowid": 2, "distance": 10},
{"rowid": 3, "distance": 14},
]
# consine
assert execute_all(db, "select rowid, distance from v4 where a match ? and k = 3", [q]) == [
{"rowid": 3, "distance": 1.9734171628952026},
{"rowid": 2, "distance": 1.9838699102401733},
{"rowid": 1, "distance": 2},
]
def rowids_value(buffer: bytearray) -> List[int]:
assert (len(buffer) % 8) == 0