fix tests, KNN/rowids in

This commit is contained in:
Alex Garcia 2024-11-12 21:57:59 -08:00
parent 4cb891a0b2
commit 8369dfe72d
2 changed files with 29 additions and 8 deletions

View file

@ -4189,6 +4189,8 @@ void vec0_query_point_data_clear(struct vec0_query_point_data *point_data) {
}
typedef enum {
// If any values are updated, please update the ARCHITECTURE.md docs accordingly!
VEC0_QUERY_PLAN_FULLSCAN = '1',
VEC0_QUERY_PLAN_POINT = '2',
VEC0_QUERY_PLAN_KNN = '3',
@ -4649,6 +4651,8 @@ static int vec0Close(sqlite3_vtab_cursor *cur) {
// All the different type of "values" provided to argv/argc in vec0Filter.
// These enums denote the use and purpose of all of them.
typedef enum {
// If any values are updated, please update the ARCHITECTURE.md docs accordingly!
VEC0_IDXSTR_KIND_KNN_MATCH = '{',
VEC0_IDXSTR_KIND_KNN_K = '}',
VEC0_IDXSTR_KIND_KNN_ROWID_IN = '[',
@ -4659,6 +4663,8 @@ typedef enum {
// The different SQLITE_INDEX_CONSTRAINT values that vec0 partition key columns
// support, but as characters that fit nicely in idxstr.
typedef enum {
// If any values are updated, please update the ARCHITECTURE.md docs accordingly!
VEC0_PARTITION_OPERATOR_EQ = 'a',
VEC0_PARTITION_OPERATOR_GT = 'b',
VEC0_PARTITION_OPERATOR_LE = 'c',
@ -5428,8 +5434,7 @@ cleanup:
int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
const char *idxStr, int argc, sqlite3_value **argv) {
UNUSED_PARAMETER(idxStr);
assert(argc >= 2);
assert(argc == (strlen(idxStr)-1) / 4);
int rc;
struct vec0_query_knn_data *knn_data;
@ -5450,9 +5455,26 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
}
memset(knn_data, 0, sizeof(*knn_data));
int query_idx =-1;
int k_idx = -1;
int rowid_in_idx = -1;
for(int i = 0; i < argc; i++) {
if(idxStr[1 + (i*4)] == VEC0_IDXSTR_KIND_KNN_MATCH) {
query_idx = i;
}
if(idxStr[1 + (i*4)] == VEC0_IDXSTR_KIND_KNN_K) {
k_idx = i;
}
if(idxStr[1 + (i*4)] == VEC0_IDXSTR_KIND_KNN_ROWID_IN) {
rowid_in_idx = i;
}
}
assert(query_idx >= 0);
assert(k_idx >= 0);
// make sure the query vector matches the vector column (type dimensions etc.)
// TODO not argv[0], source idx from idxStr
rc = vector_from_value(argv[0], &queryVector, &dimensions, &elementType,
rc = vector_from_value(argv[query_idx], &queryVector, &dimensions, &elementType,
&queryVectorCleanup, &pzError);
if (rc != SQLITE_OK) {
@ -5485,7 +5507,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
}
// TODO not argv[1], source idx from idxStr
i64 k = sqlite3_value_int64(argv[1]);
i64 k = sqlite3_value_int64(argv[k_idx]);
if (k < 0) {
vtab_set_error(
&p->base, "k value in knn queries must be greater than or equal to 0.");
@ -5515,7 +5537,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
// NULL if none were provided, which means a "full" scan.
#if COMPILER_SUPPORTS_VTAB_IN
// TODO fix
if (false /*argc > 2*/) {
if (rowid_in_idx >= 0) {
sqlite3_value *item;
int rc;
arrayRowidsIn = sqlite3_malloc(sizeof(*arrayRowidsIn));
@ -5529,8 +5551,8 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
if (rc != SQLITE_OK) {
goto cleanup;
}
for (rc = sqlite3_vtab_in_first(argv[2], &item); rc == SQLITE_OK && item;
rc = sqlite3_vtab_in_next(argv[2], &item)) {
for (rc = sqlite3_vtab_in_first(argv[rowid_in_idx], &item); rc == SQLITE_OK && item;
rc = sqlite3_vtab_in_next(argv[rowid_in_idx], &item)) {
i64 rowid;
if (p->pkIsText) {
rc = vec0_rowid_from_id(p, item, &rowid);

View file

@ -1420,7 +1420,6 @@ def test_vec0_point():
assert execute_all(db, "select * from t2 where id = 'xxx'") == []
# @pytest.mark.skip(reason="TODO failing locally for some reason")
def test_vec0_text_pk():
db = connect(EXT_PATH)
db.execute(