From 633db6e9cc6caff312866b282d41a5e96cd68053 Mon Sep 17 00:00:00 2001 From: Alex Garcia Date: Tue, 23 Jul 2024 14:04:17 -0700 Subject: [PATCH] add l1 distance to vec0 tables --- sqlite-vec.c | 15 ++++++++++++++- tests/test-loadable.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/sqlite-vec.c b/sqlite-vec.c index b5aacd0..0966c3a 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -1964,6 +1964,7 @@ int parse_primary_key_definition(const char *source, int source_length, enum Vec0DistanceMetrics { VEC0_DISTANCE_METRIC_L2 = 1, VEC0_DISTANCE_METRIC_COSINE = 2, + VEC0_DISTANCE_METRIC_L1 = 3, }; struct VectorColumnDefinition { @@ -2108,6 +2109,8 @@ int parse_vector_column(const char *source, int source_length, int valueLength = token.end - token.start; if (sqlite3_strnicmp(value, "l2", valueLength) == 0) { distanceMetric = VEC0_DISTANCE_METRIC_L2; + }else if (sqlite3_strnicmp(value, "l1", valueLength) == 0) { + distanceMetric = VEC0_DISTANCE_METRIC_L1; } else if (sqlite3_strnicmp(value, "cosine", valueLength) == 0) { distanceMetric = VEC0_DISTANCE_METRIC_COSINE; } else { @@ -4733,6 +4736,11 @@ int vec0Filter_knn_chunks_iter(vec0_vtab *p, sqlite3_stmt *stmtChunks, &vector_column->dimensions); break; } + case VEC0_DISTANCE_METRIC_L1: { + result = distance_l1_f32(base_i, (f32 *)queryVector, + &vector_column->dimensions); + break; + } case VEC0_DISTANCE_METRIC_COSINE: { result = distance_cosine_float(base_i, (f32 *)queryVector, &vector_column->dimensions); @@ -4750,6 +4758,11 @@ int vec0Filter_knn_chunks_iter(vec0_vtab *p, sqlite3_stmt *stmtChunks, &vector_column->dimensions); break; } + case VEC0_DISTANCE_METRIC_L1: { + result = distance_l1_int8(base_i, (i8 *)queryVector, + &vector_column->dimensions); + break; + } case VEC0_DISTANCE_METRIC_COSINE: { result = distance_cosine_int8(base_i, (i8 *)queryVector, &vector_column->dimensions); @@ -6887,7 +6900,7 @@ __declspec(dllexport) //{"vec_version", _static_text_func, 0, DEFAULT_FLAGS, (void *) SQLITE_VEC_VERSION }, //{"vec_debug", _static_text_func, 0, DEFAULT_FLAGS, (void *) SQLITE_VEC_DEBUG_STRING }, {"vec_distance_l2", vec_distance_l2, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE, }, - {"vec_distance_l1", vec_distance_l1, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE, NULL }, + {"vec_distance_l1", vec_distance_l1, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE, }, {"vec_distance_hamming",vec_distance_hamming, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE, }, {"vec_distance_cosine", vec_distance_cosine, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE, }, {"vec_length", vec_length, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE, }, diff --git a/tests/test-loadable.py b/tests/test-loadable.py index 75ab0f7..5eb0bed 100644 --- a/tests/test-loadable.py +++ b/tests/test-loadable.py @@ -2247,6 +2247,48 @@ def test_vec0_stress_small_chunks(): ] ) +def test_vec0_distance_metric(): + base = "('[1, 2]'), ('[3, 4]'), ('[5, 6]')" + q = '[-1, -2]' + + db = connect(EXT_PATH) + db.execute("create virtual table v1 using vec0( a float[2])") + db.execute(f"insert into v1(a) values {base}") + + db.execute("create virtual table v2 using vec0( a float[2] distance_metric=l2)") + db.execute(f"insert into v2(a) values {base}") + + db.execute("create virtual table v3 using vec0( a float[2] distance_metric=l1)") + db.execute(f"insert into v3(a) values {base}") + + db.execute("create virtual table v4 using vec0( a float[2] distance_metric=cosine)") + db.execute(f"insert into v4(a) values {base}") + + # default (L2) + assert execute_all(db, "select rowid, distance from v1 where a match ? and k = 3", [q]) == [ + {"rowid": 1, "distance": 4.4721360206604}, + {"rowid": 2, "distance": 7.211102485656738}, + {"rowid": 3, "distance": 10.0}, + ] + + # l2 + assert execute_all(db, "select rowid, distance from v2 where a match ? and k = 3", [q]) == [ + {"rowid": 1, "distance": 4.4721360206604}, + {"rowid": 2, "distance": 7.211102485656738}, + {"rowid": 3, "distance": 10.0}, + ] + # l1 + assert execute_all(db, "select rowid, distance from v3 where a match ? and k = 3", [q]) == [ + {"rowid": 1, "distance": 6}, + {"rowid": 2, "distance": 10}, + {"rowid": 3, "distance": 14}, + ] + # consine + assert execute_all(db, "select rowid, distance from v4 where a match ? and k = 3", [q]) == [ + {"rowid": 3, "distance": 1.9734171628952026}, + {"rowid": 2, "distance": 1.9838699102401733}, + {"rowid": 1, "distance": 2}, + ] def rowids_value(buffer: bytearray) -> List[int]: assert (len(buffer) % 8) == 0