diff --git a/.gitignore b/.gitignore index ad7d0d0..0268d5d 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,6 @@ tmp/ poetry.lock *.jsonl + +memstat.c +memstat.* diff --git a/TODO b/TODO index 895eb34..828d0f4 100644 --- a/TODO +++ b/TODO @@ -22,3 +22,4 @@ - remaining TODO items - skip invalid validity entries in knn filter? - dictionary encoding? + - partition `x in (...)` handling diff --git a/sqlite-vec.c b/sqlite-vec.c index c02a59c..0e0d58a 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -5265,6 +5265,7 @@ typedef enum { VEC0_METADATA_OPERATOR_LT = 'd', VEC0_METADATA_OPERATOR_GE = 'e', VEC0_METADATA_OPERATOR_NE = 'f', + VEC0_METADATA_OPERATOR_IN = 'g', } vec0_metadata_operator; static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { @@ -5498,7 +5499,33 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { switch(op) { case SQLITE_INDEX_CONSTRAINT_EQ: { - value = VEC0_METADATA_OPERATOR_EQ; + int vtabIn = 0; + #if COMPILER_SUPPORTS_VTAB_IN + if (sqlite3_libversion_number() >= 3038000) { + vtabIn = sqlite3_vtab_in(pIdxInfo, i, -1); + } + #endif + if(vtabIn) { + switch(p->metadata_columns[metadata_idx].kind) { + case VEC0_METADATA_COLUMN_KIND_FLOAT: + case VEC0_METADATA_COLUMN_KIND_BOOLEAN: { + // IMP: TODO + rc = SQLITE_ERROR; + vtab_set_error(pVTab, "'xxx in (...)' is only available on INTEGER or TEXT metadata columns."); + goto done; + break; + } + case VEC0_METADATA_COLUMN_KIND_INTEGER: + case VEC0_METADATA_COLUMN_KIND_TEXT: { + break; + } + } + value = VEC0_METADATA_OPERATOR_IN; + sqlite3_vtab_in(pIdxInfo, i, 1); + } + else { + value = VEC0_PARTITION_OPERATOR_EQ; + } break; } case SQLITE_INDEX_CONSTRAINT_GT: { @@ -5852,7 +5879,24 @@ int vec0_chunks_iter(vec0_vtab * p, const char * idxStr, int argc, sqlite3_value return rc; } -int vec0_metadata_filter_text(vec0_vtab * p, sqlite3_value * value, const void * buffer, int size, vec0_metadata_operator op, u8* b, int metadata_idx, int chunk_rowid) { +// a single `xxx in (...)` constraint on a metadata column. TEXT or INTEGER only for now. +struct Vec0MetadataIn{ + // index of argv[i]` the constraint is on + int argv_idx; + // metadata column index of the constraint, derived from idxStr + argv_idx + int metadata_idx; + // array of the copied `(...)` values from sqlite3_vtab_in_first()/sqlite3_vtab_in_next() + struct Array array; +}; + +// Array elements for `xxx in (...)` values for a text column. basically just a string +struct Vec0MetadataInTextEntry { + int n; + char * zString; +}; + + +int vec0_metadata_filter_text(vec0_vtab * p, sqlite3_value * value, const void * buffer, int size, vec0_metadata_operator op, u8* b, int metadata_idx, int chunk_rowid, struct Array * aMetadataIn, int argv_idx) { int rc; sqlite3_stmt * stmt = NULL; i64 * rowids = NULL; @@ -6088,6 +6132,66 @@ int vec0_metadata_filter_text(vec0_vtab * p, sqlite3_value * value, const void * break; } + case VEC0_METADATA_OPERATOR_IN: { + size_t metadataInIdx = -1; + for(size_t i = 0; i < aMetadataIn->length; i++) { + struct Vec0MetadataIn * metadataIn = &(((struct Vec0MetadataIn *) aMetadataIn->z)[i]); + if(metadataIn->argv_idx == argv_idx) { + metadataInIdx = i; + break; + } + } + if(metadataInIdx < 0) { + abort(); // TODO + } + + struct Vec0MetadataIn * metadataIn = &((struct Vec0MetadataIn *) aMetadataIn->z)[metadataInIdx]; + struct Array * aTarget = &(metadataIn->array); + + + int nPrefix; + char * sPrefix; + char *sFull; + int nFull; + u8 * view; + for(int i = 0; i < size; i++) { + view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; + nPrefix = ((int*) view)[0]; + sPrefix = (char *) &view[4]; + for(size_t target_idx = 0; target_idx < aTarget->length; target_idx++) { + struct Vec0MetadataInTextEntry * entry = &(((struct Vec0MetadataInTextEntry*)aTarget->z)[target_idx]); + if(entry->n != nPrefix) { + continue; + } + int cmpPrefix = strncmp(sPrefix, entry->zString, min(nPrefix, VEC0_METADATA_TEXT_VIEW_DATA_LENGTH)); + if(nPrefix <= VEC0_METADATA_TEXT_VIEW_DATA_LENGTH) { + if(cmpPrefix == 0) { + bitmap_set(b, i, 1); + break; + } + continue; + } + if(cmpPrefix) { + continue; + } + + rc = vec0_get_metadata_text_long_value(p, &stmt, metadata_idx, rowids[i], &nFull, &sFull); + if(rc != SQLITE_OK) { + goto done; + } + if(nPrefix != nFull) { + rc = SQLITE_ERROR; + goto done; + } + if(strncmp(sFull, entry->zString, nFull) == 0) { + bitmap_set(b, i, 1); + break; + } + } + } + break; + } + } rc = SQLITE_OK; @@ -6118,7 +6222,8 @@ int vec0_set_metadata_filter_bitmap( sqlite3_blob * blob, i64 chunk_rowid, u8* b, - int size) { + int size, + struct Array * aMetadataIn, int argv_idx) { // TODO: shouldn't this skip in-valid entries from the chunk's validity bitmap? int rc; @@ -6198,6 +6303,31 @@ int vec0_set_metadata_filter_bitmap( for(int i = 0; i < size; i++) { bitmap_set(b, i, array[i] != target); } break; } + case VEC0_METADATA_OPERATOR_IN: { + int metadataInIdx = -1; + for(size_t i = 0; i < aMetadataIn->length; i++) { + struct Vec0MetadataIn * metadataIn = &((struct Vec0MetadataIn *) aMetadataIn->z)[i]; + if(metadataIn->argv_idx == argv_idx) { + metadataInIdx = i; + break; + } + } + if(metadataInIdx < 0) { + abort(); // TODO + } + struct Vec0MetadataIn * metadataIn = &((struct Vec0MetadataIn *) aMetadataIn->z)[metadataInIdx]; + struct Array * aTarget = &(metadataIn->array); + + for(int i = 0; i < size; i++) { + for(size_t target_idx = 0; target_idx < aTarget->length; target_idx++) { + if( ((i64*)aTarget->z)[target_idx] == array[i]) { + bitmap_set(b, i, 1); + break; + } + } + } + break; + } } break; } @@ -6229,11 +6359,15 @@ int vec0_set_metadata_filter_bitmap( for(int i = 0; i < size; i++) { bitmap_set(b, i, array[i] != target); } break; } + case VEC0_METADATA_OPERATOR_IN: { + // should never be reached + break; + } } break; } case VEC0_METADATA_COLUMN_KIND_TEXT: { - rc = vec0_metadata_filter_text(p, value, buffer, size, op, b, metadata_idx, chunk_rowid); + rc = vec0_metadata_filter_text(p, value, buffer, size, op, b, metadata_idx, chunk_rowid, aMetadataIn, argv_idx); if(rc != SQLITE_OK) { goto done; } @@ -6248,6 +6382,7 @@ int vec0_set_metadata_filter_bitmap( int vec0Filter_knn_chunks_iter(vec0_vtab *p, sqlite3_stmt *stmtChunks, struct VectorColumnDefinition *vector_column, int vectorColumnIdx, struct Array *arrayRowidsIn, + struct Array * aMetadataIn, const char * idxStr, int argc, sqlite3_value ** argv, void *queryVector, i64 k, i64 **out_topk_rowids, f32 **out_topk_distances, i64 *out_used) { @@ -6472,7 +6607,7 @@ int vec0Filter_knn_chunks_iter(vec0_vtab *p, sqlite3_stmt *stmtChunks, } bitmap_clear(bmMetadata, p->chunk_size); - rc = vec0_set_metadata_filter_bitmap(p, metadata_idx, operator, argv[i], metadataBlobs[metadata_idx], chunk_id, bmMetadata, p->chunk_size); + rc = vec0_set_metadata_filter_bitmap(p, metadata_idx, operator, argv[i], metadataBlobs[metadata_idx], chunk_id, bmMetadata, p->chunk_size, aMetadataIn, i); if(rc != SQLITE_OK) { vtab_set_error(&p->base, "Could not filter metadata fields"); if(rc != SQLITE_OK) { @@ -6619,6 +6754,9 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum, return SQLITE_NOMEM; } memset(knn_data, 0, sizeof(*knn_data)); + // array of `struct Vec0MetadataIn`, IF there are any `xxx in (...)` metadata constraints + struct Array * aMetadataIn = NULL; + int query_idx =-1; int k_idx = -1; @@ -6738,6 +6876,95 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum, } #endif + #if COMPILER_SUPPORTS_VTAB_IN + for(int i = 0; i < argc; i++) { + if(!(idxStr[1 + (i*4)] == VEC0_IDXSTR_KIND_METADATA_CONSTRAINT && idxStr[1 + (i*4) + 2] == VEC0_METADATA_OPERATOR_IN)) { + continue; + } + int metadata_idx = idxStr[1 + (i*4) + 1] - 'A'; + if(!aMetadataIn) { + aMetadataIn = sqlite3_malloc(sizeof(*aMetadataIn)); + if(!aMetadataIn) { + rc = SQLITE_NOMEM; + goto cleanup; + } + memset(aMetadataIn, 0, sizeof(*aMetadataIn)); + rc = array_init(aMetadataIn, sizeof(struct Vec0MetadataIn), 8); + if(rc != SQLITE_OK) { + goto cleanup; + } + } + + struct Vec0MetadataIn item; + memset(&item, 0, sizeof(item)); + item.metadata_idx=metadata_idx; + item.argv_idx = i; + + switch(p->metadata_columns[metadata_idx].kind) { + case VEC0_METADATA_COLUMN_KIND_INTEGER: { + rc = array_init(&item.array, sizeof(i64), 16); + if(rc != SQLITE_OK) { + goto cleanup; + } + sqlite3_value *entry; + for (rc = sqlite3_vtab_in_first(argv[i], &entry); rc == SQLITE_OK && entry; rc = sqlite3_vtab_in_next(argv[i], &entry)) { + i64 v = sqlite3_value_int64(entry); + rc = array_append(&item.array, &v); + if (rc != SQLITE_OK) { + goto cleanup; + } + } + + if (rc != SQLITE_DONE) { + vtab_set_error(&p->base, "fuck"); // TODO + goto cleanup; + } + + break; + } + case VEC0_METADATA_COLUMN_KIND_TEXT: { + rc = array_init(&item.array, sizeof(struct Vec0MetadataInTextEntry), 16); + if(rc != SQLITE_OK) { + goto cleanup; + } + sqlite3_value *entry; + for (rc = sqlite3_vtab_in_first(argv[i], &entry); rc == SQLITE_OK && entry; rc = sqlite3_vtab_in_next(argv[i], &entry)) { + const char * s = (const char *) sqlite3_value_text(entry); + int n = sqlite3_value_bytes(entry); + + struct Vec0MetadataInTextEntry entry; + // TODO if this exits early, does it get properly cleaned up + entry.zString = sqlite3_mprintf("%.*s", n, s); + if(!entry.zString) { + rc = SQLITE_NOMEM; + goto cleanup; + } + entry.n = n; + rc = array_append(&item.array, &entry); + if (rc != SQLITE_OK) { + goto cleanup; + } + } + + if (rc != SQLITE_DONE) { + vtab_set_error(&p->base, "fuck"); // TODO + goto cleanup; + } + + break; + } + default: { + abort(); + } + } + + rc = array_append(aMetadataIn, &item); + if(rc != SQLITE_OK) { + abort(); // TODO + } + } + #endif + rc = vec0_chunks_iter(p, idxStr, argc, argv, &stmtChunks); if (rc != SQLITE_OK) { // IMP: V06942_23781 @@ -6750,7 +6977,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum, f32 *topk_distances = NULL; i64 k_used = 0; rc = vec0Filter_knn_chunks_iter(p, stmtChunks, vector_column, vectorColumnIdx, - arrayRowidsIn, idxStr, argc, argv, queryVector, k, &topk_rowids, + arrayRowidsIn, aMetadataIn, idxStr, argc, argv, queryVector, k, &topk_rowids, &topk_distances, &k_used); if (rc != SQLITE_OK) { goto cleanup; @@ -6771,6 +6998,21 @@ cleanup: array_cleanup(arrayRowidsIn); sqlite3_free(arrayRowidsIn); queryVectorCleanup(queryVector); + if(aMetadataIn) { + for(size_t i = 0; i < aMetadataIn->length; i++) { + struct Vec0MetadataIn* item = &((struct Vec0MetadataIn *) aMetadataIn->z)[i]; + for(size_t j = 0; j < item->array.length; j++) { + if(p->metadata_columns[item->metadata_idx].kind == VEC0_METADATA_COLUMN_KIND_TEXT) { + struct Vec0MetadataInTextEntry entry = ((struct Vec0MetadataInTextEntry*)item->array.z)[j]; + sqlite3_free(entry.zString); + } + } + array_cleanup(&item->array); + } + array_cleanup(aMetadataIn); + } + + sqlite3_free(aMetadataIn); return rc; } @@ -7049,7 +7291,8 @@ static int vec0Column_fullscan(vec0_vtab *pVtab, vec0_cursor *pCur, int metadata_idx = vec0_column_idx_to_metadata_idx(pVtab, i); int rc = vec0_result_metadata_value_for_rowid(pVtab, rowid, metadata_idx, context); if(rc != SQLITE_OK) { - sqlite3_result_error(context, "fuck todo", -1); + // TODO handle + sqlite3_result_error(context, "fuck", -1); } } return SQLITE_OK; @@ -7121,7 +7364,8 @@ static int vec0Column_point(vec0_vtab *pVtab, vec0_cursor *pCur, int metadata_idx = vec0_column_idx_to_metadata_idx(pVtab, i); int rc = vec0_result_metadata_value_for_rowid(pVtab, rowid, metadata_idx, context); if(rc != SQLITE_OK) { - sqlite3_result_error(context, "fuck todo", -1); + // TODO handle + sqlite3_result_error(context, "fuck", -1); } } @@ -7188,7 +7432,8 @@ static int vec0Column_knn(vec0_vtab *pVtab, vec0_cursor *pCur, i64 rowid = pCur->knn_data->rowids[pCur->knn_data->current_idx]; int rc = vec0_result_metadata_value_for_rowid(pVtab, rowid, metadata_idx, context); if(rc != SQLITE_OK) { - sqlite3_result_error(context, "fuck todo", -1); + // TODO: handle + sqlite3_result_error(context, "fuck", -1); } } diff --git a/test.sql b/test.sql index 1d15c69..8cd3f30 100644 --- a/test.sql +++ b/test.sql @@ -5,6 +5,54 @@ .mode qbox +.load ./memstat +.echo on + +select name, value from sqlite_memstat where name = 'MEMORY_USED'; + +create virtual table v using vec0( + vector float[1], + name1 text, + name2 text, + age int, + chunk_size=8 +); + +select name, value from sqlite_memstat where name = 'MEMORY_USED'; + +insert into v(vector, name1, name2, age) values + ('[1]', 'alex', 'xxxx', 1), + ('[2]', 'alex', 'aaaa', 2), + ('[3]', 'alex', 'aaaa', 3), + ('[4]', 'brian', 'aaaa', 1), + ('[5]', 'brian', 'aaaa', 2), + ('[6]', 'brian', 'aaaa', 3), + ('[7]', 'craig', 'aaaa', 1), + ('[8]', 'craig', 'xxxx', 2), + ('[9]', 'craig', 'xxxx', 3), + ('[10]', '123456789012345', 'xxxx', 3); + +select name, value from sqlite_memstat where name = 'MEMORY_USED'; + +select rowid, name1, name2, age, vec_to_json(vector) +from v +where vector match '[0]' + and k = 5 + and name1 in ('alex', 'brian', 'craig') + --and name2 in ('aaaa', 'xxxx') + and age in (1, 2, 3, 2222,3333,4444); + +select name, value from sqlite_memstat where name = 'MEMORY_USED'; + +select rowid, name1, name2, age, vec_to_json(vector) +from v +where vector match '[0]' + and k = 5 + and name1 in ('123456789012345', 'superfluous'); + + +.exit + create virtual table v using vec0( vector float[1], +description text diff --git a/tests/__snapshots__/test-metadata.ambr b/tests/__snapshots__/test-metadata.ambr index 94b61b5..c616279 100644 --- a/tests/__snapshots__/test-metadata.ambr +++ b/tests/__snapshots__/test-metadata.ambr @@ -3806,3 +3806,260 @@ }), }) # --- +# name: test_vtab_in[allow-int-all] + OrderedDict({ + 'sql': "select rowid, n, distance from v where vector match '[0]' and k = 8 and n in (555, 999)", + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'n': 999, + 'distance': 1.0, + }), + OrderedDict({ + 'rowid': 2, + 'n': 555, + 'distance': 2.0, + }), + OrderedDict({ + 'rowid': 3, + 'n': 999, + 'distance': 3.0, + }), + OrderedDict({ + 'rowid': 4, + 'n': 555, + 'distance': 4.0, + }), + OrderedDict({ + 'rowid': 5, + 'n': 999, + 'distance': 5.0, + }), + OrderedDict({ + 'rowid': 6, + 'n': 555, + 'distance': 6.0, + }), + OrderedDict({ + 'rowid': 7, + 'n': 999, + 'distance': 7.0, + }), + OrderedDict({ + 'rowid': 8, + 'n': 555, + 'distance': 8.0, + }), + ]), + }) +# --- +# name: test_vtab_in[allow-int-superfluous] + OrderedDict({ + 'sql': "select rowid, n, distance from v where vector match '[0]' and k = 8 and n in (555, -1, -2)", + 'rows': list([ + OrderedDict({ + 'rowid': 2, + 'n': 555, + 'distance': 2.0, + }), + OrderedDict({ + 'rowid': 4, + 'n': 555, + 'distance': 4.0, + }), + OrderedDict({ + 'rowid': 6, + 'n': 555, + 'distance': 6.0, + }), + OrderedDict({ + 'rowid': 8, + 'n': 555, + 'distance': 8.0, + }), + ]), + }) +# --- +# name: test_vtab_in[allow-text-all] + OrderedDict({ + 'sql': "select rowid, t, distance from v where vector match '[0]' and k = 8 and t in ('aaaa', 'zzzz')", + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 't': 'aaaa', + 'distance': 1.0, + }), + OrderedDict({ + 'rowid': 2, + 't': 'aaaa', + 'distance': 2.0, + }), + OrderedDict({ + 'rowid': 3, + 't': 'aaaa', + 'distance': 3.0, + }), + OrderedDict({ + 'rowid': 4, + 't': 'aaaa', + 'distance': 4.0, + }), + OrderedDict({ + 'rowid': 5, + 't': 'zzzz', + 'distance': 5.0, + }), + OrderedDict({ + 'rowid': 6, + 't': 'zzzz', + 'distance': 6.0, + }), + OrderedDict({ + 'rowid': 7, + 't': 'zzzz', + 'distance': 7.0, + }), + OrderedDict({ + 'rowid': 8, + 't': 'zzzz', + 'distance': 8.0, + }), + ]), + }) +# --- +# name: test_vtab_in[allow-text-superfluous] + OrderedDict({ + 'sql': "select rowid, t, distance from v where vector match '[0]' and k = 8 and t in ('aaaa', 'foo', 'bar')", + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 't': 'aaaa', + 'distance': 1.0, + }), + OrderedDict({ + 'rowid': 2, + 't': 'aaaa', + 'distance': 2.0, + }), + OrderedDict({ + 'rowid': 3, + 't': 'aaaa', + 'distance': 3.0, + }), + OrderedDict({ + 'rowid': 4, + 't': 'aaaa', + 'distance': 4.0, + }), + ]), + }) +# --- +# name: test_vtab_in[block-bool] + dict({ + 'error': 'OperationalError', + 'message': "'xxx in (...)' is only available on INTEGER or TEXT metadata columns.", + }) +# --- +# name: test_vtab_in[block-float] + dict({ + 'error': 'OperationalError', + 'message': "'xxx in (...)' is only available on INTEGER or TEXT metadata columns.", + }) +# --- +# name: test_vtab_in_long_text[all] + OrderedDict({ + 'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (select value from json_each(?))", + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 't': 'aaaa', + }), + OrderedDict({ + 'rowid': 2, + 't': 'aaaaaaaaaaaa_aaa', + }), + OrderedDict({ + 'rowid': 3, + 't': 'bbbb', + }), + OrderedDict({ + 'rowid': 4, + 't': 'bbbbbbbbbbbb_bbb', + }), + OrderedDict({ + 'rowid': 5, + 't': 'cccc', + }), + OrderedDict({ + 'rowid': 6, + 't': 'cccccccccccc_ccc', + }), + ]), + }) +# --- +# name: test_vtab_in_long_text[individual-aaaa] + OrderedDict({ + 'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')", + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 't': 'aaaa', + }), + ]), + }) +# --- +# name: test_vtab_in_long_text[individual-aaaaaaaaaaaa_aaa] + OrderedDict({ + 'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')", + 'rows': list([ + OrderedDict({ + 'rowid': 2, + 't': 'aaaaaaaaaaaa_aaa', + }), + ]), + }) +# --- +# name: test_vtab_in_long_text[individual-bbbb] + OrderedDict({ + 'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')", + 'rows': list([ + OrderedDict({ + 'rowid': 3, + 't': 'bbbb', + }), + ]), + }) +# --- +# name: test_vtab_in_long_text[individual-bbbbbbbbbbbb_bbb] + OrderedDict({ + 'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')", + 'rows': list([ + OrderedDict({ + 'rowid': 4, + 't': 'bbbbbbbbbbbb_bbb', + }), + ]), + }) +# --- +# name: test_vtab_in_long_text[individual-cccc] + OrderedDict({ + 'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')", + 'rows': list([ + OrderedDict({ + 'rowid': 5, + 't': 'cccc', + }), + ]), + }) +# --- +# name: test_vtab_in_long_text[individual-cccccccccccc_ccc] + OrderedDict({ + 'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')", + 'rows': list([ + OrderedDict({ + 'rowid': 6, + 't': 'cccccccccccc_ccc', + }), + ]), + }) +# --- diff --git a/tests/test-metadata.py b/tests/test-metadata.py index 09eb468..ce55e59 100644 --- a/tests/test-metadata.py +++ b/tests/test-metadata.py @@ -1,5 +1,7 @@ +import pytest import sqlite3 from collections import OrderedDict +import json def test_constructor_limit(db, snapshot): @@ -284,6 +286,87 @@ def test_knn(db, snapshot): ) +SUPPORTS_VTAB_IN = sqlite3.sqlite_version_info[1] >= 38 + + +@pytest.mark.skipif( + not SUPPORTS_VTAB_IN, reason="requires vtab `x in (...)` support in SQLite >=3.38" +) +def test_vtab_in(db, snapshot): + db.execute( + "create virtual table v using vec0(vector float[1], n int, t text, b boolean, f float, chunk_size=8)" + ) + db.executemany( + "insert into v(rowid, vector, n, t, b, f) values (?, ?, ?, ?, ?, ?)", + [ + (1, "[1]", 999, "aaaa", 0, 1.1), + (2, "[2]", 555, "aaaa", 0, 1.1), + (3, "[3]", 999, "aaaa", 0, 1.1), + (4, "[4]", 555, "aaaa", 0, 1.1), + (5, "[5]", 999, "zzzz", 0, 1.1), + (6, "[6]", 555, "zzzz", 0, 1.1), + (7, "[7]", 999, "zzzz", 0, 1.1), + (8, "[8]", 555, "zzzz", 0, 1.1), + ], + ) + assert exec( + db, "select * from v where vector match '[0]' and k = 8 and b in (1, 0)" + ) == snapshot(name="block-bool") + + assert exec( + db, "select * from v where vector match '[0]' and k = 8 and f in (1.1, 0.0)" + ) == snapshot(name="block-float") + + assert exec( + db, + "select rowid, n, distance from v where vector match '[0]' and k = 8 and n in (555, 999)", + ) == snapshot(name="allow-int-all") + assert exec( + db, + "select rowid, n, distance from v where vector match '[0]' and k = 8 and n in (555, -1, -2)", + ) == snapshot(name="allow-int-superfluous") + + assert exec( + db, + "select rowid, t, distance from v where vector match '[0]' and k = 8 and t in ('aaaa', 'zzzz')", + ) == snapshot(name="allow-text-all") + assert exec( + db, + "select rowid, t, distance from v where vector match '[0]' and k = 8 and t in ('aaaa', 'foo', 'bar')", + ) == snapshot(name="allow-text-superfluous") + + +def test_vtab_in_long_text(db, snapshot): + db.execute( + "create virtual table v using vec0(vector float[1], t text, chunk_size=8)" + ) + data = [ + (1, "aaaa"), + (2, "aaaaaaaaaaaa_aaa"), + (3, "bbbb"), + (4, "bbbbbbbbbbbb_bbb"), + (5, "cccc"), + (6, "cccccccccccc_ccc"), + ] + db.executemany( + "insert into v(rowid, vector, t) values (:rowid, printf('[%d]', :rowid), :vector)", + [{"rowid": row[0], "vector": row[1]} for row in data], + ) + + for _, lookup in data: + assert exec( + db, + "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')", + [lookup], + ) == snapshot(name=f"individual-{lookup}") + + assert exec( + db, + "select rowid, t from v where vector match '[0]' and k = 10 and t in (select value from json_each(?))", + [json.dumps([row[1] for row in data])], + ) == snapshot(name="all") + + def test_idxstr(db, snapshot): db.execute( """