mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 16:56:27 +02:00
add l1 distance to vec0 tables
This commit is contained in:
parent
79491542e5
commit
633db6e9cc
2 changed files with 56 additions and 1 deletions
15
sqlite-vec.c
15
sqlite-vec.c
|
|
@ -1964,6 +1964,7 @@ int parse_primary_key_definition(const char *source, int source_length,
|
|||
enum Vec0DistanceMetrics {
|
||||
VEC0_DISTANCE_METRIC_L2 = 1,
|
||||
VEC0_DISTANCE_METRIC_COSINE = 2,
|
||||
VEC0_DISTANCE_METRIC_L1 = 3,
|
||||
};
|
||||
|
||||
struct VectorColumnDefinition {
|
||||
|
|
@ -2108,6 +2109,8 @@ int parse_vector_column(const char *source, int source_length,
|
|||
int valueLength = token.end - token.start;
|
||||
if (sqlite3_strnicmp(value, "l2", valueLength) == 0) {
|
||||
distanceMetric = VEC0_DISTANCE_METRIC_L2;
|
||||
}else if (sqlite3_strnicmp(value, "l1", valueLength) == 0) {
|
||||
distanceMetric = VEC0_DISTANCE_METRIC_L1;
|
||||
} else if (sqlite3_strnicmp(value, "cosine", valueLength) == 0) {
|
||||
distanceMetric = VEC0_DISTANCE_METRIC_COSINE;
|
||||
} else {
|
||||
|
|
@ -4733,6 +4736,11 @@ int vec0Filter_knn_chunks_iter(vec0_vtab *p, sqlite3_stmt *stmtChunks,
|
|||
&vector_column->dimensions);
|
||||
break;
|
||||
}
|
||||
case VEC0_DISTANCE_METRIC_L1: {
|
||||
result = distance_l1_f32(base_i, (f32 *)queryVector,
|
||||
&vector_column->dimensions);
|
||||
break;
|
||||
}
|
||||
case VEC0_DISTANCE_METRIC_COSINE: {
|
||||
result = distance_cosine_float(base_i, (f32 *)queryVector,
|
||||
&vector_column->dimensions);
|
||||
|
|
@ -4750,6 +4758,11 @@ int vec0Filter_knn_chunks_iter(vec0_vtab *p, sqlite3_stmt *stmtChunks,
|
|||
&vector_column->dimensions);
|
||||
break;
|
||||
}
|
||||
case VEC0_DISTANCE_METRIC_L1: {
|
||||
result = distance_l1_int8(base_i, (i8 *)queryVector,
|
||||
&vector_column->dimensions);
|
||||
break;
|
||||
}
|
||||
case VEC0_DISTANCE_METRIC_COSINE: {
|
||||
result = distance_cosine_int8(base_i, (i8 *)queryVector,
|
||||
&vector_column->dimensions);
|
||||
|
|
@ -6887,7 +6900,7 @@ __declspec(dllexport)
|
|||
//{"vec_version", _static_text_func, 0, DEFAULT_FLAGS, (void *) SQLITE_VEC_VERSION },
|
||||
//{"vec_debug", _static_text_func, 0, DEFAULT_FLAGS, (void *) SQLITE_VEC_DEBUG_STRING },
|
||||
{"vec_distance_l2", vec_distance_l2, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE, },
|
||||
{"vec_distance_l1", vec_distance_l1, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE, NULL },
|
||||
{"vec_distance_l1", vec_distance_l1, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE, },
|
||||
{"vec_distance_hamming",vec_distance_hamming, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE, },
|
||||
{"vec_distance_cosine", vec_distance_cosine, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE, },
|
||||
{"vec_length", vec_length, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE, },
|
||||
|
|
|
|||
|
|
@ -2247,6 +2247,48 @@ def test_vec0_stress_small_chunks():
|
|||
]
|
||||
)
|
||||
|
||||
def test_vec0_distance_metric():
|
||||
base = "('[1, 2]'), ('[3, 4]'), ('[5, 6]')"
|
||||
q = '[-1, -2]'
|
||||
|
||||
db = connect(EXT_PATH)
|
||||
db.execute("create virtual table v1 using vec0( a float[2])")
|
||||
db.execute(f"insert into v1(a) values {base}")
|
||||
|
||||
db.execute("create virtual table v2 using vec0( a float[2] distance_metric=l2)")
|
||||
db.execute(f"insert into v2(a) values {base}")
|
||||
|
||||
db.execute("create virtual table v3 using vec0( a float[2] distance_metric=l1)")
|
||||
db.execute(f"insert into v3(a) values {base}")
|
||||
|
||||
db.execute("create virtual table v4 using vec0( a float[2] distance_metric=cosine)")
|
||||
db.execute(f"insert into v4(a) values {base}")
|
||||
|
||||
# default (L2)
|
||||
assert execute_all(db, "select rowid, distance from v1 where a match ? and k = 3", [q]) == [
|
||||
{"rowid": 1, "distance": 4.4721360206604},
|
||||
{"rowid": 2, "distance": 7.211102485656738},
|
||||
{"rowid": 3, "distance": 10.0},
|
||||
]
|
||||
|
||||
# l2
|
||||
assert execute_all(db, "select rowid, distance from v2 where a match ? and k = 3", [q]) == [
|
||||
{"rowid": 1, "distance": 4.4721360206604},
|
||||
{"rowid": 2, "distance": 7.211102485656738},
|
||||
{"rowid": 3, "distance": 10.0},
|
||||
]
|
||||
# l1
|
||||
assert execute_all(db, "select rowid, distance from v3 where a match ? and k = 3", [q]) == [
|
||||
{"rowid": 1, "distance": 6},
|
||||
{"rowid": 2, "distance": 10},
|
||||
{"rowid": 3, "distance": 14},
|
||||
]
|
||||
# consine
|
||||
assert execute_all(db, "select rowid, distance from v4 where a match ? and k = 3", [q]) == [
|
||||
{"rowid": 3, "distance": 1.9734171628952026},
|
||||
{"rowid": 2, "distance": 1.9838699102401733},
|
||||
{"rowid": 1, "distance": 2},
|
||||
]
|
||||
|
||||
def rowids_value(buffer: bytearray) -> List[int]:
|
||||
assert (len(buffer) % 8) == 0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue