From 8369dfe72d5e39f1bc2cff46b6c41d98dca0adab Mon Sep 17 00:00:00 2001 From: Alex Garcia Date: Tue, 12 Nov 2024 21:57:59 -0800 Subject: [PATCH] fix tests, KNN/rowids in --- sqlite-vec.c | 36 +++++++++++++++++++++++++++++------- tests/test-loadable.py | 1 - 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/sqlite-vec.c b/sqlite-vec.c index 87ba10a..e852322 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -4189,6 +4189,8 @@ void vec0_query_point_data_clear(struct vec0_query_point_data *point_data) { } typedef enum { + // If any values are updated, please update the ARCHITECTURE.md docs accordingly! + VEC0_QUERY_PLAN_FULLSCAN = '1', VEC0_QUERY_PLAN_POINT = '2', VEC0_QUERY_PLAN_KNN = '3', @@ -4649,6 +4651,8 @@ static int vec0Close(sqlite3_vtab_cursor *cur) { // All the different type of "values" provided to argv/argc in vec0Filter. // These enums denote the use and purpose of all of them. typedef enum { + // If any values are updated, please update the ARCHITECTURE.md docs accordingly! + VEC0_IDXSTR_KIND_KNN_MATCH = '{', VEC0_IDXSTR_KIND_KNN_K = '}', VEC0_IDXSTR_KIND_KNN_ROWID_IN = '[', @@ -4659,6 +4663,8 @@ typedef enum { // The different SQLITE_INDEX_CONSTRAINT values that vec0 partition key columns // support, but as characters that fit nicely in idxstr. typedef enum { + // If any values are updated, please update the ARCHITECTURE.md docs accordingly! + VEC0_PARTITION_OPERATOR_EQ = 'a', VEC0_PARTITION_OPERATOR_GT = 'b', VEC0_PARTITION_OPERATOR_LE = 'c', @@ -5428,8 +5434,7 @@ cleanup: int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum, const char *idxStr, int argc, sqlite3_value **argv) { - UNUSED_PARAMETER(idxStr); - assert(argc >= 2); + assert(argc == (strlen(idxStr)-1) / 4); int rc; struct vec0_query_knn_data *knn_data; @@ -5450,9 +5455,26 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum, } memset(knn_data, 0, sizeof(*knn_data)); + int query_idx =-1; + int k_idx = -1; + int rowid_in_idx = -1; + for(int i = 0; i < argc; i++) { + if(idxStr[1 + (i*4)] == VEC0_IDXSTR_KIND_KNN_MATCH) { + query_idx = i; + } + if(idxStr[1 + (i*4)] == VEC0_IDXSTR_KIND_KNN_K) { + k_idx = i; + } + if(idxStr[1 + (i*4)] == VEC0_IDXSTR_KIND_KNN_ROWID_IN) { + rowid_in_idx = i; + } + } + assert(query_idx >= 0); + assert(k_idx >= 0); + // make sure the query vector matches the vector column (type dimensions etc.) // TODO not argv[0], source idx from idxStr - rc = vector_from_value(argv[0], &queryVector, &dimensions, &elementType, + rc = vector_from_value(argv[query_idx], &queryVector, &dimensions, &elementType, &queryVectorCleanup, &pzError); if (rc != SQLITE_OK) { @@ -5485,7 +5507,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum, } // TODO not argv[1], source idx from idxStr - i64 k = sqlite3_value_int64(argv[1]); + i64 k = sqlite3_value_int64(argv[k_idx]); if (k < 0) { vtab_set_error( &p->base, "k value in knn queries must be greater than or equal to 0."); @@ -5515,7 +5537,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum, // NULL if none were provided, which means a "full" scan. #if COMPILER_SUPPORTS_VTAB_IN // TODO fix - if (false /*argc > 2*/) { + if (rowid_in_idx >= 0) { sqlite3_value *item; int rc; arrayRowidsIn = sqlite3_malloc(sizeof(*arrayRowidsIn)); @@ -5529,8 +5551,8 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum, if (rc != SQLITE_OK) { goto cleanup; } - for (rc = sqlite3_vtab_in_first(argv[2], &item); rc == SQLITE_OK && item; - rc = sqlite3_vtab_in_next(argv[2], &item)) { + for (rc = sqlite3_vtab_in_first(argv[rowid_in_idx], &item); rc == SQLITE_OK && item; + rc = sqlite3_vtab_in_next(argv[rowid_in_idx], &item)) { i64 rowid; if (p->pkIsText) { rc = vec0_rowid_from_id(p, item, &rowid); diff --git a/tests/test-loadable.py b/tests/test-loadable.py index 31a97ae..30171fe 100644 --- a/tests/test-loadable.py +++ b/tests/test-loadable.py @@ -1420,7 +1420,6 @@ def test_vec0_point(): assert execute_all(db, "select * from t2 where id = 'xxx'") == [] -# @pytest.mark.skip(reason="TODO failing locally for some reason") def test_vec0_text_pk(): db = connect(EXT_PATH) db.execute(