From 4ba167c315ef32f5b024a10e8f2cd76a166c0f85 Mon Sep 17 00:00:00 2001 From: Alex Garcia Date: Mon, 18 Nov 2024 12:15:25 -0800 Subject: [PATCH] text knn NE --- sqlite-vec.c | 47 ++++++-- tests/__snapshots__/test-metadata.ambr | 160 ++++++++++++++++++++++--- tests/test-metadata.py | 2 +- 3 files changed, 185 insertions(+), 24 deletions(-) diff --git a/sqlite-vec.c b/sqlite-vec.c index 6a62cf3..3dedf37 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -5885,13 +5885,16 @@ int vec0_metadata_filter_text(vec0_vtab * p, sqlite3_value * value, const void * sqlite3_blob_close(rowidsBlob); switch(op) { + int nPrefix; + char * sPrefix; char *sFull; int nFull; + u8 * view; case VEC0_METADATA_OPERATOR_EQ: { for(int i = 0; i < size; i++) { - u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; - int nPrefix = ((int*) view)[0]; - char * sPrefix = (char *) &view[4]; + view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; + nPrefix = ((int*) view)[0]; + sPrefix = (char *) &view[4]; // for EQ the text lengths must match if(nPrefix != nTarget) { @@ -5925,11 +5928,39 @@ int vec0_metadata_filter_text(vec0_vtab * p, sqlite3_value * value, const void * } case VEC0_METADATA_OPERATOR_NE: { for(int i = 0; i < size; i++) { - u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; - int n = ((int*) view)[0]; - char * s = (char *) &view[4]; - if(n > VEC0_METADATA_TEXT_VIEW_DATA_LENGTH) {rc = SQLITE_ERROR;goto done;} /* TODO */ - bitmap_set(b, i, strncmp(s, sTarget, n) != 0); + view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; + nPrefix = ((int*) view)[0]; + sPrefix = (char *) &view[4]; + + // for NE if text lengths dont match, it never will + if(nPrefix != nTarget) { + bitmap_set(b, i, 1); + continue; + } + + int cmpPrefix = strncmp(sPrefix, sTarget, min(nPrefix, VEC0_METADATA_TEXT_VIEW_DATA_LENGTH)); + + // for short strings, use the prefix comparison direclty + if(nPrefix <= VEC0_METADATA_TEXT_VIEW_DATA_LENGTH) { + bitmap_set(b, i, cmpPrefix != 0); + continue; + } + // for NE on longs strings, if prefixes dont match, then long string wont + if(cmpPrefix) { + bitmap_set(b, i, 1); + continue; + } + // consult the full string + rc = vec0_get_metadata_text_long_value(p, &stmt, metadata_idx, rowids[i], &nFull, &sFull); + if(rc != SQLITE_OK) { + goto done; + } + if(nPrefix != nFull) { + rc = SQLITE_ERROR; + goto done; + } + bitmap_set(b, i, strncmp(sFull, sTarget, nFull) != 0); + } break; } diff --git a/tests/__snapshots__/test-metadata.ambr b/tests/__snapshots__/test-metadata.ambr index 4ec9eca..537df26 100644 --- a/tests/__snapshots__/test-metadata.ambr +++ b/tests/__snapshots__/test-metadata.ambr @@ -783,33 +783,163 @@ }) # --- # name: test_long_text_knn[ne-bbbb] - dict({ - 'error': 'OperationalError', - 'message': 'unrecognized token: "!"', + OrderedDict({ + 'sql': "select * from v where vector match X'11111111' and k = 5 and name != ?", + 'rows': list([ + OrderedDict({ + 'rowid': 6, + 'vector': b'\x11\x11\x11\x11', + 'name': 'cccccccccccc_ccc', + }), + OrderedDict({ + 'rowid': 5, + 'vector': b'\x11\x11\x11\x11', + 'name': 'cccc', + }), + OrderedDict({ + 'rowid': 4, + 'vector': b'\x11\x11\x11\x11', + 'name': 'bbbbbbbbbbbb_bbb', + }), + OrderedDict({ + 'rowid': 2, + 'vector': b'\x11\x11\x11\x11', + 'name': 'aaaaaaaaaaaa_aaa', + }), + OrderedDict({ + 'rowid': 1, + 'vector': b'\x11\x11\x11\x11', + 'name': 'aaaa', + }), + ]), }) # --- # name: test_long_text_knn[ne-bbbbbbbbbbbb_aaa] - dict({ - 'error': 'OperationalError', - 'message': 'unrecognized token: "!"', + OrderedDict({ + 'sql': "select * from v where vector match X'11111111' and k = 5 and name != ?", + 'rows': list([ + OrderedDict({ + 'rowid': 6, + 'vector': b'\x11\x11\x11\x11', + 'name': 'cccccccccccc_ccc', + }), + OrderedDict({ + 'rowid': 5, + 'vector': b'\x11\x11\x11\x11', + 'name': 'cccc', + }), + OrderedDict({ + 'rowid': 4, + 'vector': b'\x11\x11\x11\x11', + 'name': 'bbbbbbbbbbbb_bbb', + }), + OrderedDict({ + 'rowid': 3, + 'vector': b'\x11\x11\x11\x11', + 'name': 'bbbb', + }), + OrderedDict({ + 'rowid': 2, + 'vector': b'\x11\x11\x11\x11', + 'name': 'aaaaaaaaaaaa_aaa', + }), + ]), }) # --- # name: test_long_text_knn[ne-bbbbbbbbbbbb_bbb] - dict({ - 'error': 'OperationalError', - 'message': 'unrecognized token: "!"', + OrderedDict({ + 'sql': "select * from v where vector match X'11111111' and k = 5 and name != ?", + 'rows': list([ + OrderedDict({ + 'rowid': 6, + 'vector': b'\x11\x11\x11\x11', + 'name': 'cccccccccccc_ccc', + }), + OrderedDict({ + 'rowid': 5, + 'vector': b'\x11\x11\x11\x11', + 'name': 'cccc', + }), + OrderedDict({ + 'rowid': 3, + 'vector': b'\x11\x11\x11\x11', + 'name': 'bbbb', + }), + OrderedDict({ + 'rowid': 2, + 'vector': b'\x11\x11\x11\x11', + 'name': 'aaaaaaaaaaaa_aaa', + }), + OrderedDict({ + 'rowid': 1, + 'vector': b'\x11\x11\x11\x11', + 'name': 'aaaa', + }), + ]), }) # --- # name: test_long_text_knn[ne-bbbbbbbbbbbb_ccc] - dict({ - 'error': 'OperationalError', - 'message': 'unrecognized token: "!"', + OrderedDict({ + 'sql': "select * from v where vector match X'11111111' and k = 5 and name != ?", + 'rows': list([ + OrderedDict({ + 'rowid': 6, + 'vector': b'\x11\x11\x11\x11', + 'name': 'cccccccccccc_ccc', + }), + OrderedDict({ + 'rowid': 5, + 'vector': b'\x11\x11\x11\x11', + 'name': 'cccc', + }), + OrderedDict({ + 'rowid': 4, + 'vector': b'\x11\x11\x11\x11', + 'name': 'bbbbbbbbbbbb_bbb', + }), + OrderedDict({ + 'rowid': 3, + 'vector': b'\x11\x11\x11\x11', + 'name': 'bbbb', + }), + OrderedDict({ + 'rowid': 2, + 'vector': b'\x11\x11\x11\x11', + 'name': 'aaaaaaaaaaaa_aaa', + }), + ]), }) # --- # name: test_long_text_knn[ne-longlonglonglonglonglonglong] - dict({ - 'error': 'OperationalError', - 'message': 'unrecognized token: "!"', + OrderedDict({ + 'sql': "select * from v where vector match X'11111111' and k = 5 and name != ?", + 'rows': list([ + OrderedDict({ + 'rowid': 6, + 'vector': b'\x11\x11\x11\x11', + 'name': 'cccccccccccc_ccc', + }), + OrderedDict({ + 'rowid': 5, + 'vector': b'\x11\x11\x11\x11', + 'name': 'cccc', + }), + OrderedDict({ + 'rowid': 4, + 'vector': b'\x11\x11\x11\x11', + 'name': 'bbbbbbbbbbbb_bbb', + }), + OrderedDict({ + 'rowid': 3, + 'vector': b'\x11\x11\x11\x11', + 'name': 'bbbb', + }), + OrderedDict({ + 'rowid': 2, + 'vector': b'\x11\x11\x11\x11', + 'name': 'aaaaaaaaaaaa_aaa', + }), + ]), }) # --- # name: test_long_text_updates diff --git a/tests/test-metadata.py b/tests/test-metadata.py index 73affac..ed7fd2a 100644 --- a/tests/test-metadata.py +++ b/tests/test-metadata.py @@ -152,7 +152,7 @@ def test_long_text_knn(db, snapshot): "bbbbbbbbbbbb_ccc", "longlonglonglonglonglonglong", ] - ops = ["=", "!-", "<", "<=", ">", ">="] + ops = ["=", "!=", "<", "<=", ">", ">="] op_names = ["eq", "ne", "lt", "le", "gt", "ge"] for test in tests: