diff --git a/sqlite-vec.c b/sqlite-vec.c index f2c01c4..4a369c1 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -5968,19 +5968,19 @@ int vec0_metadata_filter_text(vec0_vtab * p, sqlite3_value * value, const void * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; nPrefix = ((int*) view)[0]; sPrefix = (char *) &view[4]; - int cmpPrefix = strncmp(sPrefix, sTarget, min(nPrefix, VEC0_METADATA_TEXT_VIEW_DATA_LENGTH)); + int cmpPrefix = strncmp(sPrefix, sTarget, min(min(nPrefix, VEC0_METADATA_TEXT_VIEW_DATA_LENGTH), nTarget)); - // for short strings, use the prefix comparison direclty - if(nPrefix <= VEC0_METADATA_TEXT_VIEW_DATA_LENGTH) { - bitmap_set(b, i, cmpPrefix > 0); - continue; - } - - // for GT, only need to consult full string if EQ - if(cmpPrefix != 0) { - bitmap_set(b, i, cmpPrefix > 0); + if(nPrefix < VEC0_METADATA_TEXT_VIEW_DATA_LENGTH) { + // if prefix match, check which is longer + if(cmpPrefix == 0) { + bitmap_set(b, i, nPrefix > nTarget); + } + else { + bitmap_set(b, i, cmpPrefix > 0); + } continue; } + // TODO(perf): may not need to compare full text in some cases rc = vec0_get_metadata_text_long_value(p, &stmt, metadata_idx, rowids[i], &nFull, &sFull); if(rc != SQLITE_OK) { @@ -5996,11 +5996,32 @@ int vec0_metadata_filter_text(vec0_vtab * p, sqlite3_value * value, const void * } case VEC0_METADATA_OPERATOR_GE: { for(int i = 0; i < size; i++) { - u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; - int n = ((int*) view)[0]; - char * s = (char *) &view[4]; - if(n > VEC0_METADATA_TEXT_VIEW_DATA_LENGTH) {rc = SQLITE_ERROR;goto done;} /* TODO */ - bitmap_set(b, i, strncmp(s, sTarget, n) >= 0); + view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; + nPrefix = ((int*) view)[0]; + sPrefix = (char *) &view[4]; + int cmpPrefix = strncmp(sPrefix, sTarget, min(min(nPrefix, VEC0_METADATA_TEXT_VIEW_DATA_LENGTH), nTarget)); + + if(nPrefix < VEC0_METADATA_TEXT_VIEW_DATA_LENGTH) { + // if prefix match, check which is longer + if(cmpPrefix == 0) { + bitmap_set(b, i, nPrefix >= nTarget); + } + else { + bitmap_set(b, i, cmpPrefix >= 0); + } + continue; + } + // TODO(perf): may not need to compare full text in some cases + + rc = vec0_get_metadata_text_long_value(p, &stmt, metadata_idx, rowids[i], &nFull, &sFull); + if(rc != SQLITE_OK) { + goto done; + } + if(nPrefix != nFull) { + rc = SQLITE_ERROR; + goto done; + } + bitmap_set(b, i, strncmp(sFull, sTarget, nFull) >= 0); } break; } diff --git a/tests/__snapshots__/test-metadata.ambr b/tests/__snapshots__/test-metadata.ambr index 33d443c..05b6a75 100644 --- a/tests/__snapshots__/test-metadata.ambr +++ b/tests/__snapshots__/test-metadata.ambr @@ -617,6 +617,13 @@ ]), }) # --- +# name: test_long_text_knn[eq-bb] + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name = ?", + 'rows': list([ + ]), + }) +# --- # name: test_long_text_knn[eq-bbbb] OrderedDict({ 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name = ?", @@ -629,6 +636,13 @@ ]), }) # --- +# name: test_long_text_knn[eq-bbbbbb] + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name = ?", + 'rows': list([ + ]), + }) +# --- # name: test_long_text_knn[eq-bbbbbbbbbbbb_aaa] OrderedDict({ 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name = ?", @@ -662,34 +676,175 @@ ]), }) # --- +# name: test_long_text_knn[ge-bb] + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name >= ?", + 'rows': list([ + OrderedDict({ + 'rowid': 6, + 'name': 'cccccccccccc_ccc', + 'distance': 94.0, + }), + OrderedDict({ + 'rowid': 5, + 'name': 'cccc', + 'distance': 95.0, + }), + OrderedDict({ + 'rowid': 4, + 'name': 'bbbbbbbbbbbb_bbb', + 'distance': 96.0, + }), + OrderedDict({ + 'rowid': 3, + 'name': 'bbbb', + 'distance': 97.0, + }), + ]), + }) +# --- # name: test_long_text_knn[ge-bbbb] - dict({ - 'error': 'OperationalError', - 'message': 'Could not filter metadata fields', + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name >= ?", + 'rows': list([ + OrderedDict({ + 'rowid': 6, + 'name': 'cccccccccccc_ccc', + 'distance': 94.0, + }), + OrderedDict({ + 'rowid': 5, + 'name': 'cccc', + 'distance': 95.0, + }), + OrderedDict({ + 'rowid': 4, + 'name': 'bbbbbbbbbbbb_bbb', + 'distance': 96.0, + }), + OrderedDict({ + 'rowid': 3, + 'name': 'bbbb', + 'distance': 97.0, + }), + ]), + }) +# --- +# name: test_long_text_knn[ge-bbbbbb] + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name >= ?", + 'rows': list([ + OrderedDict({ + 'rowid': 6, + 'name': 'cccccccccccc_ccc', + 'distance': 94.0, + }), + OrderedDict({ + 'rowid': 5, + 'name': 'cccc', + 'distance': 95.0, + }), + OrderedDict({ + 'rowid': 4, + 'name': 'bbbbbbbbbbbb_bbb', + 'distance': 96.0, + }), + ]), }) # --- # name: test_long_text_knn[ge-bbbbbbbbbbbb_aaa] - dict({ - 'error': 'OperationalError', - 'message': 'Could not filter metadata fields', + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name >= ?", + 'rows': list([ + OrderedDict({ + 'rowid': 6, + 'name': 'cccccccccccc_ccc', + 'distance': 94.0, + }), + OrderedDict({ + 'rowid': 5, + 'name': 'cccc', + 'distance': 95.0, + }), + OrderedDict({ + 'rowid': 4, + 'name': 'bbbbbbbbbbbb_bbb', + 'distance': 96.0, + }), + ]), }) # --- # name: test_long_text_knn[ge-bbbbbbbbbbbb_bbb] - dict({ - 'error': 'OperationalError', - 'message': 'Could not filter metadata fields', + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name >= ?", + 'rows': list([ + OrderedDict({ + 'rowid': 6, + 'name': 'cccccccccccc_ccc', + 'distance': 94.0, + }), + OrderedDict({ + 'rowid': 5, + 'name': 'cccc', + 'distance': 95.0, + }), + OrderedDict({ + 'rowid': 4, + 'name': 'bbbbbbbbbbbb_bbb', + 'distance': 96.0, + }), + ]), }) # --- # name: test_long_text_knn[ge-bbbbbbbbbbbb_ccc] - dict({ - 'error': 'OperationalError', - 'message': 'Could not filter metadata fields', + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name >= ?", + 'rows': list([ + OrderedDict({ + 'rowid': 6, + 'name': 'cccccccccccc_ccc', + 'distance': 94.0, + }), + OrderedDict({ + 'rowid': 5, + 'name': 'cccc', + 'distance': 95.0, + }), + ]), }) # --- # name: test_long_text_knn[ge-longlonglonglonglonglonglong] - dict({ - 'error': 'OperationalError', - 'message': 'Could not filter metadata fields', + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name >= ?", + 'rows': list([ + ]), + }) +# --- +# name: test_long_text_knn[gt-bb] + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name > ?", + 'rows': list([ + OrderedDict({ + 'rowid': 6, + 'name': 'cccccccccccc_ccc', + 'distance': 94.0, + }), + OrderedDict({ + 'rowid': 5, + 'name': 'cccc', + 'distance': 95.0, + }), + OrderedDict({ + 'rowid': 4, + 'name': 'bbbbbbbbbbbb_bbb', + 'distance': 96.0, + }), + OrderedDict({ + 'rowid': 3, + 'name': 'bbbb', + 'distance': 97.0, + }), + ]), }) # --- # name: test_long_text_knn[gt-bbbb] @@ -714,6 +869,28 @@ ]), }) # --- +# name: test_long_text_knn[gt-bbbbbb] + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name > ?", + 'rows': list([ + OrderedDict({ + 'rowid': 6, + 'name': 'cccccccccccc_ccc', + 'distance': 94.0, + }), + OrderedDict({ + 'rowid': 5, + 'name': 'cccc', + 'distance': 95.0, + }), + OrderedDict({ + 'rowid': 4, + 'name': 'bbbbbbbbbbbb_bbb', + 'distance': 96.0, + }), + ]), + }) +# --- # name: test_long_text_knn[gt-bbbbbbbbbbbb_aaa] OrderedDict({ 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name > ?", @@ -777,12 +954,24 @@ ]), }) # --- +# name: test_long_text_knn[le-bb] + dict({ + 'error': 'OperationalError', + 'message': 'Could not filter metadata fields', + }) +# --- # name: test_long_text_knn[le-bbbb] dict({ 'error': 'OperationalError', 'message': 'Could not filter metadata fields', }) # --- +# name: test_long_text_knn[le-bbbbbb] + dict({ + 'error': 'OperationalError', + 'message': 'Could not filter metadata fields', + }) +# --- # name: test_long_text_knn[le-bbbbbbbbbbbb_aaa] dict({ 'error': 'OperationalError', @@ -807,12 +996,24 @@ 'message': 'Could not filter metadata fields', }) # --- +# name: test_long_text_knn[lt-bb] + dict({ + 'error': 'OperationalError', + 'message': 'Could not filter metadata fields', + }) +# --- # name: test_long_text_knn[lt-bbbb] dict({ 'error': 'OperationalError', 'message': 'Could not filter metadata fields', }) # --- +# name: test_long_text_knn[lt-bbbbbb] + dict({ + 'error': 'OperationalError', + 'message': 'Could not filter metadata fields', + }) +# --- # name: test_long_text_knn[lt-bbbbbbbbbbbb_aaa] dict({ 'error': 'OperationalError', @@ -837,6 +1038,38 @@ 'message': 'Could not filter metadata fields', }) # --- +# name: test_long_text_knn[ne-bb] + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name != ?", + 'rows': list([ + OrderedDict({ + 'rowid': 6, + 'name': 'cccccccccccc_ccc', + 'distance': 94.0, + }), + OrderedDict({ + 'rowid': 5, + 'name': 'cccc', + 'distance': 95.0, + }), + OrderedDict({ + 'rowid': 4, + 'name': 'bbbbbbbbbbbb_bbb', + 'distance': 96.0, + }), + OrderedDict({ + 'rowid': 3, + 'name': 'bbbb', + 'distance': 97.0, + }), + OrderedDict({ + 'rowid': 2, + 'name': 'aaaaaaaaaaaa_aaa', + 'distance': 98.0, + }), + ]), + }) +# --- # name: test_long_text_knn[ne-bbbb] OrderedDict({ 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name != ?", @@ -869,6 +1102,38 @@ ]), }) # --- +# name: test_long_text_knn[ne-bbbbbb] + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name != ?", + 'rows': list([ + OrderedDict({ + 'rowid': 6, + 'name': 'cccccccccccc_ccc', + 'distance': 94.0, + }), + OrderedDict({ + 'rowid': 5, + 'name': 'cccc', + 'distance': 95.0, + }), + OrderedDict({ + 'rowid': 4, + 'name': 'bbbbbbbbbbbb_bbb', + 'distance': 96.0, + }), + OrderedDict({ + 'rowid': 3, + 'name': 'bbbb', + 'distance': 97.0, + }), + OrderedDict({ + 'rowid': 2, + 'name': 'aaaaaaaaaaaa_aaa', + 'distance': 98.0, + }), + ]), + }) +# --- # name: test_long_text_knn[ne-bbbbbbbbbbbb_aaa] OrderedDict({ 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name != ?", diff --git a/tests/test-metadata.py b/tests/test-metadata.py index edf37d2..09eb468 100644 --- a/tests/test-metadata.py +++ b/tests/test-metadata.py @@ -147,6 +147,8 @@ def test_long_text_knn(db, snapshot): tests = [ "bbbb", + "bb", + "bbbbbb", "bbbbbbbbbbbb_bbb", "bbbbbbbbbbbb_aaa", "bbbbbbbbbbbb_ccc",