vtab_in handling

This commit is contained in:
Alex Garcia 2024-11-18 22:43:24 -08:00
parent 0db2e52974
commit 7b67c78530
6 changed files with 646 additions and 9 deletions

3
.gitignore vendored
View file

@ -28,3 +28,6 @@ tmp/
poetry.lock poetry.lock
*.jsonl *.jsonl
memstat.c
memstat.*

1
TODO
View file

@ -22,3 +22,4 @@
- remaining TODO items - remaining TODO items
- skip invalid validity entries in knn filter? - skip invalid validity entries in knn filter?
- dictionary encoding? - dictionary encoding?
- partition `x in (...)` handling

View file

@ -5265,6 +5265,7 @@ typedef enum {
VEC0_METADATA_OPERATOR_LT = 'd', VEC0_METADATA_OPERATOR_LT = 'd',
VEC0_METADATA_OPERATOR_GE = 'e', VEC0_METADATA_OPERATOR_GE = 'e',
VEC0_METADATA_OPERATOR_NE = 'f', VEC0_METADATA_OPERATOR_NE = 'f',
VEC0_METADATA_OPERATOR_IN = 'g',
} vec0_metadata_operator; } vec0_metadata_operator;
static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
@ -5498,7 +5499,33 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
switch(op) { switch(op) {
case SQLITE_INDEX_CONSTRAINT_EQ: { case SQLITE_INDEX_CONSTRAINT_EQ: {
value = VEC0_METADATA_OPERATOR_EQ; int vtabIn = 0;
#if COMPILER_SUPPORTS_VTAB_IN
if (sqlite3_libversion_number() >= 3038000) {
vtabIn = sqlite3_vtab_in(pIdxInfo, i, -1);
}
#endif
if(vtabIn) {
switch(p->metadata_columns[metadata_idx].kind) {
case VEC0_METADATA_COLUMN_KIND_FLOAT:
case VEC0_METADATA_COLUMN_KIND_BOOLEAN: {
// IMP: TODO
rc = SQLITE_ERROR;
vtab_set_error(pVTab, "'xxx in (...)' is only available on INTEGER or TEXT metadata columns.");
goto done;
break;
}
case VEC0_METADATA_COLUMN_KIND_INTEGER:
case VEC0_METADATA_COLUMN_KIND_TEXT: {
break;
}
}
value = VEC0_METADATA_OPERATOR_IN;
sqlite3_vtab_in(pIdxInfo, i, 1);
}
else {
value = VEC0_PARTITION_OPERATOR_EQ;
}
break; break;
} }
case SQLITE_INDEX_CONSTRAINT_GT: { case SQLITE_INDEX_CONSTRAINT_GT: {
@ -5852,7 +5879,24 @@ int vec0_chunks_iter(vec0_vtab * p, const char * idxStr, int argc, sqlite3_value
return rc; return rc;
} }
int vec0_metadata_filter_text(vec0_vtab * p, sqlite3_value * value, const void * buffer, int size, vec0_metadata_operator op, u8* b, int metadata_idx, int chunk_rowid) { // a single `xxx in (...)` constraint on a metadata column. TEXT or INTEGER only for now.
struct Vec0MetadataIn{
// index of argv[i]` the constraint is on
int argv_idx;
// metadata column index of the constraint, derived from idxStr + argv_idx
int metadata_idx;
// array of the copied `(...)` values from sqlite3_vtab_in_first()/sqlite3_vtab_in_next()
struct Array array;
};
// Array elements for `xxx in (...)` values for a text column. basically just a string
struct Vec0MetadataInTextEntry {
int n;
char * zString;
};
int vec0_metadata_filter_text(vec0_vtab * p, sqlite3_value * value, const void * buffer, int size, vec0_metadata_operator op, u8* b, int metadata_idx, int chunk_rowid, struct Array * aMetadataIn, int argv_idx) {
int rc; int rc;
sqlite3_stmt * stmt = NULL; sqlite3_stmt * stmt = NULL;
i64 * rowids = NULL; i64 * rowids = NULL;
@ -6088,6 +6132,66 @@ int vec0_metadata_filter_text(vec0_vtab * p, sqlite3_value * value, const void *
break; break;
} }
case VEC0_METADATA_OPERATOR_IN: {
size_t metadataInIdx = -1;
for(size_t i = 0; i < aMetadataIn->length; i++) {
struct Vec0MetadataIn * metadataIn = &(((struct Vec0MetadataIn *) aMetadataIn->z)[i]);
if(metadataIn->argv_idx == argv_idx) {
metadataInIdx = i;
break;
}
}
if(metadataInIdx < 0) {
abort(); // TODO
}
struct Vec0MetadataIn * metadataIn = &((struct Vec0MetadataIn *) aMetadataIn->z)[metadataInIdx];
struct Array * aTarget = &(metadataIn->array);
int nPrefix;
char * sPrefix;
char *sFull;
int nFull;
u8 * view;
for(int i = 0; i < size; i++) {
view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
nPrefix = ((int*) view)[0];
sPrefix = (char *) &view[4];
for(size_t target_idx = 0; target_idx < aTarget->length; target_idx++) {
struct Vec0MetadataInTextEntry * entry = &(((struct Vec0MetadataInTextEntry*)aTarget->z)[target_idx]);
if(entry->n != nPrefix) {
continue;
}
int cmpPrefix = strncmp(sPrefix, entry->zString, min(nPrefix, VEC0_METADATA_TEXT_VIEW_DATA_LENGTH));
if(nPrefix <= VEC0_METADATA_TEXT_VIEW_DATA_LENGTH) {
if(cmpPrefix == 0) {
bitmap_set(b, i, 1);
break;
}
continue;
}
if(cmpPrefix) {
continue;
}
rc = vec0_get_metadata_text_long_value(p, &stmt, metadata_idx, rowids[i], &nFull, &sFull);
if(rc != SQLITE_OK) {
goto done;
}
if(nPrefix != nFull) {
rc = SQLITE_ERROR;
goto done;
}
if(strncmp(sFull, entry->zString, nFull) == 0) {
bitmap_set(b, i, 1);
break;
}
}
}
break;
}
} }
rc = SQLITE_OK; rc = SQLITE_OK;
@ -6118,7 +6222,8 @@ int vec0_set_metadata_filter_bitmap(
sqlite3_blob * blob, sqlite3_blob * blob,
i64 chunk_rowid, i64 chunk_rowid,
u8* b, u8* b,
int size) { int size,
struct Array * aMetadataIn, int argv_idx) {
// TODO: shouldn't this skip in-valid entries from the chunk's validity bitmap? // TODO: shouldn't this skip in-valid entries from the chunk's validity bitmap?
int rc; int rc;
@ -6198,6 +6303,31 @@ int vec0_set_metadata_filter_bitmap(
for(int i = 0; i < size; i++) { bitmap_set(b, i, array[i] != target); } for(int i = 0; i < size; i++) { bitmap_set(b, i, array[i] != target); }
break; break;
} }
case VEC0_METADATA_OPERATOR_IN: {
int metadataInIdx = -1;
for(size_t i = 0; i < aMetadataIn->length; i++) {
struct Vec0MetadataIn * metadataIn = &((struct Vec0MetadataIn *) aMetadataIn->z)[i];
if(metadataIn->argv_idx == argv_idx) {
metadataInIdx = i;
break;
}
}
if(metadataInIdx < 0) {
abort(); // TODO
}
struct Vec0MetadataIn * metadataIn = &((struct Vec0MetadataIn *) aMetadataIn->z)[metadataInIdx];
struct Array * aTarget = &(metadataIn->array);
for(int i = 0; i < size; i++) {
for(size_t target_idx = 0; target_idx < aTarget->length; target_idx++) {
if( ((i64*)aTarget->z)[target_idx] == array[i]) {
bitmap_set(b, i, 1);
break;
}
}
}
break;
}
} }
break; break;
} }
@ -6229,11 +6359,15 @@ int vec0_set_metadata_filter_bitmap(
for(int i = 0; i < size; i++) { bitmap_set(b, i, array[i] != target); } for(int i = 0; i < size; i++) { bitmap_set(b, i, array[i] != target); }
break; break;
} }
case VEC0_METADATA_OPERATOR_IN: {
// should never be reached
break;
}
} }
break; break;
} }
case VEC0_METADATA_COLUMN_KIND_TEXT: { case VEC0_METADATA_COLUMN_KIND_TEXT: {
rc = vec0_metadata_filter_text(p, value, buffer, size, op, b, metadata_idx, chunk_rowid); rc = vec0_metadata_filter_text(p, value, buffer, size, op, b, metadata_idx, chunk_rowid, aMetadataIn, argv_idx);
if(rc != SQLITE_OK) { if(rc != SQLITE_OK) {
goto done; goto done;
} }
@ -6248,6 +6382,7 @@ int vec0_set_metadata_filter_bitmap(
int vec0Filter_knn_chunks_iter(vec0_vtab *p, sqlite3_stmt *stmtChunks, int vec0Filter_knn_chunks_iter(vec0_vtab *p, sqlite3_stmt *stmtChunks,
struct VectorColumnDefinition *vector_column, struct VectorColumnDefinition *vector_column,
int vectorColumnIdx, struct Array *arrayRowidsIn, int vectorColumnIdx, struct Array *arrayRowidsIn,
struct Array * aMetadataIn,
const char * idxStr, int argc, sqlite3_value ** argv, const char * idxStr, int argc, sqlite3_value ** argv,
void *queryVector, i64 k, i64 **out_topk_rowids, void *queryVector, i64 k, i64 **out_topk_rowids,
f32 **out_topk_distances, i64 *out_used) { f32 **out_topk_distances, i64 *out_used) {
@ -6472,7 +6607,7 @@ int vec0Filter_knn_chunks_iter(vec0_vtab *p, sqlite3_stmt *stmtChunks,
} }
bitmap_clear(bmMetadata, p->chunk_size); bitmap_clear(bmMetadata, p->chunk_size);
rc = vec0_set_metadata_filter_bitmap(p, metadata_idx, operator, argv[i], metadataBlobs[metadata_idx], chunk_id, bmMetadata, p->chunk_size); rc = vec0_set_metadata_filter_bitmap(p, metadata_idx, operator, argv[i], metadataBlobs[metadata_idx], chunk_id, bmMetadata, p->chunk_size, aMetadataIn, i);
if(rc != SQLITE_OK) { if(rc != SQLITE_OK) {
vtab_set_error(&p->base, "Could not filter metadata fields"); vtab_set_error(&p->base, "Could not filter metadata fields");
if(rc != SQLITE_OK) { if(rc != SQLITE_OK) {
@ -6619,6 +6754,9 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
return SQLITE_NOMEM; return SQLITE_NOMEM;
} }
memset(knn_data, 0, sizeof(*knn_data)); memset(knn_data, 0, sizeof(*knn_data));
// array of `struct Vec0MetadataIn`, IF there are any `xxx in (...)` metadata constraints
struct Array * aMetadataIn = NULL;
int query_idx =-1; int query_idx =-1;
int k_idx = -1; int k_idx = -1;
@ -6738,6 +6876,95 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
} }
#endif #endif
#if COMPILER_SUPPORTS_VTAB_IN
for(int i = 0; i < argc; i++) {
if(!(idxStr[1 + (i*4)] == VEC0_IDXSTR_KIND_METADATA_CONSTRAINT && idxStr[1 + (i*4) + 2] == VEC0_METADATA_OPERATOR_IN)) {
continue;
}
int metadata_idx = idxStr[1 + (i*4) + 1] - 'A';
if(!aMetadataIn) {
aMetadataIn = sqlite3_malloc(sizeof(*aMetadataIn));
if(!aMetadataIn) {
rc = SQLITE_NOMEM;
goto cleanup;
}
memset(aMetadataIn, 0, sizeof(*aMetadataIn));
rc = array_init(aMetadataIn, sizeof(struct Vec0MetadataIn), 8);
if(rc != SQLITE_OK) {
goto cleanup;
}
}
struct Vec0MetadataIn item;
memset(&item, 0, sizeof(item));
item.metadata_idx=metadata_idx;
item.argv_idx = i;
switch(p->metadata_columns[metadata_idx].kind) {
case VEC0_METADATA_COLUMN_KIND_INTEGER: {
rc = array_init(&item.array, sizeof(i64), 16);
if(rc != SQLITE_OK) {
goto cleanup;
}
sqlite3_value *entry;
for (rc = sqlite3_vtab_in_first(argv[i], &entry); rc == SQLITE_OK && entry; rc = sqlite3_vtab_in_next(argv[i], &entry)) {
i64 v = sqlite3_value_int64(entry);
rc = array_append(&item.array, &v);
if (rc != SQLITE_OK) {
goto cleanup;
}
}
if (rc != SQLITE_DONE) {
vtab_set_error(&p->base, "fuck"); // TODO
goto cleanup;
}
break;
}
case VEC0_METADATA_COLUMN_KIND_TEXT: {
rc = array_init(&item.array, sizeof(struct Vec0MetadataInTextEntry), 16);
if(rc != SQLITE_OK) {
goto cleanup;
}
sqlite3_value *entry;
for (rc = sqlite3_vtab_in_first(argv[i], &entry); rc == SQLITE_OK && entry; rc = sqlite3_vtab_in_next(argv[i], &entry)) {
const char * s = (const char *) sqlite3_value_text(entry);
int n = sqlite3_value_bytes(entry);
struct Vec0MetadataInTextEntry entry;
// TODO if this exits early, does it get properly cleaned up
entry.zString = sqlite3_mprintf("%.*s", n, s);
if(!entry.zString) {
rc = SQLITE_NOMEM;
goto cleanup;
}
entry.n = n;
rc = array_append(&item.array, &entry);
if (rc != SQLITE_OK) {
goto cleanup;
}
}
if (rc != SQLITE_DONE) {
vtab_set_error(&p->base, "fuck"); // TODO
goto cleanup;
}
break;
}
default: {
abort();
}
}
rc = array_append(aMetadataIn, &item);
if(rc != SQLITE_OK) {
abort(); // TODO
}
}
#endif
rc = vec0_chunks_iter(p, idxStr, argc, argv, &stmtChunks); rc = vec0_chunks_iter(p, idxStr, argc, argv, &stmtChunks);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
// IMP: V06942_23781 // IMP: V06942_23781
@ -6750,7 +6977,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
f32 *topk_distances = NULL; f32 *topk_distances = NULL;
i64 k_used = 0; i64 k_used = 0;
rc = vec0Filter_knn_chunks_iter(p, stmtChunks, vector_column, vectorColumnIdx, rc = vec0Filter_knn_chunks_iter(p, stmtChunks, vector_column, vectorColumnIdx,
arrayRowidsIn, idxStr, argc, argv, queryVector, k, &topk_rowids, arrayRowidsIn, aMetadataIn, idxStr, argc, argv, queryVector, k, &topk_rowids,
&topk_distances, &k_used); &topk_distances, &k_used);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
goto cleanup; goto cleanup;
@ -6771,6 +6998,21 @@ cleanup:
array_cleanup(arrayRowidsIn); array_cleanup(arrayRowidsIn);
sqlite3_free(arrayRowidsIn); sqlite3_free(arrayRowidsIn);
queryVectorCleanup(queryVector); queryVectorCleanup(queryVector);
if(aMetadataIn) {
for(size_t i = 0; i < aMetadataIn->length; i++) {
struct Vec0MetadataIn* item = &((struct Vec0MetadataIn *) aMetadataIn->z)[i];
for(size_t j = 0; j < item->array.length; j++) {
if(p->metadata_columns[item->metadata_idx].kind == VEC0_METADATA_COLUMN_KIND_TEXT) {
struct Vec0MetadataInTextEntry entry = ((struct Vec0MetadataInTextEntry*)item->array.z)[j];
sqlite3_free(entry.zString);
}
}
array_cleanup(&item->array);
}
array_cleanup(aMetadataIn);
}
sqlite3_free(aMetadataIn);
return rc; return rc;
} }
@ -7049,7 +7291,8 @@ static int vec0Column_fullscan(vec0_vtab *pVtab, vec0_cursor *pCur,
int metadata_idx = vec0_column_idx_to_metadata_idx(pVtab, i); int metadata_idx = vec0_column_idx_to_metadata_idx(pVtab, i);
int rc = vec0_result_metadata_value_for_rowid(pVtab, rowid, metadata_idx, context); int rc = vec0_result_metadata_value_for_rowid(pVtab, rowid, metadata_idx, context);
if(rc != SQLITE_OK) { if(rc != SQLITE_OK) {
sqlite3_result_error(context, "fuck todo", -1); // TODO handle
sqlite3_result_error(context, "fuck", -1);
} }
} }
return SQLITE_OK; return SQLITE_OK;
@ -7121,7 +7364,8 @@ static int vec0Column_point(vec0_vtab *pVtab, vec0_cursor *pCur,
int metadata_idx = vec0_column_idx_to_metadata_idx(pVtab, i); int metadata_idx = vec0_column_idx_to_metadata_idx(pVtab, i);
int rc = vec0_result_metadata_value_for_rowid(pVtab, rowid, metadata_idx, context); int rc = vec0_result_metadata_value_for_rowid(pVtab, rowid, metadata_idx, context);
if(rc != SQLITE_OK) { if(rc != SQLITE_OK) {
sqlite3_result_error(context, "fuck todo", -1); // TODO handle
sqlite3_result_error(context, "fuck", -1);
} }
} }
@ -7188,7 +7432,8 @@ static int vec0Column_knn(vec0_vtab *pVtab, vec0_cursor *pCur,
i64 rowid = pCur->knn_data->rowids[pCur->knn_data->current_idx]; i64 rowid = pCur->knn_data->rowids[pCur->knn_data->current_idx];
int rc = vec0_result_metadata_value_for_rowid(pVtab, rowid, metadata_idx, context); int rc = vec0_result_metadata_value_for_rowid(pVtab, rowid, metadata_idx, context);
if(rc != SQLITE_OK) { if(rc != SQLITE_OK) {
sqlite3_result_error(context, "fuck todo", -1); // TODO: handle
sqlite3_result_error(context, "fuck", -1);
} }
} }

View file

@ -5,6 +5,54 @@
.mode qbox .mode qbox
.load ./memstat
.echo on
select name, value from sqlite_memstat where name = 'MEMORY_USED';
create virtual table v using vec0(
vector float[1],
name1 text,
name2 text,
age int,
chunk_size=8
);
select name, value from sqlite_memstat where name = 'MEMORY_USED';
insert into v(vector, name1, name2, age) values
('[1]', 'alex', 'xxxx', 1),
('[2]', 'alex', 'aaaa', 2),
('[3]', 'alex', 'aaaa', 3),
('[4]', 'brian', 'aaaa', 1),
('[5]', 'brian', 'aaaa', 2),
('[6]', 'brian', 'aaaa', 3),
('[7]', 'craig', 'aaaa', 1),
('[8]', 'craig', 'xxxx', 2),
('[9]', 'craig', 'xxxx', 3),
('[10]', '123456789012345', 'xxxx', 3);
select name, value from sqlite_memstat where name = 'MEMORY_USED';
select rowid, name1, name2, age, vec_to_json(vector)
from v
where vector match '[0]'
and k = 5
and name1 in ('alex', 'brian', 'craig')
--and name2 in ('aaaa', 'xxxx')
and age in (1, 2, 3, 2222,3333,4444);
select name, value from sqlite_memstat where name = 'MEMORY_USED';
select rowid, name1, name2, age, vec_to_json(vector)
from v
where vector match '[0]'
and k = 5
and name1 in ('123456789012345', 'superfluous');
.exit
create virtual table v using vec0( create virtual table v using vec0(
vector float[1], vector float[1],
+description text +description text

View file

@ -3806,3 +3806,260 @@
}), }),
}) })
# --- # ---
# name: test_vtab_in[allow-int-all]
OrderedDict({
'sql': "select rowid, n, distance from v where vector match '[0]' and k = 8 and n in (555, 999)",
'rows': list([
OrderedDict({
'rowid': 1,
'n': 999,
'distance': 1.0,
}),
OrderedDict({
'rowid': 2,
'n': 555,
'distance': 2.0,
}),
OrderedDict({
'rowid': 3,
'n': 999,
'distance': 3.0,
}),
OrderedDict({
'rowid': 4,
'n': 555,
'distance': 4.0,
}),
OrderedDict({
'rowid': 5,
'n': 999,
'distance': 5.0,
}),
OrderedDict({
'rowid': 6,
'n': 555,
'distance': 6.0,
}),
OrderedDict({
'rowid': 7,
'n': 999,
'distance': 7.0,
}),
OrderedDict({
'rowid': 8,
'n': 555,
'distance': 8.0,
}),
]),
})
# ---
# name: test_vtab_in[allow-int-superfluous]
OrderedDict({
'sql': "select rowid, n, distance from v where vector match '[0]' and k = 8 and n in (555, -1, -2)",
'rows': list([
OrderedDict({
'rowid': 2,
'n': 555,
'distance': 2.0,
}),
OrderedDict({
'rowid': 4,
'n': 555,
'distance': 4.0,
}),
OrderedDict({
'rowid': 6,
'n': 555,
'distance': 6.0,
}),
OrderedDict({
'rowid': 8,
'n': 555,
'distance': 8.0,
}),
]),
})
# ---
# name: test_vtab_in[allow-text-all]
OrderedDict({
'sql': "select rowid, t, distance from v where vector match '[0]' and k = 8 and t in ('aaaa', 'zzzz')",
'rows': list([
OrderedDict({
'rowid': 1,
't': 'aaaa',
'distance': 1.0,
}),
OrderedDict({
'rowid': 2,
't': 'aaaa',
'distance': 2.0,
}),
OrderedDict({
'rowid': 3,
't': 'aaaa',
'distance': 3.0,
}),
OrderedDict({
'rowid': 4,
't': 'aaaa',
'distance': 4.0,
}),
OrderedDict({
'rowid': 5,
't': 'zzzz',
'distance': 5.0,
}),
OrderedDict({
'rowid': 6,
't': 'zzzz',
'distance': 6.0,
}),
OrderedDict({
'rowid': 7,
't': 'zzzz',
'distance': 7.0,
}),
OrderedDict({
'rowid': 8,
't': 'zzzz',
'distance': 8.0,
}),
]),
})
# ---
# name: test_vtab_in[allow-text-superfluous]
OrderedDict({
'sql': "select rowid, t, distance from v where vector match '[0]' and k = 8 and t in ('aaaa', 'foo', 'bar')",
'rows': list([
OrderedDict({
'rowid': 1,
't': 'aaaa',
'distance': 1.0,
}),
OrderedDict({
'rowid': 2,
't': 'aaaa',
'distance': 2.0,
}),
OrderedDict({
'rowid': 3,
't': 'aaaa',
'distance': 3.0,
}),
OrderedDict({
'rowid': 4,
't': 'aaaa',
'distance': 4.0,
}),
]),
})
# ---
# name: test_vtab_in[block-bool]
dict({
'error': 'OperationalError',
'message': "'xxx in (...)' is only available on INTEGER or TEXT metadata columns.",
})
# ---
# name: test_vtab_in[block-float]
dict({
'error': 'OperationalError',
'message': "'xxx in (...)' is only available on INTEGER or TEXT metadata columns.",
})
# ---
# name: test_vtab_in_long_text[all]
OrderedDict({
'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (select value from json_each(?))",
'rows': list([
OrderedDict({
'rowid': 1,
't': 'aaaa',
}),
OrderedDict({
'rowid': 2,
't': 'aaaaaaaaaaaa_aaa',
}),
OrderedDict({
'rowid': 3,
't': 'bbbb',
}),
OrderedDict({
'rowid': 4,
't': 'bbbbbbbbbbbb_bbb',
}),
OrderedDict({
'rowid': 5,
't': 'cccc',
}),
OrderedDict({
'rowid': 6,
't': 'cccccccccccc_ccc',
}),
]),
})
# ---
# name: test_vtab_in_long_text[individual-aaaa]
OrderedDict({
'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')",
'rows': list([
OrderedDict({
'rowid': 1,
't': 'aaaa',
}),
]),
})
# ---
# name: test_vtab_in_long_text[individual-aaaaaaaaaaaa_aaa]
OrderedDict({
'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')",
'rows': list([
OrderedDict({
'rowid': 2,
't': 'aaaaaaaaaaaa_aaa',
}),
]),
})
# ---
# name: test_vtab_in_long_text[individual-bbbb]
OrderedDict({
'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')",
'rows': list([
OrderedDict({
'rowid': 3,
't': 'bbbb',
}),
]),
})
# ---
# name: test_vtab_in_long_text[individual-bbbbbbbbbbbb_bbb]
OrderedDict({
'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')",
'rows': list([
OrderedDict({
'rowid': 4,
't': 'bbbbbbbbbbbb_bbb',
}),
]),
})
# ---
# name: test_vtab_in_long_text[individual-cccc]
OrderedDict({
'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')",
'rows': list([
OrderedDict({
'rowid': 5,
't': 'cccc',
}),
]),
})
# ---
# name: test_vtab_in_long_text[individual-cccccccccccc_ccc]
OrderedDict({
'sql': "select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')",
'rows': list([
OrderedDict({
'rowid': 6,
't': 'cccccccccccc_ccc',
}),
]),
})
# ---

View file

@ -1,5 +1,7 @@
import pytest
import sqlite3 import sqlite3
from collections import OrderedDict from collections import OrderedDict
import json
def test_constructor_limit(db, snapshot): def test_constructor_limit(db, snapshot):
@ -284,6 +286,87 @@ def test_knn(db, snapshot):
) )
SUPPORTS_VTAB_IN = sqlite3.sqlite_version_info[1] >= 38
@pytest.mark.skipif(
not SUPPORTS_VTAB_IN, reason="requires vtab `x in (...)` support in SQLite >=3.38"
)
def test_vtab_in(db, snapshot):
db.execute(
"create virtual table v using vec0(vector float[1], n int, t text, b boolean, f float, chunk_size=8)"
)
db.executemany(
"insert into v(rowid, vector, n, t, b, f) values (?, ?, ?, ?, ?, ?)",
[
(1, "[1]", 999, "aaaa", 0, 1.1),
(2, "[2]", 555, "aaaa", 0, 1.1),
(3, "[3]", 999, "aaaa", 0, 1.1),
(4, "[4]", 555, "aaaa", 0, 1.1),
(5, "[5]", 999, "zzzz", 0, 1.1),
(6, "[6]", 555, "zzzz", 0, 1.1),
(7, "[7]", 999, "zzzz", 0, 1.1),
(8, "[8]", 555, "zzzz", 0, 1.1),
],
)
assert exec(
db, "select * from v where vector match '[0]' and k = 8 and b in (1, 0)"
) == snapshot(name="block-bool")
assert exec(
db, "select * from v where vector match '[0]' and k = 8 and f in (1.1, 0.0)"
) == snapshot(name="block-float")
assert exec(
db,
"select rowid, n, distance from v where vector match '[0]' and k = 8 and n in (555, 999)",
) == snapshot(name="allow-int-all")
assert exec(
db,
"select rowid, n, distance from v where vector match '[0]' and k = 8 and n in (555, -1, -2)",
) == snapshot(name="allow-int-superfluous")
assert exec(
db,
"select rowid, t, distance from v where vector match '[0]' and k = 8 and t in ('aaaa', 'zzzz')",
) == snapshot(name="allow-text-all")
assert exec(
db,
"select rowid, t, distance from v where vector match '[0]' and k = 8 and t in ('aaaa', 'foo', 'bar')",
) == snapshot(name="allow-text-superfluous")
def test_vtab_in_long_text(db, snapshot):
db.execute(
"create virtual table v using vec0(vector float[1], t text, chunk_size=8)"
)
data = [
(1, "aaaa"),
(2, "aaaaaaaaaaaa_aaa"),
(3, "bbbb"),
(4, "bbbbbbbbbbbb_bbb"),
(5, "cccc"),
(6, "cccccccccccc_ccc"),
]
db.executemany(
"insert into v(rowid, vector, t) values (:rowid, printf('[%d]', :rowid), :vector)",
[{"rowid": row[0], "vector": row[1]} for row in data],
)
for _, lookup in data:
assert exec(
db,
"select rowid, t from v where vector match '[0]' and k = 10 and t in (?, 'nonsense')",
[lookup],
) == snapshot(name=f"individual-{lookup}")
assert exec(
db,
"select rowid, t from v where vector match '[0]' and k = 10 and t in (select value from json_each(?))",
[json.dumps([row[1] for row in data])],
) == snapshot(name="all")
def test_idxstr(db, snapshot): def test_idxstr(db, snapshot):
db.execute( db.execute(
""" """