From 10a2216845077802e56f2b87d0d41993814df68e Mon Sep 17 00:00:00 2001 From: Alex Garcia Date: Mon, 18 Nov 2024 11:21:49 -0800 Subject: [PATCH] refactor some text knn filtering --- sqlite-vec.c | 188 +++++++++++++------------ tests/__snapshots__/test-metadata.ambr | 8 +- tests/test-metadata.py | 10 -- 3 files changed, 101 insertions(+), 105 deletions(-) diff --git a/sqlite-vec.c b/sqlite-vec.c index db8e756..ae7417d 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -5852,6 +5852,104 @@ int vec0_chunks_iter(vec0_vtab * p, const char * idxStr, int argc, sqlite3_value return rc; } +int vec0_metadata_filter_text(vec0_vtab * p, sqlite3_value * value, const void * buffer, int size, vec0_metadata_operator op, u8* b, int metadata_idx, i64 *rowids) { + int rc; + sqlite3_stmt * stmt = NULL; + const char * target = (const char *) sqlite3_value_text(value); + int targetn = sqlite3_value_bytes(value); + + switch(op) { + case VEC0_METADATA_OPERATOR_EQ: { + for(int i = 0; i < size; i++) { + u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; + int n = ((int*) view)[0]; + char * s = (char *) &view[4]; + if(n != targetn) { + bitmap_set(b, i, 0); + continue; + } + int prefix_cmp = strncmp(s, target, min(n, 12)); + if(n <= 12) { + bitmap_set(b, i, prefix_cmp == 0); + } + // if the prefix doesnt match, the rest of the string wont match + else if(prefix_cmp) { + bitmap_set(b, i, 0); + } + // need to consult + else { + char *slong; + int slongn; + rc = vec0_get_metadata_text_long_value(p, &stmt, metadata_idx, rowids[i], &slongn, &slong); + if(rc != SQLITE_OK) { + goto done; + } + assert(n == slongn); + bitmap_set(b, i, strncmp(slong, target, n) == 0); + } + } + break; + } + case VEC0_METADATA_OPERATOR_NE: { + for(int i = 0; i < size; i++) { + u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; + int n = ((int*) view)[0]; + char * s = (char *) &view[4]; + if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */ + bitmap_set(b, i, strncmp(s, target, n) != 0); + } + break; + } + case VEC0_METADATA_OPERATOR_GT: { + for(int i = 0; i < size; i++) { + u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; + int n = ((int*) view)[0]; + char * s = (char *) &view[4]; + if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */ + bitmap_set(b, i, strncmp(s, target, n) > 0); + } + break; + } + case VEC0_METADATA_OPERATOR_GE: { + for(int i = 0; i < size; i++) { + u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; + int n = ((int*) view)[0]; + char * s = (char *) &view[4]; + if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */ + bitmap_set(b, i, strncmp(s, target, n) >= 0); + } + break; + } + case VEC0_METADATA_OPERATOR_LE: { + for(int i = 0; i < size; i++) { + u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; + int n = ((int*) view)[0]; + char * s = (char *) &view[4]; + if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */ + bitmap_set(b, i, strncmp(s, target, n) <= 0); + } + break; + } + case VEC0_METADATA_OPERATOR_LT: { + for(int i = 0; i < size; i++) { + u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; + int n = ((int*) view)[0]; + char * s = (char *) &view[4]; + if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */ + bitmap_set(b, i, strncmp(s, target, n) < 0); + } + break; + } + + } + rc = SQLITE_OK; + + done: + sqlite3_finalize(stmt); + return rc; + +} + /** * @brief Fill in bitmap of chunk values, whether or not the values match a metadata constraint * @@ -6008,95 +6106,7 @@ int vec0_set_metadata_filter_bitmap( break; } case VEC0_METADATA_COLUMN_KIND_TEXT: { - const char * target = (const char *) sqlite3_value_text(value); - int targetn = sqlite3_value_bytes(value); - - switch(op) { - case VEC0_METADATA_OPERATOR_EQ: { - sqlite3_stmt * stmt = NULL; - for(int i = 0; i < size; i++) { - u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; - int n = ((int*) view)[0]; - char * s = (char *) &view[4]; - if(n != targetn) { - bitmap_set(b, i, 0); - continue; - } - int prefix_cmp = strncmp(s, target, min(n, 12)); - if(n <= 12) { - bitmap_set(b, i, prefix_cmp == 0); - } - // if the prefix doesnt match, the rest of the string wont match - else if(prefix_cmp) { - bitmap_set(b, i, 0); - } - // need to consult - else { - char *slong; - int slongn; - rc = vec0_get_metadata_text_long_value(p, &stmt, metadata_idx, rowids[i], &slongn, &slong); - if(rc != SQLITE_OK) { - goto done; - } - assert(n == slongn); - bitmap_set(b, i, strncmp(slong, target, n) == 0); - } - } - sqlite3_finalize(stmt); - break; - } - case VEC0_METADATA_OPERATOR_NE: { - for(int i = 0; i < size; i++) { - u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; - int n = ((int*) view)[0]; - char * s = (char *) &view[4]; - if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */ - bitmap_set(b, i, strncmp(s, target, n) != 0); - } - break; - } - case VEC0_METADATA_OPERATOR_GT: { - for(int i = 0; i < size; i++) { - u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; - int n = ((int*) view)[0]; - char * s = (char *) &view[4]; - if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */ - bitmap_set(b, i, strncmp(s, target, n) > 0); - } - break; - } - case VEC0_METADATA_OPERATOR_GE: { - for(int i = 0; i < size; i++) { - u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; - int n = ((int*) view)[0]; - char * s = (char *) &view[4]; - if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */ - bitmap_set(b, i, strncmp(s, target, n) >= 0); - } - break; - } - case VEC0_METADATA_OPERATOR_LE: { - for(int i = 0; i < size; i++) { - u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; - int n = ((int*) view)[0]; - char * s = (char *) &view[4]; - if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */ - bitmap_set(b, i, strncmp(s, target, n) <= 0); - } - break; - } - case VEC0_METADATA_OPERATOR_LT: { - for(int i = 0; i < size; i++) { - u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; - int n = ((int*) view)[0]; - char * s = (char *) &view[4]; - if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */ - bitmap_set(b, i, strncmp(s, target, n) < 0); - } - break; - } - - } + vec0_metadata_filter_text(p, value, buffer, size, op, b, metadata_idx, rowids); break; } } diff --git a/tests/__snapshots__/test-metadata.ambr b/tests/__snapshots__/test-metadata.ambr index 7213724..51b1a2a 100644 --- a/tests/__snapshots__/test-metadata.ambr +++ b/tests/__snapshots__/test-metadata.ambr @@ -1434,6 +1434,7 @@ # name: test_stress.1 OrderedDict({ 'sql': ''' + select movie_id, title, @@ -1448,6 +1449,7 @@ and num_reviews between 100 and 500 and mean_rating > 3.5 and k = 5; + ''', 'rows': list([ OrderedDict({ @@ -1875,12 +1877,6 @@ ]), }) # --- -# name: test_text_knn.10 - dict({ - 'error': 'OperationalError', - 'message': 'Could not filter metadata fields', - }) -# --- # name: test_text_knn.2 dict({ 'v_chunks': OrderedDict({ diff --git a/tests/test-metadata.py b/tests/test-metadata.py index 3e2fb95..2b37b61 100644 --- a/tests/test-metadata.py +++ b/tests/test-metadata.py @@ -120,16 +120,6 @@ def test_text_knn(db, snapshot): == snapshot() ) - # this break KNN :( - db.execute("insert into v(vector, name) values ('[3.0]', '1234567890123')") - assert ( - exec( - db, - "select rowid, name, distance from v where vector match '[.01]' and k = 5 and name != 'aaa'", - ) - == snapshot() - ) - def test_long_text_updates(db, snapshot): db.execute(