mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 00:36:56 +02:00
fix tests, KNN/rowids in
This commit is contained in:
parent
4cb891a0b2
commit
8369dfe72d
2 changed files with 29 additions and 8 deletions
36
sqlite-vec.c
36
sqlite-vec.c
|
|
@ -4189,6 +4189,8 @@ void vec0_query_point_data_clear(struct vec0_query_point_data *point_data) {
|
|||
}
|
||||
|
||||
typedef enum {
|
||||
// If any values are updated, please update the ARCHITECTURE.md docs accordingly!
|
||||
|
||||
VEC0_QUERY_PLAN_FULLSCAN = '1',
|
||||
VEC0_QUERY_PLAN_POINT = '2',
|
||||
VEC0_QUERY_PLAN_KNN = '3',
|
||||
|
|
@ -4649,6 +4651,8 @@ static int vec0Close(sqlite3_vtab_cursor *cur) {
|
|||
// All the different type of "values" provided to argv/argc in vec0Filter.
|
||||
// These enums denote the use and purpose of all of them.
|
||||
typedef enum {
|
||||
// If any values are updated, please update the ARCHITECTURE.md docs accordingly!
|
||||
|
||||
VEC0_IDXSTR_KIND_KNN_MATCH = '{',
|
||||
VEC0_IDXSTR_KIND_KNN_K = '}',
|
||||
VEC0_IDXSTR_KIND_KNN_ROWID_IN = '[',
|
||||
|
|
@ -4659,6 +4663,8 @@ typedef enum {
|
|||
// The different SQLITE_INDEX_CONSTRAINT values that vec0 partition key columns
|
||||
// support, but as characters that fit nicely in idxstr.
|
||||
typedef enum {
|
||||
// If any values are updated, please update the ARCHITECTURE.md docs accordingly!
|
||||
|
||||
VEC0_PARTITION_OPERATOR_EQ = 'a',
|
||||
VEC0_PARTITION_OPERATOR_GT = 'b',
|
||||
VEC0_PARTITION_OPERATOR_LE = 'c',
|
||||
|
|
@ -5428,8 +5434,7 @@ cleanup:
|
|||
|
||||
int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
|
||||
const char *idxStr, int argc, sqlite3_value **argv) {
|
||||
UNUSED_PARAMETER(idxStr);
|
||||
assert(argc >= 2);
|
||||
assert(argc == (strlen(idxStr)-1) / 4);
|
||||
int rc;
|
||||
struct vec0_query_knn_data *knn_data;
|
||||
|
||||
|
|
@ -5450,9 +5455,26 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
|
|||
}
|
||||
memset(knn_data, 0, sizeof(*knn_data));
|
||||
|
||||
int query_idx =-1;
|
||||
int k_idx = -1;
|
||||
int rowid_in_idx = -1;
|
||||
for(int i = 0; i < argc; i++) {
|
||||
if(idxStr[1 + (i*4)] == VEC0_IDXSTR_KIND_KNN_MATCH) {
|
||||
query_idx = i;
|
||||
}
|
||||
if(idxStr[1 + (i*4)] == VEC0_IDXSTR_KIND_KNN_K) {
|
||||
k_idx = i;
|
||||
}
|
||||
if(idxStr[1 + (i*4)] == VEC0_IDXSTR_KIND_KNN_ROWID_IN) {
|
||||
rowid_in_idx = i;
|
||||
}
|
||||
}
|
||||
assert(query_idx >= 0);
|
||||
assert(k_idx >= 0);
|
||||
|
||||
// make sure the query vector matches the vector column (type dimensions etc.)
|
||||
// TODO not argv[0], source idx from idxStr
|
||||
rc = vector_from_value(argv[0], &queryVector, &dimensions, &elementType,
|
||||
rc = vector_from_value(argv[query_idx], &queryVector, &dimensions, &elementType,
|
||||
&queryVectorCleanup, &pzError);
|
||||
|
||||
if (rc != SQLITE_OK) {
|
||||
|
|
@ -5485,7 +5507,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
|
|||
}
|
||||
|
||||
// TODO not argv[1], source idx from idxStr
|
||||
i64 k = sqlite3_value_int64(argv[1]);
|
||||
i64 k = sqlite3_value_int64(argv[k_idx]);
|
||||
if (k < 0) {
|
||||
vtab_set_error(
|
||||
&p->base, "k value in knn queries must be greater than or equal to 0.");
|
||||
|
|
@ -5515,7 +5537,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
|
|||
// NULL if none were provided, which means a "full" scan.
|
||||
#if COMPILER_SUPPORTS_VTAB_IN
|
||||
// TODO fix
|
||||
if (false /*argc > 2*/) {
|
||||
if (rowid_in_idx >= 0) {
|
||||
sqlite3_value *item;
|
||||
int rc;
|
||||
arrayRowidsIn = sqlite3_malloc(sizeof(*arrayRowidsIn));
|
||||
|
|
@ -5529,8 +5551,8 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
|
|||
if (rc != SQLITE_OK) {
|
||||
goto cleanup;
|
||||
}
|
||||
for (rc = sqlite3_vtab_in_first(argv[2], &item); rc == SQLITE_OK && item;
|
||||
rc = sqlite3_vtab_in_next(argv[2], &item)) {
|
||||
for (rc = sqlite3_vtab_in_first(argv[rowid_in_idx], &item); rc == SQLITE_OK && item;
|
||||
rc = sqlite3_vtab_in_next(argv[rowid_in_idx], &item)) {
|
||||
i64 rowid;
|
||||
if (p->pkIsText) {
|
||||
rc = vec0_rowid_from_id(p, item, &rowid);
|
||||
|
|
|
|||
|
|
@ -1420,7 +1420,6 @@ def test_vec0_point():
|
|||
assert execute_all(db, "select * from t2 where id = 'xxx'") == []
|
||||
|
||||
|
||||
# @pytest.mark.skip(reason="TODO failing locally for some reason")
|
||||
def test_vec0_text_pk():
|
||||
db = connect(EXT_PATH)
|
||||
db.execute(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue