diff --git a/sqlite-vec.c b/sqlite-vec.c index 64e5442..747dbaf 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -3036,8 +3036,8 @@ int vec0_get_id_value_from_rowid(vec0_vtab *pVtab, i64 rowid, // TODO: test / evidence-of sqlite3_bind_int64(pVtab->stmtRowidsGetChunkPosition, 1, rowid); rc = sqlite3_step(pVtab->stmtRowidsGetChunkPosition); - if (rc == SQLITE_ROW) { - return SQLITE_ERROR; + if (rc != SQLITE_ROW) { + goto cleanup; } sqlite3_value *value = sqlite3_column_value(pVtab->stmtRowidsGetChunkPosition, 0); @@ -3050,7 +3050,44 @@ cleanup: return rc; } -// TODO make sure callees use the return value of this function +int vec0_rowid_from_id(vec0_vtab *p, sqlite3_value *valueId, i64 *rowid) { + sqlite3_stmt *stmt = NULL; + int rc; + char *zSql; + zSql = sqlite3_mprintf("SELECT rowid" + " FROM " VEC0_SHADOW_ROWIDS_NAME " WHERE id = ?", + p->schemaName, p->tableName); + if (!zSql) { + rc = SQLITE_NOMEM; + goto cleanup; + } + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) { + goto cleanup; + } + sqlite3_bind_value(stmt, 1, valueId); + rc = sqlite3_step(stmt); + if (rc == SQLITE_DONE) { + rc = SQLITE_EMPTY; + goto cleanup; + } + if (rc != SQLITE_ROW) { + goto cleanup; + } + *rowid = sqlite3_column_int64(stmt, 0); + rc = sqlite3_step(stmt); + if (rc != SQLITE_DONE) { + goto cleanup; + } + + rc = SQLITE_OK; + +cleanup: + sqlite3_finalize(stmt); + return rc; +} + int vec0_result_id(vec0_vtab *p, sqlite3_context *context, i64 rowid) { if (!p->pkIsText) { sqlite3_result_int64(context, rowid); @@ -3085,31 +3122,59 @@ int vec0_result_id(vec0_vtab *p, sqlite3_context *context, i64 rowid) { int vec0_get_vector_data(vec0_vtab *pVtab, i64 rowid, int vector_column_idx, void **outVector, int *outVectorSize) { int rc; + i64 chunk_id; + i64 chunk_offset; + size_t size; + void *buf = NULL; + int blobOffset; assert((vector_column_idx >= 0) && (vector_column_idx < pVtab->numVectorColumns)); sqlite3_bind_int64(pVtab->stmtRowidsGetChunkPosition, 1, rowid); rc = sqlite3_step(pVtab->stmtRowidsGetChunkPosition); + if (rc == SQLITE_DONE) { + rc = SQLITE_EMPTY; + goto cleanup; + } if (rc != SQLITE_ROW) { - vtab_set_error(&pVtab->base, "fuck"); // TODO + vtab_set_error(&pVtab->base, "Could not find a row with id %lld", rowid); rc = SQLITE_ERROR; goto cleanup; } - i64 chunk_id = sqlite3_column_int64(pVtab->stmtRowidsGetChunkPosition, 1); - i64 chunk_offset = sqlite3_column_int64(pVtab->stmtRowidsGetChunkPosition, 2); + chunk_id = sqlite3_column_int64(pVtab->stmtRowidsGetChunkPosition, 1); + chunk_offset = sqlite3_column_int64(pVtab->stmtRowidsGetChunkPosition, 2); rc = sqlite3_blob_reopen(pVtab->vectorBlobs[vector_column_idx], chunk_id); - todo_assert(rc == SQLITE_OK); - size_t size = - vector_column_byte_size(pVtab->vector_columns[vector_column_idx]); - int blobOffset = chunk_offset * size; + if (rc != SQLITE_OK) { + vtab_set_error( + &pVtab->base, + "Could not fetch vector data for %lld, reopening blob failed", rowid); + rc = SQLITE_ERROR; + goto cleanup; + } + + size = vector_column_byte_size(pVtab->vector_columns[vector_column_idx]); + blobOffset = chunk_offset * size; + + buf = sqlite3_malloc(size); + if (!buf) { + rc = SQLITE_ERROR; + goto cleanup; + } - void *buf = sqlite3_malloc(size); - todo_assert(buf); rc = sqlite3_blob_read(pVtab->vectorBlobs[vector_column_idx], buf, size, blobOffset); - todo_assert(rc == SQLITE_OK); + if (rc != SQLITE_OK) { + sqlite3_free(buf); + buf = NULL; + vtab_set_error( + &pVtab->base, + "Could not fetch vector data for %lld, reading from blob failed", + rowid); + rc = SQLITE_ERROR; + goto cleanup; + } *outVector = buf; if (outVectorSize) { @@ -3273,15 +3338,15 @@ struct vec0_query_fullscan_data { sqlite3_stmt *rowids_stmt; i8 done; }; -int vec0_query_fullscan_data_clear( +void vec0_query_fullscan_data_clear( struct vec0_query_fullscan_data *fullscan_data) { - int rc; + if (!fullscan_data) + return; + if (fullscan_data->rowids_stmt) { - rc = sqlite3_finalize(fullscan_data->rowids_stmt); - todo_assert(rc == SQLITE_OK); + sqlite3_finalize(fullscan_data->rowids_stmt); fullscan_data->rowids_stmt = NULL; } - return SQLITE_OK; } struct vec0_query_knn_data { @@ -3292,7 +3357,10 @@ struct vec0_query_knn_data { f32 *distances; i64 current_idx; }; -int vec0_query_knn_data_clear(struct vec0_query_knn_data *knn_data) { +void vec0_query_knn_data_clear(struct vec0_query_knn_data *knn_data) { + if (!knn_data) + return; + if (knn_data->rowids) { sqlite3_free(knn_data->rowids); knn_data->rowids = NULL; @@ -3301,7 +3369,6 @@ int vec0_query_knn_data_clear(struct vec0_query_knn_data *knn_data) { sqlite3_free(knn_data->distances); knn_data->distances = NULL; } - return SQLITE_OK; } struct vec0_query_point_data { @@ -3310,6 +3377,8 @@ struct vec0_query_point_data { int done; }; void vec0_query_point_data_clear(struct vec0_query_point_data *point_data) { + if (!point_data) + return; for (int i = 0; i < VEC0_MAX_VECTOR_COLUMNS; i++) { sqlite3_free(point_data->vectors[i]); point_data->vectors[i] = NULL; @@ -3326,17 +3395,6 @@ struct vec0_cursor { struct vec0_query_point_data *point_data; }; -#define SET_VTAB_ERROR(msg) \ - do { \ - sqlite3_free(pVTab->zErrMsg); \ - pVTab->zErrMsg = sqlite3_mprintf("%s", msg); \ - } while (0) -#define SET_VTAB_CURSOR_ERROR(msg) \ - do { \ - sqlite3_free(pVtabCursor->pVtab->zErrMsg); \ - pVtabCursor->pVtab->zErrMsg = sqlite3_mprintf("%s", msg); \ - } while (0) - #define VEC_CONSTRUCTOR_ERROR "vec0 constructor error: " static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, sqlite3_vtab **ppVtab, char **pzErr, bool isCreate) { @@ -3780,21 +3838,21 @@ static int vec0Open(sqlite3_vtab *p, sqlite3_vtab_cursor **ppCursor) { } static int vec0Close(sqlite3_vtab_cursor *cur) { - int rc; vec0_cursor *pCur = (vec0_cursor *)cur; if (pCur->fullscan_data) { - rc = vec0_query_fullscan_data_clear(pCur->fullscan_data); - todo_assert(rc == SQLITE_OK); + vec0_query_fullscan_data_clear(pCur->fullscan_data); sqlite3_free(pCur->fullscan_data); + pCur->fullscan_data = NULL; } if (pCur->knn_data) { - rc = vec0_query_knn_data_clear(pCur->knn_data); - todo_assert(rc == SQLITE_OK); + vec0_query_knn_data_clear(pCur->knn_data); sqlite3_free(pCur->knn_data); + pCur->knn_data = NULL; } if (pCur->point_data) { vec0_query_point_data_clear(pCur->point_data); sqlite3_free(pCur->point_data); + pCur->point_data = NULL; } sqlite3_free(pCur); return SQLITE_OK; @@ -3849,7 +3907,8 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { if (op == SQLITE_INDEX_CONSTRAINT_MATCH && vec0_column_idx_is_vector(p, iColumn)) { if (iMatchTerm > -1) { - // TODO only 1 match operator at a time + vtab_set_error( + pVTab, "only 1 MATCH operator is allowed in a single vec0 query"); return SQLITE_ERROR; } iMatchTerm = i; @@ -3860,7 +3919,11 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { } if (op == SQLITE_INDEX_CONSTRAINT_EQ && iColumn == VEC0_COLUMN_ID) { if (vtabIn) { - todo_assert(iRowidInTerm == -1); + if (iRowidInTerm != -1) { + vtab_set_error(pVTab, "only 1 'rowid in (..)' operator is allowed in " + "a single vec0 query"); + return SQLITE_ERROR; + } iRowidInTerm = i; } else { @@ -3873,35 +3936,36 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { } if (iMatchTerm >= 0) { if (iLimitTerm < 0 && iKTerm < 0) { - // TODO: error, match on vector1 should require a limit for KNN. right? + vtab_set_error( + pVTab, + "A LIMIT or 'k = ?' constraint is required on vec0 knn queries."); return SQLITE_ERROR; } if (iLimitTerm >= 0 && iKTerm >= 0) { + vtab_set_error(pVTab, "Only LIMIT or 'k =?' can be provided, not both"); return SQLITE_ERROR; } - if (pIdxInfo->nOrderBy < 1) { - // TODO error, `ORDER BY DISTANCE required - SET_VTAB_ERROR("ORDER BY distance required"); - return SQLITE_CONSTRAINT; - } - if (pIdxInfo->nOrderBy > 1) { - // TODO error, orderByConsumed is all or nothing, only 1 order by allowed - SET_VTAB_ERROR("more than 1 ORDER BY clause provided"); - return SQLITE_CONSTRAINT; - } - if (pIdxInfo->aOrderBy[0].iColumn != vec0_column_distance_idx(p)) { - // TODO error, ORDER BY must be on column - SET_VTAB_ERROR("ORDER BY must be on the distance column"); - return SQLITE_CONSTRAINT; - } - if (pIdxInfo->aOrderBy[0].desc) { - // TODO KNN should be ascending, is descending possible? - SET_VTAB_ERROR("Only ascending in ORDER BY distance clause is supported, " - "DESC is not supported yet."); - return SQLITE_CONSTRAINT; + + if (pIdxInfo->nOrderBy) { + if (pIdxInfo->nOrderBy > 1) { + vtab_set_error(pVTab, "Only a single 'ORDER BY distance' clause is " + "allowed on vec0 KNN queries"); + return SQLITE_ERROR; + } + if (pIdxInfo->aOrderBy[0].iColumn != vec0_column_distance_idx(p)) { + vtab_set_error(pVTab, + "Only a single 'ORDER BY distance' clause is allowed on " + "vec0 KNN queries, not on other columns"); + return SQLITE_ERROR; + } + if (pIdxInfo->aOrderBy[0].desc) { + vtab_set_error( + pVTab, "Only ascending in ORDER BY distance clause is supported, " + "DESC is not supported yet."); + return SQLITE_ERROR; + } } - pIdxInfo->orderByConsumed = 1; pIdxInfo->aConstraintUsage[iMatchTerm].argvIndex = 1; pIdxInfo->aConstraintUsage[iMatchTerm].omit = 1; if (iLimitTerm >= 0) { @@ -4073,6 +4137,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum, char *err; rc = vector_from_value(argv[0], &queryVector, &dimensions, &elementType, &cleanup, &err); + todo_assert(rc == SQLITE_OK); todo_assert(elementType == vector_column->element_type); todo_assert(dimensions == vector_column->dimensions); @@ -4296,64 +4361,105 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum, return SQLITE_OK; } -int vec0Filter_fullscan(vec0_cursor *pCur, vec0_vtab *p, int idxNum, - const char *idxStr, int argc, sqlite3_value **argv) { - UNUSED_PARAMETER(idxNum); - UNUSED_PARAMETER(idxStr); - UNUSED_PARAMETER(argc); - UNUSED_PARAMETER(argv); +int vec0Filter_fullscan(vec0_vtab *p, vec0_cursor *pCur) { int rc; char *zSql; + struct vec0_query_fullscan_data *fullscan_data; - pCur->query_plan = SQLITE_VEC0_QUERYPLAN_FULLSCAN; - struct vec0_query_fullscan_data *fullscan_data = - sqlite3_malloc(sizeof(struct vec0_query_fullscan_data)); + fullscan_data = sqlite3_malloc(sizeof(*fullscan_data)); if (!fullscan_data) { return SQLITE_NOMEM; } - memset(fullscan_data, 0, sizeof(struct vec0_query_fullscan_data)); + memset(fullscan_data, 0, sizeof(*fullscan_data)); + zSql = sqlite3_mprintf(" SELECT rowid " " FROM " VEC0_SHADOW_ROWIDS_NAME " ORDER by chunk_id, chunk_offset ", p->schemaName, p->tableName); - todo_assert(zSql); + if (!zSql) { + rc = SQLITE_NOMEM; + goto error; + } rc = sqlite3_prepare_v2(p->db, zSql, -1, &fullscan_data->rowids_stmt, NULL); sqlite3_free(zSql); - todo_assert(rc == SQLITE_OK); - rc = sqlite3_step(fullscan_data->rowids_stmt); - fullscan_data->done = rc == SQLITE_DONE; - if (!(rc == SQLITE_ROW || rc == SQLITE_DONE)) { - vec0_query_fullscan_data_clear(fullscan_data); - return SQLITE_ERROR; + if (rc != SQLITE_OK) { + // IMP: V09901_26739 + vtab_set_error(&p->base, "Error preparing rowid scan: %s", + sqlite3_errmsg(p->db)); + goto error; } + + rc = sqlite3_step(fullscan_data->rowids_stmt); + + // DONE when there's no rowids, ROW when there are, both "success" + if (!(rc == SQLITE_ROW || rc == SQLITE_DONE)) { + goto error; + } + + fullscan_data->done = rc == SQLITE_DONE; + pCur->query_plan = SQLITE_VEC0_QUERYPLAN_FULLSCAN; pCur->fullscan_data = fullscan_data; return SQLITE_OK; + +error: + vec0_query_fullscan_data_clear(fullscan_data); + sqlite3_free(fullscan_data); + return rc; } -int vec0Filter_point(vec0_cursor *pCur, vec0_vtab *p, int idxNum, - const char *idxStr, int argc, sqlite3_value **argv) { - UNUSED_PARAMETER(idxNum); - UNUSED_PARAMETER(idxStr); +int vec0Filter_point(vec0_cursor *pCur, vec0_vtab *p, int argc, + sqlite3_value **argv) { int rc; assert(argc == 1); - i64 rowid = sqlite3_value_int64(argv[0]); + i64 rowid; + struct vec0_query_point_data *point_data = NULL; - pCur->query_plan = SQLITE_VEC0_QUERYPLAN_POINT; - struct vec0_query_point_data *point_data = - sqlite3_malloc(sizeof(struct vec0_query_point_data)); + point_data = sqlite3_malloc(sizeof(*point_data)); if (!point_data) { - return SQLITE_NOMEM; + rc = SQLITE_NOMEM; + goto error; + } + memset(point_data, 0, sizeof(*point_data)); + + if (p->pkIsText) { + rc = vec0_rowid_from_id(p, argv[0], &rowid); + if (rc == SQLITE_EMPTY) { + goto eof; + } + if (rc != SQLITE_OK) { + goto error; + } + } else { + rowid = sqlite3_value_int64(argv[0]); } - memset(point_data, 0, sizeof(struct vec0_query_point_data)); for (int i = 0; i < p->numVectorColumns; i++) { rc = vec0_get_vector_data(p, rowid, i, &point_data->vectors[i], NULL); - assert(rc == SQLITE_OK); + if (rc == SQLITE_EMPTY) { + goto eof; + } + if (rc != SQLITE_OK) { + goto error; + } } + point_data->rowid = rowid; point_data->done = 0; pCur->point_data = point_data; + pCur->query_plan = SQLITE_VEC0_QUERYPLAN_POINT; return SQLITE_OK; + +eof: + point_data->rowid = rowid; + point_data->done = 1; + pCur->point_data = point_data; + pCur->query_plan = SQLITE_VEC0_QUERYPLAN_POINT; + return SQLITE_OK; + +error: + vec0_query_point_data_clear(point_data); + sqlite3_free(point_data); + return rc; } static int vec0Filter(sqlite3_vtab_cursor *pVtabCursor, int idxNum, @@ -4361,23 +4467,26 @@ static int vec0Filter(sqlite3_vtab_cursor *pVtabCursor, int idxNum, vec0_cursor *pCur = (vec0_cursor *)pVtabCursor; vec0_vtab *p = (vec0_vtab *)pVtabCursor->pVtab; if (strcmp(idxStr, VEC0_QUERY_PLAN_FULLSCAN) == 0) { - return vec0Filter_fullscan(pCur, p, idxNum, idxStr, argc, argv); + return vec0Filter_fullscan(p, pCur); } else if (strncmp(idxStr, "knn:", 4) == 0) { return vec0Filter_knn(pCur, p, idxNum, idxStr, argc, argv); } else if (strcmp(idxStr, VEC0_QUERY_PLAN_POINT) == 0) { - return vec0Filter_point(pCur, p, idxNum, idxStr, argc, argv); + return vec0Filter_point(pCur, p, argc, argv); } else { - SET_VTAB_CURSOR_ERROR("unknown idxStr"); + vtab_set_error(pVtabCursor->pVtab, "unknown idxStr '%s'", idxStr); return SQLITE_ERROR; } } static int vec0Rowid(sqlite3_vtab_cursor *cur, sqlite_int64 *pRowid) { - UNUSED_PARAMETER(cur); - UNUSED_PARAMETER(pRowid); vec0_cursor *pCur = (vec0_cursor *)cur; - todo_assert(pCur->query_plan == SQLITE_VEC0_QUERYPLAN_POINT); - todo_assert(pCur->point_data); + if ((pCur->query_plan != SQLITE_VEC0_QUERYPLAN_POINT) || + (!pCur->point_data)) { + vtab_set_error( + cur->pVtab, + "Internal sqlite-vec error: exepcted point query plan in vec0Rowid"); + return SQLITE_ERROR; + } *pRowid = pCur->point_data->rowid; return SQLITE_OK; } @@ -4386,48 +4495,58 @@ static int vec0Next(sqlite3_vtab_cursor *cur) { vec0_cursor *pCur = (vec0_cursor *)cur; switch (pCur->query_plan) { case SQLITE_VEC0_QUERYPLAN_FULLSCAN: { - todo_assert(pCur->fullscan_data); + if (!pCur->fullscan_data) { + return SQLITE_ERROR; + } int rc = sqlite3_step(pCur->fullscan_data->rowids_stmt); if (rc == SQLITE_DONE) { pCur->fullscan_data->done = 1; return SQLITE_OK; } if (rc == SQLITE_ROW) { - // TODO error handle return SQLITE_OK; } return SQLITE_ERROR; } case SQLITE_VEC0_QUERYPLAN_KNN: { - todo_assert(pCur->knn_data); + if (!pCur->knn_data) { + return SQLITE_ERROR; + } + pCur->knn_data->current_idx++; return SQLITE_OK; } case SQLITE_VEC0_QUERYPLAN_POINT: { - todo_assert(pCur->point_data); + if (!pCur->point_data) { + return SQLITE_ERROR; + } pCur->point_data->done = 1; return SQLITE_OK; } - default: { - todo("point next impl"); - } } + return SQLITE_ERROR; } static int vec0Eof(sqlite3_vtab_cursor *cur) { vec0_cursor *pCur = (vec0_cursor *)cur; switch (pCur->query_plan) { case SQLITE_VEC0_QUERYPLAN_FULLSCAN: { - todo_assert(pCur->fullscan_data); + if (!pCur->fullscan_data) { + return 1; + } return pCur->fullscan_data->done; } case SQLITE_VEC0_QUERYPLAN_KNN: { - todo_assert(pCur->knn_data); + if (!pCur->knn_data) { + return 1; + } return (pCur->knn_data->current_idx >= pCur->knn_data->k) || (pCur->knn_data->distances[pCur->knn_data->current_idx] == FLT_MAX); } case SQLITE_VEC0_QUERYPLAN_POINT: { - todo_assert(pCur->point_data); + if (!pCur->point_data) { + return 1; + } return pCur->point_data->done; } } @@ -4435,21 +4554,26 @@ static int vec0Eof(sqlite3_vtab_cursor *cur) { static int vec0Column_fullscan(vec0_vtab *pVtab, vec0_cursor *pCur, sqlite3_context *context, int i) { - todo_assert(pCur->fullscan_data); + if (!pCur->fullscan_data) { + sqlite3_result_error( + context, "Internal sqlite-vec error: fullscan_data is NULL.", -1); + return SQLITE_ERROR; + } i64 rowid = sqlite3_column_int64(pCur->fullscan_data->rowids_stmt, 0); if (i == VEC0_COLUMN_ID) { - vec0_result_id(pVtab, context, rowid); + return vec0_result_id(pVtab, context, rowid); } else if (vec0_column_idx_is_vector(pVtab, i)) { void *v; int sz; int vector_idx = vec0_column_idx_to_vector_idx(pVtab, i); int rc = vec0_get_vector_data(pVtab, rowid, vector_idx, &v, &sz); - todo_assert(rc == SQLITE_OK); - sqlite3_result_blob(context, v, sz, SQLITE_TRANSIENT); + if (rc != SQLITE_OK) { + return rc; + } + sqlite3_result_blob(context, v, sz, sqlite3_free); sqlite3_result_subtype(context, pVtab->vector_columns[vector_idx].element_type); - sqlite3_free(v); } else if (i == vec0_column_distance_idx(pVtab)) { sqlite3_result_null(context); } else { @@ -4460,16 +4584,18 @@ static int vec0Column_fullscan(vec0_vtab *pVtab, vec0_cursor *pCur, static int vec0Column_point(vec0_vtab *pVtab, vec0_cursor *pCur, sqlite3_context *context, int i) { - todo_assert(pCur->point_data); + if (!pCur->point_data) { + sqlite3_result_error(context, + "Internal sqlite-vec error: point_data is NULL.", -1); + return SQLITE_ERROR; + } if (i == VEC0_COLUMN_ID) { - vec0_result_id(pVtab, context, pCur->point_data->rowid); - return SQLITE_OK; + return vec0_result_id(pVtab, context, pCur->point_data->rowid); } if (i == vec0_column_distance_idx(pVtab)) { sqlite3_result_null(context); return SQLITE_OK; } - // TODO only have 1st vector data if (vec0_column_idx_is_vector(pVtab, i)) { if (sqlite3_vtab_nochange(context)) { sqlite3_result_null(context); @@ -4490,11 +4616,14 @@ static int vec0Column_point(vec0_vtab *pVtab, vec0_cursor *pCur, static int vec0Column_knn(vec0_vtab *pVtab, vec0_cursor *pCur, sqlite3_context *context, int i) { - todo_assert(pCur->knn_data); + if (!pCur->knn_data) { + sqlite3_result_error(context, + "Internal sqlite-vec error: knn_data is NULL.", -1); + return SQLITE_ERROR; + } if (i == VEC0_COLUMN_ID) { i64 rowid = pCur->knn_data->rowids[pCur->knn_data->current_idx]; - vec0_result_id(pVtab, context, rowid); - return SQLITE_OK; + return vec0_result_id(pVtab, context, rowid); } if (i == vec0_column_distance_idx(pVtab)) { sqlite3_result_double( @@ -5259,7 +5388,6 @@ int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite_int64 rowid) { // 5. Delete value in _rowids table // 1. get chunk_id and chunk_offset from _rowids - // TODO how to make this fail without failing the point query rc = vec0_get_chunk_position(p, rowid, &chunk_id, &chunk_offset); if (rc != SQLITE_OK) { return rc; @@ -5427,10 +5555,8 @@ static int vec0Update(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, vtab_set_error(pVTab, "UPDATE operation on rowids with vec0 is not supported."); return SQLITE_ERROR; - } - // unknown operation - else { - SET_VTAB_ERROR("Unrecognized xUpdate operation provided for vec0."); + } else { + vtab_set_error(pVTab, "Unrecognized xUpdate operation provided for vec0."); return SQLITE_ERROR; } } diff --git a/tests/test-loadable.py b/tests/test-loadable.py index 60b44b5..e6cf5cd 100644 --- a/tests/test-loadable.py +++ b/tests/test-loadable.py @@ -16,6 +16,8 @@ EXT_PATH = "./dist/vec0" SUPPORTS_SUBTYPE = sqlite3.sqlite_version_info[1] > 38 SUPPORTS_DROP_COLUMN = sqlite3.sqlite_version_info[1] >= 35 +SUPPORTS_VTAB_IN = sqlite3.sqlite_version_info[1] >= 38 +SUPPORTS_VTAB_LIMIT = sqlite3.sqlite_version_info[1] >= 41 def bitmap_full(n: int) -> bytearray: @@ -1133,38 +1135,138 @@ def test_vec0_updates(): # ] +def test_vec0_point(): + db = connect(EXT_PATH) + db.execute("CREATE VIRTUAL TABLE t USING vec0(a float[1], b float[1])") + db.execute( + "INSERT INTO t VALUES (1, X'AABBCCDD', X'00112233'), (2, X'AABBCCDD', X'99887766');" + ) + + assert execute_all(db, "select * from t where rowid = 1") == [ + { + "a": b"\xaa\xbb\xcc\xdd", + "b": b'\x00\x11"3', + "rowid": 1, + } + ] + assert execute_all(db, "select * from t where rowid = 999") == [] + + db.execute( + "CREATE VIRTUAL TABLE t2 USING vec0(id text primary key, a float[1], b float[1])" + ) + db.execute( + "INSERT INTO t2 VALUES ('A', X'AABBCCDD', X'00112233'), ('B', X'AABBCCDD', X'99887766');" + ) + + assert execute_all(db, "select * from t2 where id = 'A'") == [ + { + "a": b"\xaa\xbb\xcc\xdd", + "b": b'\x00\x11"3', + "id": "A", + } + ] + + assert execute_all(db, "select * from t2 where id = 'xxx'") == [] + + def test_vec0_text_pk(): db = connect(EXT_PATH) db.execute( """ create virtual table t using vec0( t_id text primary key, - aaa float[8], - bbb float8[8] + aaa float[1], + bbb float8[1] ); """ ) + assert execute_all(db, "select * from t") == [] + + with _raises( + "The t virtual table was declared with a TEXT primary key, but a non-TEXT value was provided in an INSERT." + ): + db.execute("INSERT INTO t VALUES (1, X'AABBCCDD', X'AABBCCDD')") + db.executemany( "INSERT INTO t VALUES (:t_id, :aaa, :bbb)", [ { "t_id": "t_1", - "aaa": "[.1, .1, .1, .1, -.1, -.1, -.1, -.1]", - "bbb": "[.1, .1, .1, .1, -.1, -.1, -.1, -.1]", + "aaa": "[.1]", + "bbb": "[-.1]", }, { "t_id": "t_2", - "aaa": "[.1, .1, .1, .1, -.1, -.1, -.1, -.1]", - "bbb": "[.1, .1, .1, .1, -.1, -.1, -.1, -.1]", + "aaa": "[.2]", + "bbb": "[-.2]", }, { "t_id": "t_3", - "aaa": "[.1, .1, .1, .1, -.1, -.1, -.1, -.1]", - "bbb": "[.1, .1, .1, .1, -.1, -.1, -.1, -.1]", + "aaa": "[.3]", + "bbb": "[-.3]", }, ], ) - assert execute_all(db, "select * from t") == [] + assert execute_all(db, "select t_id from t") == [ + {"t_id": "t_1"}, + {"t_id": "t_2"}, + {"t_id": "t_3"}, + ] + assert execute_all(db, "select * from t") == [ + {"t_id": "t_1", "aaa": _f32([0.1]), "bbb": _f32([-0.1])}, + {"t_id": "t_2", "aaa": _f32([0.2]), "bbb": _f32([-0.2])}, + {"t_id": "t_3", "aaa": _f32([0.3]), "bbb": _f32([-0.3])}, + ] + + # EVIDENCE-OF: V09901_26739 vec0 full scan catches _rowid prep error + db.set_authorizer(authorizer_deny_on(sqlite3.SQLITE_READ, "t_rowids", "rowid")) + with _raises( + "Error preparing rowid scan: access to t_rowids.rowid is prohibited", + sqlite3.DatabaseError, + ): + db.execute("select * from t") + db.set_authorizer(None) + + +def test_vec0_best_index(): + db = connect(EXT_PATH) + db.execute( + """ + create virtual table t using vec0( + aaa float[1], + bbb float8[1] + ); + """ + ) + + with _raises("only 1 MATCH operator is allowed in a single vec0 query"): + db.execute("select * from t where aaa match NULL and bbb match NULL") + + if SUPPORTS_VTAB_IN: + with _raises( + "only 1 'rowid in (..)' operator is allowed in a single vec0 query" + ): + db.execute("select * from t where rowid in(4,5,6) and rowid in (1, 2,3)") + + with _raises("A LIMIT or 'k = ?' constraint is required on vec0 knn queries."): + db.execute("select * from t where aaa MATCH ?") + + with _raises("Only LIMIT or 'k =?' can be provided, not both"): + db.execute("select * from t where aaa MATCH ? and k = 10 limit 20") + + with _raises( + "Only a single 'ORDER BY distance' clause is allowed on vec0 KNN queries" + ): + db.execute( + "select * from t where aaa MATCH NULL and k = 10 order by distance, distance" + ) + + with _raises( + "Only ascending in ORDER BY distance clause is supported, DESC is not supported yet." + ): + db.execute( + "select * from t where aaa MATCH NULL and k = 10 order by distance desc" + ) def authorizer_deny_on(operation, x1, x2=None): @@ -1610,6 +1712,13 @@ def test_smoke(): "select * from vec_xyz where a match X'' and k = 10 order by distance" ), ) + if SUPPORTS_VTAB_LIMIT: + assert re.match( + "SCAN (TABLE )?vec_xyz VIRTUAL TABLE INDEX 0:knn:", + explain_query_plan( + "select * from vec_xyz where a match X'' order by distance limit 10" + ), + ) assert re.match( "SCAN (TABLE )?vec_xyz VIRTUAL TABLE INDEX 0:fullscan", explain_query_plan("select * from vec_xyz"),