From 018e9789de2b647241b5eb7bc70cd27a6e2b21e8 Mon Sep 17 00:00:00 2001 From: Alex Garcia Date: Mon, 18 Nov 2024 14:09:07 -0800 Subject: [PATCH] text knn LT/LE --- TODO | 1 - sqlite-vec.c | 62 ++++- tests/__snapshots__/test-metadata.ambr | 328 +++++++++++++++++++++---- 3 files changed, 338 insertions(+), 53 deletions(-) diff --git a/TODO b/TODO index 6d2c596..895eb34 100644 --- a/TODO +++ b/TODO @@ -13,7 +13,6 @@ - perf: LEFT JOIN aux table to rowids query in vec0_cursor for rowid/point stmts, to avoid N lookup queries # metadata filtering -- text comparisons (long) - `v in (...)` handling - [ ] test accessing aux values when rowid is different than 1,2,3 etc. - [ ] add `xyz_info` shadow table with version etc. diff --git a/sqlite-vec.c b/sqlite-vec.c index 4a369c1..c02a59c 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -6027,21 +6027,63 @@ int vec0_metadata_filter_text(vec0_vtab * p, sqlite3_value * value, const void * } case VEC0_METADATA_OPERATOR_LE: { for(int i = 0; i < size; i++) { - u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; - int n = ((int*) view)[0]; - char * s = (char *) &view[4]; - if(n > VEC0_METADATA_TEXT_VIEW_DATA_LENGTH) {rc = SQLITE_ERROR;goto done;} /* TODO */ - bitmap_set(b, i, strncmp(s, sTarget, n) <= 0); + view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; + nPrefix = ((int*) view)[0]; + sPrefix = (char *) &view[4]; + int cmpPrefix = strncmp(sPrefix, sTarget, min(min(nPrefix, VEC0_METADATA_TEXT_VIEW_DATA_LENGTH), nTarget)); + + if(nPrefix < VEC0_METADATA_TEXT_VIEW_DATA_LENGTH) { + // if prefix match, check which is longer + if(cmpPrefix == 0) { + bitmap_set(b, i, nPrefix <= nTarget); + } + else { + bitmap_set(b, i, cmpPrefix <= 0); + } + continue; + } + // TODO(perf): may not need to compare full text in some cases + + rc = vec0_get_metadata_text_long_value(p, &stmt, metadata_idx, rowids[i], &nFull, &sFull); + if(rc != SQLITE_OK) { + goto done; + } + if(nPrefix != nFull) { + rc = SQLITE_ERROR; + goto done; + } + bitmap_set(b, i, strncmp(sFull, sTarget, nFull) <= 0); } break; } case VEC0_METADATA_OPERATOR_LT: { for(int i = 0; i < size; i++) { - u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; - int n = ((int*) view)[0]; - char * s = (char *) &view[4]; - if(n > VEC0_METADATA_TEXT_VIEW_DATA_LENGTH) {rc = SQLITE_ERROR;goto done;} /* TODO */ - bitmap_set(b, i, strncmp(s, sTarget, n) < 0); + view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; + nPrefix = ((int*) view)[0]; + sPrefix = (char *) &view[4]; + int cmpPrefix = strncmp(sPrefix, sTarget, min(min(nPrefix, VEC0_METADATA_TEXT_VIEW_DATA_LENGTH), nTarget)); + + if(nPrefix < VEC0_METADATA_TEXT_VIEW_DATA_LENGTH) { + // if prefix match, check which is longer + if(cmpPrefix == 0) { + bitmap_set(b, i, nPrefix < nTarget); + } + else { + bitmap_set(b, i, cmpPrefix < 0); + } + continue; + } + // TODO(perf): may not need to compare full text in some cases + + rc = vec0_get_metadata_text_long_value(p, &stmt, metadata_idx, rowids[i], &nFull, &sFull); + if(rc != SQLITE_OK) { + goto done; + } + if(nPrefix != nFull) { + rc = SQLITE_ERROR; + goto done; + } + bitmap_set(b, i, strncmp(sFull, sTarget, nFull) < 0); } break; } diff --git a/tests/__snapshots__/test-metadata.ambr b/tests/__snapshots__/test-metadata.ambr index 05b6a75..94b61b5 100644 --- a/tests/__snapshots__/test-metadata.ambr +++ b/tests/__snapshots__/test-metadata.ambr @@ -955,87 +955,331 @@ }) # --- # name: test_long_text_knn[le-bb] - dict({ - 'error': 'OperationalError', - 'message': 'Could not filter metadata fields', + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name <= ?", + 'rows': list([ + OrderedDict({ + 'rowid': 2, + 'name': 'aaaaaaaaaaaa_aaa', + 'distance': 98.0, + }), + OrderedDict({ + 'rowid': 1, + 'name': 'aaaa', + 'distance': 99.0, + }), + ]), }) # --- # name: test_long_text_knn[le-bbbb] - dict({ - 'error': 'OperationalError', - 'message': 'Could not filter metadata fields', + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name <= ?", + 'rows': list([ + OrderedDict({ + 'rowid': 3, + 'name': 'bbbb', + 'distance': 97.0, + }), + OrderedDict({ + 'rowid': 2, + 'name': 'aaaaaaaaaaaa_aaa', + 'distance': 98.0, + }), + OrderedDict({ + 'rowid': 1, + 'name': 'aaaa', + 'distance': 99.0, + }), + ]), }) # --- # name: test_long_text_knn[le-bbbbbb] - dict({ - 'error': 'OperationalError', - 'message': 'Could not filter metadata fields', + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name <= ?", + 'rows': list([ + OrderedDict({ + 'rowid': 3, + 'name': 'bbbb', + 'distance': 97.0, + }), + OrderedDict({ + 'rowid': 2, + 'name': 'aaaaaaaaaaaa_aaa', + 'distance': 98.0, + }), + OrderedDict({ + 'rowid': 1, + 'name': 'aaaa', + 'distance': 99.0, + }), + ]), }) # --- # name: test_long_text_knn[le-bbbbbbbbbbbb_aaa] - dict({ - 'error': 'OperationalError', - 'message': 'Could not filter metadata fields', + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name <= ?", + 'rows': list([ + OrderedDict({ + 'rowid': 3, + 'name': 'bbbb', + 'distance': 97.0, + }), + OrderedDict({ + 'rowid': 2, + 'name': 'aaaaaaaaaaaa_aaa', + 'distance': 98.0, + }), + OrderedDict({ + 'rowid': 1, + 'name': 'aaaa', + 'distance': 99.0, + }), + ]), }) # --- # name: test_long_text_knn[le-bbbbbbbbbbbb_bbb] - dict({ - 'error': 'OperationalError', - 'message': 'Could not filter metadata fields', + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name <= ?", + 'rows': list([ + OrderedDict({ + 'rowid': 4, + 'name': 'bbbbbbbbbbbb_bbb', + 'distance': 96.0, + }), + OrderedDict({ + 'rowid': 3, + 'name': 'bbbb', + 'distance': 97.0, + }), + OrderedDict({ + 'rowid': 2, + 'name': 'aaaaaaaaaaaa_aaa', + 'distance': 98.0, + }), + OrderedDict({ + 'rowid': 1, + 'name': 'aaaa', + 'distance': 99.0, + }), + ]), }) # --- # name: test_long_text_knn[le-bbbbbbbbbbbb_ccc] - dict({ - 'error': 'OperationalError', - 'message': 'Could not filter metadata fields', + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name <= ?", + 'rows': list([ + OrderedDict({ + 'rowid': 4, + 'name': 'bbbbbbbbbbbb_bbb', + 'distance': 96.0, + }), + OrderedDict({ + 'rowid': 3, + 'name': 'bbbb', + 'distance': 97.0, + }), + OrderedDict({ + 'rowid': 2, + 'name': 'aaaaaaaaaaaa_aaa', + 'distance': 98.0, + }), + OrderedDict({ + 'rowid': 1, + 'name': 'aaaa', + 'distance': 99.0, + }), + ]), }) # --- # name: test_long_text_knn[le-longlonglonglonglonglonglong] - dict({ - 'error': 'OperationalError', - 'message': 'Could not filter metadata fields', + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name <= ?", + 'rows': list([ + OrderedDict({ + 'rowid': 6, + 'name': 'cccccccccccc_ccc', + 'distance': 94.0, + }), + OrderedDict({ + 'rowid': 5, + 'name': 'cccc', + 'distance': 95.0, + }), + OrderedDict({ + 'rowid': 4, + 'name': 'bbbbbbbbbbbb_bbb', + 'distance': 96.0, + }), + OrderedDict({ + 'rowid': 3, + 'name': 'bbbb', + 'distance': 97.0, + }), + OrderedDict({ + 'rowid': 2, + 'name': 'aaaaaaaaaaaa_aaa', + 'distance': 98.0, + }), + ]), }) # --- # name: test_long_text_knn[lt-bb] - dict({ - 'error': 'OperationalError', - 'message': 'Could not filter metadata fields', + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name < ?", + 'rows': list([ + OrderedDict({ + 'rowid': 2, + 'name': 'aaaaaaaaaaaa_aaa', + 'distance': 98.0, + }), + OrderedDict({ + 'rowid': 1, + 'name': 'aaaa', + 'distance': 99.0, + }), + ]), }) # --- # name: test_long_text_knn[lt-bbbb] - dict({ - 'error': 'OperationalError', - 'message': 'Could not filter metadata fields', + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name < ?", + 'rows': list([ + OrderedDict({ + 'rowid': 2, + 'name': 'aaaaaaaaaaaa_aaa', + 'distance': 98.0, + }), + OrderedDict({ + 'rowid': 1, + 'name': 'aaaa', + 'distance': 99.0, + }), + ]), }) # --- # name: test_long_text_knn[lt-bbbbbb] - dict({ - 'error': 'OperationalError', - 'message': 'Could not filter metadata fields', + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name < ?", + 'rows': list([ + OrderedDict({ + 'rowid': 3, + 'name': 'bbbb', + 'distance': 97.0, + }), + OrderedDict({ + 'rowid': 2, + 'name': 'aaaaaaaaaaaa_aaa', + 'distance': 98.0, + }), + OrderedDict({ + 'rowid': 1, + 'name': 'aaaa', + 'distance': 99.0, + }), + ]), }) # --- # name: test_long_text_knn[lt-bbbbbbbbbbbb_aaa] - dict({ - 'error': 'OperationalError', - 'message': 'Could not filter metadata fields', + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name < ?", + 'rows': list([ + OrderedDict({ + 'rowid': 3, + 'name': 'bbbb', + 'distance': 97.0, + }), + OrderedDict({ + 'rowid': 2, + 'name': 'aaaaaaaaaaaa_aaa', + 'distance': 98.0, + }), + OrderedDict({ + 'rowid': 1, + 'name': 'aaaa', + 'distance': 99.0, + }), + ]), }) # --- # name: test_long_text_knn[lt-bbbbbbbbbbbb_bbb] - dict({ - 'error': 'OperationalError', - 'message': 'Could not filter metadata fields', + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name < ?", + 'rows': list([ + OrderedDict({ + 'rowid': 3, + 'name': 'bbbb', + 'distance': 97.0, + }), + OrderedDict({ + 'rowid': 2, + 'name': 'aaaaaaaaaaaa_aaa', + 'distance': 98.0, + }), + OrderedDict({ + 'rowid': 1, + 'name': 'aaaa', + 'distance': 99.0, + }), + ]), }) # --- # name: test_long_text_knn[lt-bbbbbbbbbbbb_ccc] - dict({ - 'error': 'OperationalError', - 'message': 'Could not filter metadata fields', + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name < ?", + 'rows': list([ + OrderedDict({ + 'rowid': 4, + 'name': 'bbbbbbbbbbbb_bbb', + 'distance': 96.0, + }), + OrderedDict({ + 'rowid': 3, + 'name': 'bbbb', + 'distance': 97.0, + }), + OrderedDict({ + 'rowid': 2, + 'name': 'aaaaaaaaaaaa_aaa', + 'distance': 98.0, + }), + OrderedDict({ + 'rowid': 1, + 'name': 'aaaa', + 'distance': 99.0, + }), + ]), }) # --- # name: test_long_text_knn[lt-longlonglonglonglonglonglong] - dict({ - 'error': 'OperationalError', - 'message': 'Could not filter metadata fields', + OrderedDict({ + 'sql': "select rowid, name, distance from v where vector match '[100]' and k = 5 and name < ?", + 'rows': list([ + OrderedDict({ + 'rowid': 6, + 'name': 'cccccccccccc_ccc', + 'distance': 94.0, + }), + OrderedDict({ + 'rowid': 5, + 'name': 'cccc', + 'distance': 95.0, + }), + OrderedDict({ + 'rowid': 4, + 'name': 'bbbbbbbbbbbb_bbb', + 'distance': 96.0, + }), + OrderedDict({ + 'rowid': 3, + 'name': 'bbbb', + 'distance': 97.0, + }), + OrderedDict({ + 'rowid': 2, + 'name': 'aaaaaaaaaaaa_aaa', + 'distance': 98.0, + }), + ]), }) # --- # name: test_long_text_knn[ne-bb]