From a1a64427fc49b4f8fb24569678c9cc1050a2b4ab Mon Sep 17 00:00:00 2001 From: Alex Garcia Date: Thu, 14 Nov 2024 16:36:53 -0800 Subject: [PATCH] boolean comparison handling --- TODO | 1 - sqlite-vec.c | 53 +++++++--- tests/__snapshots__/test-metadata.ambr | 131 ++++++++++++++++++++++++- tests/test-metadata.py | 29 ++++-- 4 files changed, 189 insertions(+), 25 deletions(-) diff --git a/TODO b/TODO index 5e1c2f1..b8e09c0 100644 --- a/TODO +++ b/TODO @@ -13,7 +13,6 @@ - perf: LEFT JOIN aux table to rowids query in vec0_cursor for rowid/point stmts, to avoid N lookup queries # metadata filtering -- boolean comparisons - text comparisons (long) - skip invalid validity entries in knn filter? - null! diff --git a/sqlite-vec.c b/sqlite-vec.c index ee1a15c..69ec6d2 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -2093,7 +2093,7 @@ typedef enum { VEC0_METADATA_COLUMN_KIND_INTEGER, VEC0_METADATA_COLUMN_KIND_FLOAT, VEC0_METADATA_COLUMN_KIND_TEXT, - // TODO: blob, date, datetime + // future: blob, date, datetime } vec0_metadata_column_kind; /** @@ -5480,7 +5480,6 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { } - // TODO: when aux branch is merge, move this loop logic to above loop for (int i = 0; i < pIdxInfo->nConstraint; i++) { if (!pIdxInfo->aConstraint[i].usable) continue; @@ -5533,15 +5532,22 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { } } - if(value) { - pIdxInfo->aConstraintUsage[i].argvIndex = argvIndex++; - pIdxInfo->aConstraintUsage[i].omit = 1; - sqlite3_str_appendchar(idxStr, 1, VEC0_IDXSTR_KIND_METADATA_CONSTRAINT); - sqlite3_str_appendchar(idxStr, 1, 'A' + metadata_idx); - sqlite3_str_appendchar(idxStr, 1, value); - sqlite3_str_appendchar(idxStr, 1, '_'); + if(p->metadata_columns[metadata_idx].kind == VEC0_METADATA_COLUMN_KIND_BOOLEAN) { + if(!(value == VEC0_METADATA_OPERATOR_EQ || value == VEC0_METADATA_OPERATOR_NE)) { + // IMP: V10145_26984 + rc = SQLITE_ERROR; + vtab_set_error(pVTab, "ONLY EQUALS (=) or NOT_EQUALS (!=) operators are allowed on boolean metadata columns."); + goto done; + } } + pIdxInfo->aConstraintUsage[i].argvIndex = argvIndex++; + pIdxInfo->aConstraintUsage[i].omit = 1; + sqlite3_str_appendchar(idxStr, 1, VEC0_IDXSTR_KIND_METADATA_CONSTRAINT); + sqlite3_str_appendchar(idxStr, 1, 'A' + metadata_idx); + sqlite3_str_appendchar(idxStr, 1, value); + sqlite3_str_appendchar(idxStr, 1, '_'); + } @@ -5867,11 +5873,18 @@ int vec0_set_metadata_filter_bitmap( if(!buffer) { return SQLITE_NOMEM; } - sqlite3_blob_read(blob, buffer, blobSize, 0); + rc = sqlite3_blob_read(blob, buffer, blobSize, 0); + if(rc != SQLITE_OK) { + goto done; + } switch(kind) { case VEC0_METADATA_COLUMN_KIND_BOOLEAN: { - for(int i = 0; i < size; i++) { - // TODO boolean comparisions + int target = sqlite3_value_int(value); + if( (target && op == VEC0_METADATA_OPERATOR_EQ) || (!target && op == VEC0_METADATA_OPERATOR_NE)) { + for(int i = 0; i < size; i++) { bitmap_set(b, i, bitmap_get((u8*) buffer, i)); } + } + else { + for(int i = 0; i < size; i++) { bitmap_set(b, i, !bitmap_get((u8*) buffer, i)); } } break; } @@ -5938,8 +5951,17 @@ int vec0_set_metadata_filter_bitmap( break; } case VEC0_METADATA_COLUMN_KIND_TEXT: { - // TODO check for and handle large strings + // TODO: handle large strings. For now just raise a generic error + for(int i = 0; i < size; i++) { + u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; + int n = ((int*) view)[0]; + if(n > 12) { + rc = SQLITE_ERROR; + goto done; + } + } const char * target = (const char *) sqlite3_value_text(value); + switch(op) { case VEC0_METADATA_OPERATOR_EQ: { for(int i = 0; i < size; i++) { @@ -5999,8 +6021,9 @@ int vec0_set_metadata_filter_bitmap( break; } } - sqlite3_free(buffer); - return SQLITE_OK; + done: + sqlite3_free(buffer); + return rc; } int vec0Filter_knn_chunks_iter(vec0_vtab *p, sqlite3_stmt *stmtChunks, diff --git a/tests/__snapshots__/test-metadata.ambr b/tests/__snapshots__/test-metadata.ambr index 6c01534..6a119a7 100644 --- a/tests/__snapshots__/test-metadata.ambr +++ b/tests/__snapshots__/test-metadata.ambr @@ -1638,13 +1638,140 @@ ]), }) # --- -# name: test_stress.8 +# name: test_stress[bool-eq-false] OrderedDict({ - 'sql': "select movie_id, mean_rating, distance from vec_movies where synopsis_embedding match '[100]' and k = 5 and is_favorited = TRUE", + 'sql': "select movie_id, is_favorited, distance from vec_movies where synopsis_embedding match '[100]' and k = 5 and is_favorited = FALSE", 'rows': list([ + OrderedDict({ + 'movie_id': 16, + 'is_favorited': 0, + 'distance': 84.0, + }), + OrderedDict({ + 'movie_id': 14, + 'is_favorited': 0, + 'distance': 86.0, + }), + OrderedDict({ + 'movie_id': 12, + 'is_favorited': 0, + 'distance': 88.0, + }), + OrderedDict({ + 'movie_id': 10, + 'is_favorited': 0, + 'distance': 90.0, + }), + OrderedDict({ + 'movie_id': 8, + 'is_favorited': 0, + 'distance': 92.0, + }), ]), }) # --- +# name: test_stress[bool-eq-true] + OrderedDict({ + 'sql': "select movie_id, is_favorited, distance from vec_movies where synopsis_embedding match '[100]' and k = 5 and is_favorited = TRUE", + 'rows': list([ + OrderedDict({ + 'movie_id': 25, + 'is_favorited': 1, + 'distance': 75.0, + }), + OrderedDict({ + 'movie_id': 24, + 'is_favorited': 1, + 'distance': 76.0, + }), + OrderedDict({ + 'movie_id': 23, + 'is_favorited': 1, + 'distance': 77.0, + }), + OrderedDict({ + 'movie_id': 22, + 'is_favorited': 1, + 'distance': 78.0, + }), + OrderedDict({ + 'movie_id': 21, + 'is_favorited': 1, + 'distance': 79.0, + }), + ]), + }) +# --- +# name: test_stress[bool-ne-false] + OrderedDict({ + 'sql': "select movie_id, is_favorited, distance from vec_movies where synopsis_embedding match '[100]' and k = 5 and is_favorited != FALSE", + 'rows': list([ + OrderedDict({ + 'movie_id': 25, + 'is_favorited': 1, + 'distance': 75.0, + }), + OrderedDict({ + 'movie_id': 24, + 'is_favorited': 1, + 'distance': 76.0, + }), + OrderedDict({ + 'movie_id': 23, + 'is_favorited': 1, + 'distance': 77.0, + }), + OrderedDict({ + 'movie_id': 22, + 'is_favorited': 1, + 'distance': 78.0, + }), + OrderedDict({ + 'movie_id': 21, + 'is_favorited': 1, + 'distance': 79.0, + }), + ]), + }) +# --- +# name: test_stress[bool-ne-true] + OrderedDict({ + 'sql': "select movie_id, is_favorited, distance from vec_movies where synopsis_embedding match '[100]' and k = 5 and is_favorited != TRUE", + 'rows': list([ + OrderedDict({ + 'movie_id': 16, + 'is_favorited': 0, + 'distance': 84.0, + }), + OrderedDict({ + 'movie_id': 14, + 'is_favorited': 0, + 'distance': 86.0, + }), + OrderedDict({ + 'movie_id': 12, + 'is_favorited': 0, + 'distance': 88.0, + }), + OrderedDict({ + 'movie_id': 10, + 'is_favorited': 0, + 'distance': 90.0, + }), + OrderedDict({ + 'movie_id': 8, + 'is_favorited': 0, + 'distance': 92.0, + }), + ]), + }) +# --- +# name: test_stress[bool-other-op] + dict({ + 'error': 'OperationalError', + 'message': 'ONLY EQUALS (=) or NOT_EQUALS (!=) operators are allowed on boolean metadata columns.', + }) +# --- # name: test_types[illegal-boolean] dict({ 'error': 'OperationalError', diff --git a/tests/test-metadata.py b/tests/test-metadata.py index f91c46c..b7ba5ce 100644 --- a/tests/test-metadata.py +++ b/tests/test-metadata.py @@ -355,13 +355,28 @@ def test_stress(db, snapshot): == snapshot() ) - assert ( - exec( - db, - "select movie_id, mean_rating, distance from vec_movies where synopsis_embedding match '[100]' and k = 5 and is_favorited = TRUE", - ) - == snapshot() - ) + assert exec( + db, + "select movie_id, is_favorited, distance from vec_movies where synopsis_embedding match '[100]' and k = 5 and is_favorited = TRUE", + ) == snapshot(name="bool-eq-true") + assert exec( + db, + "select movie_id, is_favorited, distance from vec_movies where synopsis_embedding match '[100]' and k = 5 and is_favorited != TRUE", + ) == snapshot(name="bool-ne-true") + assert exec( + db, + "select movie_id, is_favorited, distance from vec_movies where synopsis_embedding match '[100]' and k = 5 and is_favorited = FALSE", + ) == snapshot(name="bool-eq-false") + assert exec( + db, + "select movie_id, is_favorited, distance from vec_movies where synopsis_embedding match '[100]' and k = 5 and is_favorited != FALSE", + ) == snapshot(name="bool-ne-false") + + # EVIDENCE-OF: V10145_26984 + assert exec( + db, + "select movie_id, is_favorited, distance from vec_movies where synopsis_embedding match '[100]' and k = 5 and is_favorited >= 999", + ) == snapshot(name="bool-other-op") def exec(db, sql, parameters=[]):