diff --git a/.gitignore b/.gitignore index 404b43b..a64dfe2 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ sift/ *.tar.gz *.db +*.npy *.bin *.out venv/ diff --git a/sqlite-vec.c b/sqlite-vec.c index fb670cc..5de42d0 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -77,17 +77,8 @@ typedef size_t usize; #define UNUSED_PARAMETER(X) (void)(X) #endif -#ifndef todo_assert -#define todo_assert(X) assert(X) -#endif - #define countof(x) (sizeof(x) / sizeof((x)[0])) - -#define todo(msg) \ - do { \ - fprintf(stderr, "TODO: %s\n", msg); \ - exit(1); \ - } while (0) +#define min(a, b) (((a) <= (b)) ? (a) : (b)) enum VectorElementType { SQLITE_VEC_ELEMENT_TYPE_FLOAT32 = 223 + 0, @@ -456,13 +447,15 @@ int array_append(struct Array *array, const void *element) { return SQLITE_NOMEM; } } - memcpy(& ((unsigned char *) array->z)[array->length * array->element_size], element, - array->element_size); + memcpy(&((unsigned char *)array->z)[array->length * array->element_size], + element, array->element_size); array->length++; return SQLITE_OK; } void array_cleanup(struct Array *array) { + if (!array) + return; array->element_size = 0; array->length = 0; array->capacity = 0; @@ -1082,13 +1075,22 @@ static void vec_quantize_i8(sqlite3_context *context, int argc, sqlite3_value **argv) { f32 *srcVector; size_t dimensions; - fvec_cleanup cleanup; + fvec_cleanup srcCleanup; char *err; - int rc = fvec_from_value(argv[0], &srcVector, &dimensions, &cleanup, &err); - todo_assert(rc == SQLITE_OK); + i8 *out = NULL; + int rc = fvec_from_value(argv[0], &srcVector, &dimensions, &srcCleanup, &err); + if (rc != SQLITE_OK) { + sqlite3_result_error(context, err, -1); + sqlite3_free(err); + return; + } + int sz = dimensions * sizeof(i8); - i8 *out = sqlite3_malloc(sz); - todo_assert(out); + out = sqlite3_malloc(sz); + if (!out) { + rc = SQLITE_NOMEM; + goto cleanup; + } memset(out, 0, sz); if (argc == 2) { @@ -1100,9 +1102,8 @@ static void vec_quantize_i8(sqlite3_context *context, int argc, "2nd argument to vec_quantize_i8() must be 'unit', " "or ranges must be provided.", -1); - cleanup(srcVector); sqlite3_free(out); - return; + goto cleanup; } f32 step = (1.0 - (-1.0)) / 255; for (size_t i = 0; i < dimensions; i++) { @@ -1113,13 +1114,19 @@ static void vec_quantize_i8(sqlite3_context *context, int argc, // size_t d; // fvec_cleanup minCleanup, maxCleanup; // int rc = fvec_from_value(argv[1], ) - todo("ranges"); + + sqlite3_free(out); + // TODO + sqlite3_result_error( + context, "ranges parameter not supported in vec_quantize_i8 yet.", -1); + goto cleanup; } - cleanup(srcVector); sqlite3_result_blob(context, out, dimensions * sizeof(i8), sqlite3_free); sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_INT8); - return; + +cleanup: + srcCleanup(srcVector); } static void vec_quantize_binary(sqlite3_context *context, int argc, @@ -1310,6 +1317,7 @@ static void vec_slice(sqlite3_context *context, int argc, int start = sqlite3_value_int(argv[1]); int end = sqlite3_value_int(argv[2]); + if (start < 0) { sqlite3_result_error(context, "slice 'start' index must be a postive number.", -1); @@ -1337,7 +1345,13 @@ static void vec_slice(sqlite3_context *context, int argc, "slice 'start' index is greater than 'end' index", -1); goto done; } - // TODO check start == end + if (start == end) { + sqlite3_result_error(context, + "slice 'start' index is equal to the 'end' index, " + "vectors must have non-zero length", + -1); + goto done; + } size_t n = end - start; switch (elementType) { @@ -1346,13 +1360,13 @@ static void vec_slice(sqlite3_context *context, int argc, f32 *out = sqlite3_malloc(outSize); if (!out) { sqlite3_result_error_nomem(context); - return; + goto done; } memset(out, 0, outSize); for (size_t i = 0; i < n; i++) { out[i] = ((f32 *)vector)[start + i]; } - sqlite3_result_blob(context, out, n * sizeof(f32), sqlite3_free); + sqlite3_result_blob(context, out, outSize, sqlite3_free); sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_FLOAT32); goto done; } @@ -1367,7 +1381,7 @@ static void vec_slice(sqlite3_context *context, int argc, for (size_t i = 0; i < n; i++) { out[i] = ((i8 *)vector)[start + i]; } - sqlite3_result_blob(context, out, n * sizeof(i8), sqlite3_free); + sqlite3_result_blob(context, out, outSize, sqlite3_free); sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_INT8); goto done; } @@ -1390,7 +1404,7 @@ static void vec_slice(sqlite3_context *context, int argc, for (size_t i = 0; i < n / CHAR_BIT; i++) { out[i] = ((u8 *)vector)[(start / CHAR_BIT) + i]; } - sqlite3_result_blob(context, out, n / CHAR_BIT, sqlite3_free); + sqlite3_result_blob(context, out, outSize, sqlite3_free); sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_BIT); goto done; } @@ -1916,7 +1930,7 @@ static int vec_eachConnect(sqlite3 *db, void *pAux, int argc, UNUSED_PARAMETER(pAux); UNUSED_PARAMETER(argc); UNUSED_PARAMETER(argv); - UNUSED_PARAMETER(pzErr); // TODO use + UNUSED_PARAMETER(pzErr); vec_each_vtab *pNew; int rc; @@ -1952,13 +1966,15 @@ static int vec_eachOpen(sqlite3_vtab *p, sqlite3_vtab_cursor **ppCursor) { static int vec_eachClose(sqlite3_vtab_cursor *cur) { vec_each_cursor *pCur = (vec_each_cursor *)cur; + pCur->cleanup(pCur->vector); sqlite3_free(pCur); return SQLITE_OK; } static int vec_eachBestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { - int hasVector; + UNUSED_PARAMETER(pVTab); + int hasVector = 0; for (int i = 0; i < pIdxInfo->nConstraint; i++) { const struct sqlite3_index_constraint *pCons = &pIdxInfo->aConstraint[i]; // printf("i=%d iColumn=%d, op=%d, usable=%d\n", i, pCons->iColumn, @@ -1975,8 +1991,7 @@ static int vec_eachBestIndex(sqlite3_vtab *pVTab, } } if (!hasVector) { - pVTab->zErrMsg = sqlite3_mprintf("vector argument is required"); - return SQLITE_ERROR; + return SQLITE_CONSTRAINT; } pIdxInfo->estimatedCost = (double)100000; @@ -1992,6 +2007,11 @@ static int vec_eachFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum, assert(argc == 1); vec_each_cursor *pCur = (vec_each_cursor *)pVtabCursor; + if (pCur->vector) { + pCur->cleanup(pCur->vector); + pCur->vector = NULL; + } + char *pzErrMsg; int rc = vector_from_value(argv[0], &pCur->vector, &pCur->dimensions, &pCur->vector_type, &pCur->cleanup, &pzErrMsg); @@ -2726,14 +2746,17 @@ static int vec_npy_eachColumnBuffer(vec_npy_each_cursor *pCur, case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { sqlite3_result_blob( context, - & ((unsigned char *) pCur->vector)[pCur->iRowid * pCur->nDimensions * sizeof(f32)], + &((unsigned char *) + pCur->vector)[pCur->iRowid * pCur->nDimensions * sizeof(f32)], pCur->nDimensions * sizeof(f32), SQLITE_STATIC); break; } case SQLITE_VEC_ELEMENT_TYPE_INT8: case SQLITE_VEC_ELEMENT_TYPE_BIT: { - todo("bit array npy column"); + // TODO + sqlite3_result_error(context, + "vec_npy_each only supports float32 vectors", -1); break; } } @@ -2749,14 +2772,19 @@ static int vec_npy_eachColumnFile(vec_npy_each_cursor *pCur, case VEC_NPY_EACH_COLUMN_VECTOR: { switch (pCur->elementType) { case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { - sqlite3_result_blob(context, - & ((unsigned char *)pCur->chunksBuffer)[pCur->currentChunkIndex * pCur->nDimensions * sizeof(f32)], - pCur->nDimensions * sizeof(f32), SQLITE_TRANSIENT); + sqlite3_result_blob( + context, + &((unsigned char *) + pCur->chunksBuffer)[pCur->currentChunkIndex * + pCur->nDimensions * sizeof(f32)], + pCur->nDimensions * sizeof(f32), SQLITE_TRANSIENT); break; } case SQLITE_VEC_ELEMENT_TYPE_INT8: case SQLITE_VEC_ELEMENT_TYPE_BIT: { - todo("bit array npy column"); + // TODO + sqlite3_result_error(context, + "vec_npy_each only supports float32 vectors", -1); break; } } @@ -3026,8 +3054,8 @@ int vec0_column_k_idx(vec0_vtab *pVtab) { */ int vec0_column_idx_is_vector(vec0_vtab *pVtab, int column_idx) { return column_idx >= VEC0_COLUMN_VECTORN_START && - column_idx <= (VEC0_COLUMN_VECTORN_START + pVtab->numVectorColumns - - 1); // TODO is -1 necessary here? + column_idx <= + (VEC0_COLUMN_VECTORN_START + pVtab->numVectorColumns - 1); } /** @@ -3153,6 +3181,7 @@ int vec0_get_vector_data(vec0_vtab *pVtab, i64 rowid, int vector_column_idx, sqlite3_bind_int64(pVtab->stmtRowidsGetChunkPosition, 1, rowid); rc = sqlite3_step(pVtab->stmtRowidsGetChunkPosition); if (rc == SQLITE_DONE) { + // TODO error message on callers rc = SQLITE_EMPTY; goto cleanup; } @@ -3371,6 +3400,7 @@ void vec0_query_fullscan_data_clear( struct vec0_query_knn_data { i64 k; + i64 k_used; // Array of rowids of size k. Must be freed with sqlite3_free(). i64 *rowids; // Array of distances of size k. Must be freed with sqlite3_free(). @@ -3805,14 +3835,15 @@ static int vec0Destroy(sqlite3_vtab *pVtab) { vec0_free_resources(p); - // TODO evidence-of here - + // later: can't evidence-of here, bc always gives "SQL logic error" instead of + // provided error zSql = sqlite3_mprintf("DROP TABLE " VEC0_SHADOW_CHUNKS_NAME, p->schemaName, p->tableName); rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0); sqlite3_free((void *)zSql); if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) { rc = SQLITE_ERROR; + vtab_set_error(pVtab, "could not drop chunks shadow table"); goto done; } sqlite3_finalize(stmt); @@ -3840,9 +3871,13 @@ static int vec0Destroy(sqlite3_vtab *pVtab) { } stmt = NULL; rc = SQLITE_OK; + done: - sqlite3_free(p); sqlite3_finalize(stmt); + vec0_free(p); + if (rc == SQLITE_OK) { + sqlite3_free(p); + } return rc; } @@ -3857,8 +3892,7 @@ static int vec0Open(sqlite3_vtab *p, sqlite3_vtab_cursor **ppCursor) { return SQLITE_OK; } -static int vec0Close(sqlite3_vtab_cursor *cur) { - vec0_cursor *pCur = (vec0_cursor *)cur; +void vec0CursorClear(vec0_cursor *pCur) { if (pCur->fullscan_data) { vec0_query_fullscan_data_clear(pCur->fullscan_data); sqlite3_free(pCur->fullscan_data); @@ -3874,6 +3908,11 @@ static int vec0Close(sqlite3_vtab_cursor *cur) { sqlite3_free(pCur->point_data); pCur->point_data = NULL; } +} + +static int vec0Close(sqlite3_vtab_cursor *cur) { + vec0_cursor *pCur = (vec0_cursor *)cur; + vec0CursorClear(pCur); sqlite3_free(pCur); return SQLITE_OK; } @@ -4036,62 +4075,91 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { // forward delcaration bc vec0Filter uses it static int vec0Next(sqlite3_vtab_cursor *cur); -// TODO: Ya this shit is slow -void dethrone(int k, f32 *base_distances, i64 *base_rowids, size_t chunk_size, - i32 *chunk_top_idx, f32 *chunk_distances, i64 *chunk_rowids, - - i64 **out_rowids, f32 **out_distances) { - *out_rowids = sqlite3_malloc(k * sizeof(i64)); - todo_assert(out_rowids); - *out_distances = sqlite3_malloc(k * sizeof(f32)); - todo_assert(out_distances); - - size_t ptrA = 0; - size_t ptrB = 0; - for (int i = 0; i < k; i++) { - if (chunk_distances[chunk_top_idx[ptrA]] < base_distances[ptrB]) { - (*out_rowids)[i] = chunk_rowids[chunk_top_idx[ptrA]]; - (*out_distances)[i] = chunk_distances[chunk_top_idx[ptrA]]; - // TODO if ptrA at chunk_size-1 is always minimum, won't it always repeat? - if (ptrA < (chunk_size - 1)) { - ptrA++; - } - } else { - (*out_rowids)[i] = base_rowids[ptrB]; - (*out_distances)[i] = base_distances[ptrB]; - ptrB++; +void merge_sorted_lists(f32 *a, i64 *a_rowids, i64 a_length, f32 *b, + i64 *b_rowids, i32 *b_top_idxs, i64 b_length, f32 *out, + i64 *out_rowids, i64 out_length, i64 *out_used) { + // assert((a_length >= out_length) || (b_length >= out_length)); + i64 ptrA = 0; + i64 ptrB = 0; + for (int i = 0; i < out_length; i++) { + if ((ptrA >= a_length) && (ptrB >= b_length)) { + *out_used = i; + return; } - } -} - -/* -// TODO is this better?? from vec_expo experiment -void dethrone2(int k, f32 *base_distances, i64 *base_rowids, size_t chunk_size, - i32 *chunk_top_idx, f32 *chunk_distances, i64 *chunk_rowids, - - i64 **out_rowids, f32 **out_distances) { - *out_rowids = sqlite3_malloc(k * sizeof(i64)); - todo_assert(*out_rowids); - *out_distances = sqlite3_malloc(k * sizeof(f32)); - todo_assert(*out_distances); - - size_t ptrA = 0; - size_t ptrB = 0; - for (int i = 0; i < k; i++) { - if (ptrA < chunk_size && - (ptrB >= k || - chunk_distances[chunk_top_idx[ptrA]] < base_distances[ptrB])) { - (*out_rowids)[i] = chunk_rowids[chunk_top_idx[ptrA]]; - (*out_distances)[i] = chunk_distances[chunk_top_idx[ptrA]]; + if (ptrA >= a_length) { + out[i] = b[b_top_idxs[ptrB]]; + out_rowids[i] = b_rowids[b_top_idxs[ptrB]]; + ptrB++; + } else if (ptrB >= b_length) { + out[i] = a[ptrA]; + out_rowids[i] = a_rowids[ptrA]; ptrA++; - } else if (ptrB < k) { - (*out_rowids)[i] = base_rowids[ptrB]; - (*out_distances)[i] = base_distances[ptrB]; - ptrB++; + } else { + if (a[ptrA] <= b[b_top_idxs[ptrB]]) { + out[i] = a[ptrA]; + out_rowids[i] = a_rowids[ptrA]; + ptrA++; + } else { + out[i] = b[b_top_idxs[ptrB]]; + out_rowids[i] = b_rowids[b_top_idxs[ptrB]]; + ptrB++; + } } } + + *out_used = out_length; +} + +u8 *bitmap_new(i32 n) { + assert(n % 8 == 0); + u8 *p = sqlite3_malloc(n * sizeof(u8) / CHAR_BIT); + if (p) { + memset(p, 0, n * sizeof(u8) / CHAR_BIT); + } + return p; +} +u8 *bitmap_new_from(i32 n, u8 *from) { + assert(n % 8 == 0); + u8 *p = sqlite3_malloc(n * sizeof(u8) / CHAR_BIT); + if (p) { + memcpy(p, from, n / CHAR_BIT); + } + return p; +} + +void bitmap_copy(u8 *base, u8 *from, i32 n) { + assert(n % 8 == 0); + memcpy(base, from, n / CHAR_BIT); +} + +void bitmap_and_inplace(u8 *base, u8 *other, i32 n) { + assert((n % 8) == 0); + for (int i = 0; i < n / CHAR_BIT; i++) { + base[i] = base[i] & other[i]; + } +} + +void bitmap_set(u8 *bitmap, i32 position, int value) { + bitmap[position / CHAR_BIT] |= value << (position % CHAR_BIT); +} + +int bitmap_get(u8 *bitmap, i32 position) { + return (((bitmap[position / CHAR_BIT]) >> (position % CHAR_BIT)) & 1); +} + +void bitmap_clear(u8 *bitmap, i32 n) { + assert((n % 8) == 0); + memset(bitmap, 0, n / CHAR_BIT); +} + +void bitmap_debug(u8 *bitmap, i32 n) { + for (int i = 0; i < n; i++) { + printf("%d", bitmap_get(bitmap, i)); + if (i > 0 && (i % 8 == 0)) + printf("|"); + } + printf("\n"); } -*/ /** * @brief Finds the minimum k items in distances, and writes the indicies to @@ -4099,286 +4167,479 @@ void dethrone2(int k, f32 *base_distances, i64 *base_rowids, size_t chunk_size, * * @param distances input f32 array of size n, the items to consider. * @param n: size of distances array. - * @param out: Output array of size k, will contain the minumum k element - * indicies + * @param out: Output array of size k, will contain at most k element indicies * @param k: Size of output array * @return int */ -int min_idx(const f32 *distances, i32 n, i32 *out, i32 k) { - todo_assert(k > 0); - todo_assert(k <= n); +int min_idx(const f32 *distances, i32 n, u8 *candidates, i32 *out, i32 k, + u8 *bTaken, i32 *k_used) { + assert(k > 0); + assert(k <= n); - unsigned char *taken = malloc(n * sizeof(unsigned char)); - todo_assert(taken); - memset(taken, 0, n); + bitmap_clear(bTaken, n); for (int ik = 0; ik < k; ik++) { int min_idx = 0; - while (min_idx < n && taken[min_idx]) { + while (min_idx < n && + (bitmap_get(bTaken, min_idx) || !bitmap_get(candidates, min_idx))) { min_idx++; } - todo_assert(min_idx < n); + if (min_idx >= n) { + *k_used = ik; + return SQLITE_OK; + } for (int i = 0; i < n; i++) { - if (distances[i] < distances[min_idx] && !taken[i]) { + if (distances[i] <= distances[min_idx] && !bitmap_get(bTaken, i) && + (bitmap_get(candidates, i))) { min_idx = i; } } out[ik] = min_idx; - taken[min_idx] = 1; + bitmap_set(bTaken, min_idx, 1); } - free(taken); + *k_used = k; return SQLITE_OK; } +int vec0Filter_knn_chunks_iter(vec0_vtab *p, sqlite3_stmt *stmtChunks, + struct VectorColumnDefinition *vector_column, + int vectorColumnIdx, struct Array *arrayRowidsIn, + void *queryVector, i64 k, i64 **out_topk_rowids, + f32 **out_topk_distances, i64 *out_used) { + // for each chunk, get top min(k, chunk_size) rowid + distances to query vec. + // then reconcile all topk_chunks for a true top k. + // output only rowids + distances for now + + int rc = SQLITE_OK; + sqlite3_blob *blobVectors = NULL; + + void *baseVectors = NULL; // memory: chunk_size * dimensions * element_size + + // OWNED BY CALLER ON SUCCESS + i64 *topk_rowids = NULL; // memory: k * 4 + // OWNED BY CALLER ON SUCCESS + f32 *topk_distances = NULL; // memory: k * 4 + + i64 *tmp_topk_rowids = NULL; // memory: k * 4 + f32 *tmp_topk_distances = NULL; // memory: k * 4 + f32 *chunk_distances = NULL; // memory: chunk_size * 4 + u8 *b = NULL; // memory: chunk_size / 8 + u8 *bTaken = NULL; // memory: chunk_size / 8 + i32 *chunk_topk_idxs = NULL; // memory: k * 4 + u8 *bmRowids = NULL; // memory: chunk_size / 8 + // // total: a lot??? + + // 6 * (k * 4) + (k * 2) + (chunk_size / 8) + (chunk_size * dimensions * 4) + + topk_rowids = sqlite3_malloc(k * sizeof(i64)); + if (!topk_rowids) { + rc = SQLITE_NOMEM; + goto cleanup; + } + memset(topk_rowids, 0, k * sizeof(i64)); + + topk_distances = sqlite3_malloc(k * sizeof(f32)); + if (!topk_distances) { + rc = SQLITE_NOMEM; + goto cleanup; + } + memset(topk_distances, 0, k * sizeof(f32)); + + tmp_topk_rowids = sqlite3_malloc(k * sizeof(i64)); + if (!tmp_topk_rowids) { + rc = SQLITE_NOMEM; + goto cleanup; + } + memset(tmp_topk_rowids, 0, k * sizeof(i64)); + + tmp_topk_distances = sqlite3_malloc(k * sizeof(f32)); + if (!tmp_topk_distances) { + rc = SQLITE_NOMEM; + goto cleanup; + } + memset(tmp_topk_distances, 0, k * sizeof(f32)); + + i64 k_used = 0; + i64 baseVectorsSize = p->chunk_size * vector_column_byte_size(*vector_column); + baseVectors = sqlite3_malloc(baseVectorsSize); + if (!baseVectors) { + rc = SQLITE_NOMEM; + goto cleanup; + } + + chunk_distances = sqlite3_malloc(p->chunk_size * sizeof(f32)); + if (!chunk_distances) { + rc = SQLITE_NOMEM; + goto cleanup; + } + + b = bitmap_new(p->chunk_size); + if (!b) { + rc = SQLITE_NOMEM; + goto cleanup; + } + + bTaken = bitmap_new(p->chunk_size); + if (!bTaken) { + rc = SQLITE_NOMEM; + goto cleanup; + } + + chunk_topk_idxs = sqlite3_malloc(k * sizeof(i32)); + if (!chunk_topk_idxs) { + rc = SQLITE_NOMEM; + goto cleanup; + } + + bmRowids = arrayRowidsIn ? bitmap_new(p->chunk_size) : NULL; + if (arrayRowidsIn && !bmRowids) { + rc = SQLITE_NOMEM; + goto cleanup; + } + + while (true) { + rc = sqlite3_step(stmtChunks); + if (rc == SQLITE_DONE) { + break; + } + if (rc != SQLITE_ROW) { + vtab_set_error(&p->base, "chunks iter error"); + rc = SQLITE_ERROR; + goto cleanup; + } + memset(chunk_distances, 0, p->chunk_size * sizeof(f32)); + memset(chunk_topk_idxs, 0, k * sizeof(i32)); + bitmap_clear(b, p->chunk_size); + + i64 chunk_id = sqlite3_column_int64(stmtChunks, 0); + unsigned char *chunkValidity = + (unsigned char *)sqlite3_column_blob(stmtChunks, 1); + i64 validitySize = sqlite3_column_bytes(stmtChunks, 1); + if (validitySize != p->chunk_size / CHAR_BIT) { + // IMP: V05271_22109 + vtab_set_error( + &p->base, + "chunk validity size doesn't match - expected %lld, found %lld", + p->chunk_size / CHAR_BIT, validitySize); + rc = SQLITE_ERROR; + goto cleanup; + } + + i64 *chunkRowids = (i64 *)sqlite3_column_blob(stmtChunks, 2); + i64 rowidsSize = sqlite3_column_bytes(stmtChunks, 2); + if (rowidsSize != p->chunk_size * sizeof(i64)) { + // IMP: V02796_19635 + vtab_set_error(&p->base, "rowids size doesn't match"); + vtab_set_error( + &p->base, + "chunk rowids size doesn't match - expected %lld, found %lld", + p->chunk_size * sizeof(i64), rowidsSize); + rc = SQLITE_ERROR; + goto cleanup; + } + + // open the vector chunk blob for the current chunk + rc = sqlite3_blob_open(p->db, p->schemaName, + p->shadowVectorChunksNames[vectorColumnIdx], + "vectors", chunk_id, 0, &blobVectors); + if (rc != SQLITE_OK) { + vtab_set_error(&p->base, "could not open vectors blob for chunk %lld", + chunk_id); + rc = SQLITE_ERROR; + goto cleanup; + } + + i64 currentBaseVectorsSize = sqlite3_blob_bytes(blobVectors); + i64 expectedBaseVectorsSize = + p->chunk_size * vector_column_byte_size(*vector_column); + if (currentBaseVectorsSize != expectedBaseVectorsSize) { + // IMP: V16465_00535 + vtab_set_error( + &p->base, + "vectors blob size doesn't match - expected %lld, found %lld", + expectedBaseVectorsSize, currentBaseVectorsSize); + rc = SQLITE_ERROR; + goto cleanup; + } + rc = sqlite3_blob_read(blobVectors, baseVectors, currentBaseVectorsSize, 0); + + if (rc != SQLITE_OK) { + vtab_set_error(&p->base, "vectors blob read error for %lld", chunk_id); + rc = SQLITE_ERROR; + goto cleanup; + } + + bitmap_copy(b, chunkValidity, p->chunk_size); + if (arrayRowidsIn) { + bitmap_clear(bmRowids, p->chunk_size); + + for (int i = 0; i < p->chunk_size; i++) { + if (!bitmap_get(chunkValidity, i)) { + continue; + } + i64 rowid = chunkRowids[i]; + void *in = bsearch(&rowid, arrayRowidsIn->z, arrayRowidsIn->length, + sizeof(i64), _cmp); + bitmap_set(bmRowids, i, in ? 1 : 0); + } + bitmap_and_inplace(b, bmRowids, p->chunk_size); + } + + for (int i = 0; i < p->chunk_size; i++) { + if (!bitmap_get(b, i)) { + continue; + }; + + f32 result; + switch (vector_column->element_type) { + case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { + const f32 *base_i = + ((f32 *)baseVectors) + (i * vector_column->dimensions); + switch (vector_column->distance_metric) { + case VEC0_DISTANCE_METRIC_L2: { + result = distance_l2_sqr_float(base_i, (f32 *)queryVector, + &vector_column->dimensions); + break; + } + case VEC0_DISTANCE_METRIC_COSINE: { + result = distance_cosine_float(base_i, (f32 *)queryVector, + &vector_column->dimensions); + break; + } + } + break; + } + case SQLITE_VEC_ELEMENT_TYPE_INT8: { + const i8 *base_i = + ((i8 *)baseVectors) + (i * vector_column->dimensions); + switch (vector_column->distance_metric) { + case VEC0_DISTANCE_METRIC_L2: { + result = distance_l2_sqr_int8(base_i, (i8 *)queryVector, + &vector_column->dimensions); + break; + } + case VEC0_DISTANCE_METRIC_COSINE: { + result = distance_cosine_int8(base_i, (i8 *)queryVector, + &vector_column->dimensions); + break; + } + } + + break; + } + case SQLITE_VEC_ELEMENT_TYPE_BIT: { + const u8 *base_i = + ((u8 *)baseVectors) + (i * (vector_column->dimensions / CHAR_BIT)); + result = distance_hamming(base_i, (u8 *)queryVector, + &vector_column->dimensions); + break; + } + } + + chunk_distances[i] = result; + } + + int used1; + min_idx(chunk_distances, p->chunk_size, b, chunk_topk_idxs, + min(k, p->chunk_size), bTaken, &used1); + + i64 used; + merge_sorted_lists(topk_distances, topk_rowids, k_used, chunk_distances, + chunkRowids, chunk_topk_idxs, + min(min(k, p->chunk_size), used1), tmp_topk_distances, + tmp_topk_rowids, k, &used); + + for (int i = 0; i < used; i++) { + topk_rowids[i] = tmp_topk_rowids[i]; + topk_distances[i] = tmp_topk_distances[i]; + } + k_used = used; + sqlite3_blob_close(blobVectors); + blobVectors = NULL; + } + + *out_topk_rowids = topk_rowids; + *out_topk_distances = topk_distances; + *out_used = k_used; + rc = SQLITE_OK; + +cleanup: + if (rc != SQLITE_OK) { + sqlite3_free(topk_rowids); + sqlite3_free(topk_distances); + } + sqlite3_free(chunk_topk_idxs); + sqlite3_free(tmp_topk_rowids); + sqlite3_free(tmp_topk_distances); + sqlite3_free(b); + sqlite3_free(bTaken); + sqlite3_free(bmRowids); + sqlite3_free(baseVectors); + sqlite3_free(chunk_distances); + sqlite3_blob_close(blobVectors); + return rc; +} + int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum, const char *idxStr, int argc, sqlite3_value **argv) { - UNUSED_PARAMETER(idxNum); UNUSED_PARAMETER(idxStr); assert(argc >= 2); int rc; - pCur->query_plan = SQLITE_VEC0_QUERYPLAN_KNN; - struct vec0_query_knn_data *knn_data = - sqlite3_malloc(sizeof(struct vec0_query_knn_data)); - if (!knn_data) { - return SQLITE_NOMEM; - } - memset(knn_data, 0, sizeof(struct vec0_query_knn_data)); + struct vec0_query_knn_data *knn_data; int vectorColumnIdx = idxNum; struct VectorColumnDefinition *vector_column = &p->vector_columns[vectorColumnIdx]; + struct Array *arrayRowidsIn = NULL; + sqlite3_stmt *stmtChunks = NULL; void *queryVector; size_t dimensions; enum VectorElementType elementType; - vector_cleanup cleanup; - char *err; + vector_cleanup queryVectorCleanup = vector_cleanup_noop; + char *pzError; + knn_data = sqlite3_malloc(sizeof(*knn_data)); + if (!knn_data) { + return SQLITE_NOMEM; + } + memset(knn_data, 0, sizeof(*knn_data)); + rc = vector_from_value(argv[0], &queryVector, &dimensions, &elementType, - &cleanup, &err); - todo_assert(rc == SQLITE_OK); - todo_assert(elementType == vector_column->element_type); - todo_assert(dimensions == vector_column->dimensions); + &queryVectorCleanup, &pzError); + + if (rc != SQLITE_OK) { + vtab_set_error(&p->base, + "Query vector on the \"%.*s\" column is invalid: %z", + vector_column->name_length, vector_column->name, pzError); + rc = SQLITE_ERROR; + goto cleanup; + } + if (elementType != vector_column->element_type) { + vtab_set_error( + &p->base, + "Query vector for the \"%.*s\" column is expected to be of type " + "%s, but a %s vector was provided.", + vector_column->name_length, vector_column->name, + vector_subtype_name(vector_column->element_type), + vector_subtype_name(elementType)); + rc = SQLITE_ERROR; + goto cleanup; + } + if (dimensions != vector_column->dimensions) { + vtab_set_error( + &p->base, + "Dimension mismatch for inserted vector for the \"%.*s\" column. " + "Expected %d dimensions but received %d.", + vector_column->name_length, vector_column->name, + vector_column->dimensions, dimensions); + rc = SQLITE_ERROR; + goto cleanup; + } i64 k = sqlite3_value_int64(argv[1]); - todo_assert(k >= 0); + if (k < 0) { + vtab_set_error( + &p->base, "k value in knn queries must be greater than or equal to 0."); + rc = SQLITE_ERROR; + goto cleanup; + } + if (k == 0) { knn_data->k = 0; pCur->knn_data = knn_data; - return SQLITE_OK; + pCur->query_plan = SQLITE_VEC0_QUERYPLAN_KNN; + rc = SQLITE_OK; + goto cleanup; } // handle when a `rowid in (...)` operation was provided // Array of all the rowids that appear in any `rowid in (...)` constraint. // NULL if none were provided, which means a "full" scan. - struct Array *arrayRowidsIn = NULL; if (argc > 2) { sqlite3_value *item; int rc; - arrayRowidsIn = sqlite3_malloc(sizeof(struct Array)); - todo_assert(arrayRowidsIn); + arrayRowidsIn = sqlite3_malloc(sizeof(*arrayRowidsIn)); + if (!arrayRowidsIn) { + rc = SQLITE_NOMEM; + goto cleanup; + } + memset(arrayRowidsIn, 0, sizeof(*arrayRowidsIn)); + rc = array_init(arrayRowidsIn, sizeof(i64), 32); - todo_assert(rc == SQLITE_OK); + if (rc != SQLITE_OK) { + goto cleanup; + } + for (rc = sqlite3_vtab_in_first(argv[2], &item); rc == SQLITE_OK && item; rc = sqlite3_vtab_in_next(argv[2], &item)) { - i64 rowid = sqlite3_value_int64(item); + i64 rowid; + if (p->pkIsText) { + rc = vec0_rowid_from_id(p, item, &rowid); + if (rc != SQLITE_OK) { + goto cleanup; + } + } else { + rowid = sqlite3_value_int64(item); + } rc = array_append(arrayRowidsIn, &rowid); - todo_assert(rc == SQLITE_OK); + if (rc != SQLITE_OK) { + goto cleanup; + } + } + if (rc != SQLITE_DONE) { + vtab_set_error(&p->base, "error processing rowid in (...) array"); + goto cleanup; } - todo_assert(rc == SQLITE_DONE); qsort(arrayRowidsIn->z, arrayRowidsIn->length, arrayRowidsIn->element_size, _cmp); } - i64 *topk_rowids = sqlite3_malloc(k * sizeof(i64)); - todo_assert(topk_rowids); - for (int i = 0; i < k; i++) { - // TODO do we need to ensure that rowid is never -1? - topk_rowids[i] = -1; + char *zSql; + zSql = sqlite3_mprintf("select chunk_id, validity, rowids " + " from " VEC0_SHADOW_CHUNKS_NAME, + p->schemaName, p->tableName); + if (!zSql) { + rc = SQLITE_NOMEM; + goto cleanup; } - f32 *topk_distances = sqlite3_malloc(k * sizeof(f32)); - todo_assert(topk_distances); - for (int i = 0; i < k; i++) { - topk_distances[i] = FLT_MAX; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmtChunks, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) { + // IMP: V06942_23781 + vtab_set_error(&p->base, "Error preparing stmtChunk: %s", + sqlite3_errmsg(p->db)); + goto cleanup; } - // for each chunk, get top min(k, chunk_size) rowid + distances to query vec. - // then reconcile all topk_chunks for a true top k. - // output only rowids + distances for now - - { - sqlite3_blob *blobVectors; - sqlite3_stmt *stmtChunks; - char *zSql; - zSql = sqlite3_mprintf("select chunk_id, validity, rowids " - " from " VEC0_SHADOW_CHUNKS_NAME, - p->schemaName, p->tableName); - todo_assert(zSql); - rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmtChunks, NULL); - sqlite3_free(zSql); - todo_assert(rc == SQLITE_OK); - - void *baseVectors = NULL; - i64 baseVectorsSize = 0; - - while (true) { - rc = sqlite3_step(stmtChunks); - if (rc == SQLITE_DONE) - break; - if (rc != SQLITE_ROW) { - todo("chunks iter error"); - } - i64 chunk_id = sqlite3_column_int64(stmtChunks, 0); - unsigned char *chunkValidity = - (unsigned char *)sqlite3_column_blob(stmtChunks, 1); - i64 validitySize = sqlite3_column_bytes(stmtChunks, 1); - todo_assert(validitySize == p->chunk_size / CHAR_BIT); - i64 *chunkRowids = (i64 *)sqlite3_column_blob(stmtChunks, 2); - i64 rowidsSize = sqlite3_column_bytes(stmtChunks, 2); - todo_assert(rowidsSize == p->chunk_size * sizeof(i64)); - - // open the vector chunk blob for the current chunk - rc = sqlite3_blob_open(p->db, p->schemaName, - p->shadowVectorChunksNames[vectorColumnIdx], - "vectors", chunk_id, 0, &blobVectors); - todo_assert(rc == SQLITE_OK); - i64 currentBaseVectorsSize = sqlite3_blob_bytes(blobVectors); - todo_assert((unsigned long)currentBaseVectorsSize == - p->chunk_size * vector_column_byte_size(*vector_column)); - - if (currentBaseVectorsSize > baseVectorsSize) { - if (baseVectors) { - sqlite3_free(baseVectors); - } - baseVectors = sqlite3_malloc(currentBaseVectorsSize); - todo_assert(baseVectors); - baseVectorsSize = currentBaseVectorsSize; - } - rc = sqlite3_blob_read(blobVectors, baseVectors, currentBaseVectorsSize, - 0); - todo_assert(rc == SQLITE_OK); - - // TODO realloc here, like baseVectors - f32 *chunk_distances = sqlite3_malloc(p->chunk_size * sizeof(f32)); - todo_assert(chunk_distances); - - for (int i = 0; i < p->chunk_size; i++) { - - // Ensure the current vector is "valid" in the validity bitmap. - // If not, skip and continue on - if (!(((chunkValidity[i / CHAR_BIT]) >> (i % CHAR_BIT)) & 1)) { - chunk_distances[i] = FLT_MAX; - continue; - }; - // If pre-filtering, make sure the rowid appears in the `rowid in (...)` - // list. - if (arrayRowidsIn) { - i64 rowid = chunkRowids[i]; - void *in = bsearch(&rowid, arrayRowidsIn->z, arrayRowidsIn->length, - sizeof(i64), _cmp); - if (!in) { - chunk_distances[i] = FLT_MAX; - continue; - } - } - - f32 result; - switch (vector_column->element_type) { - case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { - const f32 *base_i = - ((f32 *)baseVectors) + (i * vector_column->dimensions); - switch (vector_column->distance_metric) { - case VEC0_DISTANCE_METRIC_L2: { - result = distance_l2_sqr_float(base_i, (f32 *)queryVector, - &vector_column->dimensions); - break; - } - case VEC0_DISTANCE_METRIC_COSINE: { - result = distance_cosine_float(base_i, (f32 *)queryVector, - &vector_column->dimensions); - break; - } - } - - // result = distance_cosine(base_i, (f32 *) queryVector, & - // vector_column->dimensions); - break; - } - case SQLITE_VEC_ELEMENT_TYPE_INT8: { - const i8 *base_i = - ((i8 *)baseVectors) + (i * vector_column->dimensions); - switch (vector_column->distance_metric) { - case VEC0_DISTANCE_METRIC_L2: { - result = distance_l2_sqr_int8(base_i, (i8 *)queryVector, - &vector_column->dimensions); - - break; - } - case VEC0_DISTANCE_METRIC_COSINE: { - result = distance_cosine_int8(base_i, (i8 *)queryVector, - &vector_column->dimensions); - break; - } - } - - break; - } - case SQLITE_VEC_ELEMENT_TYPE_BIT: { - const u8 *base_i = ((u8 *)baseVectors) + - (i * (vector_column->dimensions / CHAR_BIT)); - result = distance_hamming(base_i, (u8 *)queryVector, - &vector_column->dimensions); - break; - } - } - - chunk_distances[i] = result; - } - - // now that we have the distances - i32 *chunk_topk_idxs = sqlite3_malloc(k * sizeof(i32)); - todo_assert(chunk_topk_idxs); - min_idx(chunk_distances, p->chunk_size, chunk_topk_idxs, - k <= p->chunk_size ? k : p->chunk_size); - - i64 *out_rowids; - f32 *out_distances; - dethrone(k, topk_distances, topk_rowids, p->chunk_size, chunk_topk_idxs, - chunk_distances, chunkRowids, - - &out_rowids, &out_distances); - for (int i = 0; i < k; i++) { - topk_rowids[i] = out_rowids[i]; - topk_distances[i] = out_distances[i]; - } - sqlite3_free(out_rowids); - sqlite3_free(out_distances); - sqlite3_free(chunk_distances); - sqlite3_free(chunk_topk_idxs); - - sqlite3_blob_close(blobVectors); - } - - sqlite3_free(baseVectors); - rc = sqlite3_finalize(stmtChunks); - todo_assert(rc == SQLITE_OK); - - if (arrayRowidsIn) { - array_cleanup(arrayRowidsIn); - sqlite3_free(arrayRowidsIn); - } + i64 *topk_rowids = NULL; + f32 *topk_distances = NULL; + i64 k_used = 0; + rc = vec0Filter_knn_chunks_iter(p, stmtChunks, vector_column, vectorColumnIdx, + arrayRowidsIn, queryVector, k, &topk_rowids, + &topk_distances, &k_used); + if (rc != SQLITE_OK) { + goto cleanup; } - cleanup(queryVector); - knn_data->current_idx = 0; knn_data->k = k; knn_data->rowids = topk_rowids; knn_data->distances = topk_distances; + knn_data->k_used = k_used; pCur->knn_data = knn_data; - return SQLITE_OK; + pCur->query_plan = SQLITE_VEC0_QUERYPLAN_KNN; + rc = SQLITE_OK; + +cleanup: + sqlite3_finalize(stmtChunks); + array_cleanup(arrayRowidsIn); + sqlite3_free(arrayRowidsIn); + queryVectorCleanup(queryVector); + + return rc; } int vec0Filter_fullscan(vec0_vtab *p, vec0_cursor *pCur) { @@ -4484,8 +4745,9 @@ error: static int vec0Filter(sqlite3_vtab_cursor *pVtabCursor, int idxNum, const char *idxStr, int argc, sqlite3_value **argv) { - vec0_cursor *pCur = (vec0_cursor *)pVtabCursor; vec0_vtab *p = (vec0_vtab *)pVtabCursor->pVtab; + vec0_cursor *pCur = (vec0_cursor *)pVtabCursor; + vec0CursorClear(pCur); if (strcmp(idxStr, VEC0_QUERY_PLAN_FULLSCAN) == 0) { return vec0Filter_fullscan(p, pCur); } else if (strncmp(idxStr, "knn:", 4) == 0) { @@ -4500,15 +4762,23 @@ static int vec0Filter(sqlite3_vtab_cursor *pVtabCursor, int idxNum, static int vec0Rowid(sqlite3_vtab_cursor *cur, sqlite_int64 *pRowid) { vec0_cursor *pCur = (vec0_cursor *)cur; - if ((pCur->query_plan != SQLITE_VEC0_QUERYPLAN_POINT) || - (!pCur->point_data)) { - vtab_set_error( - cur->pVtab, - "Internal sqlite-vec error: exepcted point query plan in vec0Rowid"); + switch (pCur->query_plan) { + case SQLITE_VEC0_QUERYPLAN_FULLSCAN: { + *pRowid = sqlite3_column_int64(pCur->fullscan_data->rowids_stmt, 0); + return SQLITE_OK; + } + case SQLITE_VEC0_QUERYPLAN_POINT: { + *pRowid = pCur->point_data->rowid; + return SQLITE_OK; + } + case SQLITE_VEC0_QUERYPLAN_KNN: { + vtab_set_error(cur->pVtab, + "Internal sqlite-vec error: expected point query plan in " + "vec0Rowid, found %d", + pCur->query_plan); return SQLITE_ERROR; } - *pRowid = pCur->point_data->rowid; - return SQLITE_OK; + } } static int vec0Next(sqlite3_vtab_cursor *cur) { @@ -4560,8 +4830,9 @@ static int vec0Eof(sqlite3_vtab_cursor *cur) { if (!pCur->knn_data) { return 1; } - return (pCur->knn_data->current_idx >= pCur->knn_data->k) || - (pCur->knn_data->distances[pCur->knn_data->current_idx] == FLT_MAX); + // return (pCur->knn_data->current_idx >= pCur->knn_data->k) || + // (pCur->knn_data->distances[pCur->knn_data->current_idx] == FLT_MAX); + return (pCur->knn_data->current_idx >= pCur->knn_data->k_used); } case SQLITE_VEC0_QUERYPLAN_POINT: { if (!pCur->point_data) { @@ -4653,11 +4924,16 @@ static int vec0Column_knn(vec0_vtab *pVtab, vec0_cursor *pCur, if (vec0_column_idx_is_vector(pVtab, i)) { void *out; int sz; + int vector_idx = vec0_column_idx_to_vector_idx(pVtab, i); int rc = vec0_get_vector_data( - pVtab, pCur->knn_data->rowids[pCur->knn_data->current_idx], - vec0_column_idx_to_vector_idx(pVtab, i), &out, &sz); - todo_assert(rc == SQLITE_OK); + pVtab, pCur->knn_data->rowids[pCur->knn_data->current_idx], vector_idx, + &out, &sz); + if (rc != SQLITE_OK) { + return rc; + } sqlite3_result_blob(context, out, sz, sqlite3_free); + sqlite3_result_subtype(context, + pVtab->vector_columns[vector_idx].element_type); return SQLITE_OK; } @@ -4900,7 +5176,7 @@ int vec0Update_InsertNextAvailableStep( "validity blob size mismatch on " "%s.%s.%lld, expected %lld but received %lld.", p->schemaName, p->shadowChunksName, *chunk_rowid, - (i64) (p->chunk_size / CHAR_BIT), validitySize); + (i64)(p->chunk_size / CHAR_BIT), validitySize); rc = SQLITE_ERROR; goto cleanup; } @@ -5622,7 +5898,6 @@ static sqlite3_module vec0Module = { }; #pragma endregion -#ifdef SQLITE_VEC_ENABLE_EXPERIMENTAL static char *POINTER_NAME_STATIC_BLOB_DEF = "vec0-static_blob_def"; struct static_blob_definition { void *p; @@ -5632,6 +5907,8 @@ struct static_blob_definition { }; static void vec_static_blob_from_raw(sqlite3_context *context, int argc, sqlite3_value **argv) { + + assert(argc == 4); struct static_blob_definition *p; p = sqlite3_malloc(sizeof(*p)); if (!p) { @@ -5639,7 +5916,7 @@ static void vec_static_blob_from_raw(sqlite3_context *context, int argc, return; } memset(p, 0, sizeof(*p)); - p->p = sqlite3_value_int64(argv[0]); + p->p = (void *)sqlite3_value_int64(argv[0]); p->element_type = SQLITE_VEC_ELEMENT_TYPE_FLOAT32; p->dimensions = sqlite3_value_int64(argv[2]); p->nvectors = sqlite3_value_int64(argv[3]); @@ -5679,6 +5956,10 @@ struct vec_static_blobs_cursor { static int vec_static_blobsConnect(sqlite3 *db, void *pAux, int argc, const char *const *argv, sqlite3_vtab **ppVtab, char **pzErr) { + UNUSED_PARAMETER(argc); + UNUSED_PARAMETER(argv); + UNUSED_PARAMETER(pzErr); + vec_static_blobs_vtab *pNew; #define VEC_STATIC_BLOBS_NAME 0 #define VEC_STATIC_BLOBS_DATA 1 @@ -5705,6 +5986,7 @@ static int vec_static_blobsDisconnect(sqlite3_vtab *pVtab) { static int vec_static_blobsUpdate(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, sqlite_int64 *pRowid) { + UNUSED_PARAMETER(pRowid); vec_static_blobs_vtab *p = (vec_static_blobs_vtab *)pVTab; // DELETE operation if (argc == 1 && sqlite3_value_type(argv[0]) != SQLITE_NULL) { @@ -5712,7 +5994,8 @@ static int vec_static_blobsUpdate(sqlite3_vtab *pVTab, int argc, } // INSERT operation else if (argc > 1 && sqlite3_value_type(argv[0]) == SQLITE_NULL) { - const char *key = sqlite3_value_text(argv[2 + VEC_STATIC_BLOBS_NAME]); + const char *key = + (const char *)sqlite3_value_text(argv[2 + VEC_STATIC_BLOBS_NAME]); int idx = -1; for (int i = 0; i < MAX_STATIC_BLOBS; i++) { if (!p->data->static_blobs[i].name) { @@ -5741,6 +6024,7 @@ static int vec_static_blobsUpdate(sqlite3_vtab *pVTab, int argc, static int vec_static_blobsOpen(sqlite3_vtab *p, sqlite3_vtab_cursor **ppCursor) { + UNUSED_PARAMETER(p); vec_static_blobs_cursor *pCur; pCur = sqlite3_malloc(sizeof(*pCur)); if (pCur == 0) @@ -5758,6 +6042,7 @@ static int vec_static_blobsClose(sqlite3_vtab_cursor *cur) { static int vec_static_blobsBestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { + UNUSED_PARAMETER(pVTab); pIdxInfo->idxNum = 1; pIdxInfo->estimatedCost = (double)10; pIdxInfo->estimatedRows = 10; @@ -5768,6 +6053,10 @@ static int vec_static_blobsNext(sqlite3_vtab_cursor *cur); static int vec_static_blobsFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum, const char *idxStr, int argc, sqlite3_value **argv) { + UNUSED_PARAMETER(idxNum); + UNUSED_PARAMETER(idxStr); + UNUSED_PARAMETER(argc); + UNUSED_PARAMETER(argv); vec_static_blobs_cursor *pCur = (vec_static_blobs_cursor *)pVtabCursor; pCur->iRowid = -1; vec_static_blobsNext(pVtabCursor); @@ -5846,7 +6135,11 @@ static sqlite3_module vec_static_blobsModule = { /* xSavepoint */ 0, /* xRelease */ 0, /* xRollbackTo */ 0, - /* xShadowName */ 0}; + /* xShadowName */ 0, +#if SQLITE_VERSION_NUMBER >= 3044000 + /* xIntegrity */ 0 +#endif +}; #pragma endregion #pragma region vec_static_blob_entries() table function @@ -5872,6 +6165,9 @@ struct vec_static_blob_entries_cursor { static int vec_static_blob_entriesConnect(sqlite3 *db, void *pAux, int argc, const char *const *argv, sqlite3_vtab **ppVtab, char **pzErr) { + UNUSED_PARAMETER(argc); + UNUSED_PARAMETER(argv); + UNUSED_PARAMETER(pzErr); vec_static_blob_data *blob_data = pAux; int idx = -1; for (int i = 0; i < MAX_STATIC_BLOBS; i++) { @@ -5905,7 +6201,7 @@ static int vec_static_blob_entriesConnect(sqlite3 *db, void *pAux, int argc, static int vec_static_blob_entriesCreate(sqlite3 *db, void *pAux, int argc, const char *const *argv, sqlite3_vtab **ppVtab, char **pzErr) { - vec_static_blob_entriesConnect(db, pAux, argc, argv, ppVtab, pzErr); + return vec_static_blob_entriesConnect(db, pAux, argc, argv, ppVtab, pzErr); } static int vec_static_blob_entriesDisconnect(sqlite3_vtab *pVtab) { @@ -5916,6 +6212,7 @@ static int vec_static_blob_entriesDisconnect(sqlite3_vtab *pVtab) { static int vec_static_blob_entriesOpen(sqlite3_vtab *p, sqlite3_vtab_cursor **ppCursor) { + UNUSED_PARAMETER(p); vec_static_blob_entries_cursor *pCur; pCur = sqlite3_malloc(sizeof(*pCur)); if (pCur == 0) @@ -5936,7 +6233,7 @@ static int vec_static_blob_entriesBestIndex(sqlite3_vtab *pVTab, vec_static_blob_entries_vtab *p = (vec_static_blob_entries_vtab *)pVTab; int iMatchTerm = -1; int iLimitTerm = -1; - int iRowidTerm = -1; // TODO point query + // int iRowidTerm = -1; // TODO point query int iKTerm = -1; for (int i = 0; i < pIdxInfo->nConstraint; i++) { @@ -5970,27 +6267,28 @@ static int vec_static_blob_entriesBestIndex(sqlite3_vtab *pVTab, return SQLITE_ERROR; // limit or k, not both } if (pIdxInfo->nOrderBy < 1) { - SET_VTAB_ERROR("ORDER BY distance required"); + vtab_set_error(pVTab, "ORDER BY distance required"); return SQLITE_CONSTRAINT; } if (pIdxInfo->nOrderBy > 1) { // TODO error, orderByConsumed is all or nothing, only 1 order by allowed - SET_VTAB_ERROR("more than 1 ORDER BY clause provided"); + vtab_set_error(pVTab, "more than 1 ORDER BY clause provided"); return SQLITE_CONSTRAINT; } if (pIdxInfo->aOrderBy[0].iColumn != VEC_STATIC_BLOB_ENTRIES_DISTANCE) { - SET_VTAB_ERROR("ORDER BY must be on the distance column"); + vtab_set_error(pVTab, "ORDER BY must be on the distance column"); return SQLITE_CONSTRAINT; } if (pIdxInfo->aOrderBy[0].desc) { - SET_VTAB_ERROR("Only ascending in ORDER BY distance clause is supported, " + vtab_set_error(pVTab, + "Only ascending in ORDER BY distance clause is supported, " "DESC is not supported yet."); return SQLITE_CONSTRAINT; } pIdxInfo->idxNum = VEC_SBE__QUERYPLAN_KNN; - pIdxInfo->estimatedCost = (double)10; // TODO vtab_value(?) as hint? - pIdxInfo->estimatedRows = 10; // TODO vtab_value(?) as hint? + pIdxInfo->estimatedCost = (double)10; + pIdxInfo->estimatedRows = 10; pIdxInfo->orderByConsumed = 1; pIdxInfo->aConstraintUsage[iMatchTerm].argvIndex = 1; @@ -6014,6 +6312,8 @@ static int vec_static_blob_entriesBestIndex(sqlite3_vtab *pVTab, static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum, const char *idxStr, int argc, sqlite3_value **argv) { + UNUSED_PARAMETER(idxStr); + assert(argc >= 0 && argc <= 3); vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)pVtabCursor; vec_static_blob_entries_vtab *p = @@ -6021,12 +6321,12 @@ static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor, if (idxNum == VEC_SBE__QUERYPLAN_KNN) { pCur->query_plan = VEC_SBE__QUERYPLAN_KNN; - struct vec0_query_knn_data *knn_data = - sqlite3_malloc(sizeof(struct vec0_query_knn_data)); + struct vec0_query_knn_data *knn_data; + knn_data = sqlite3_malloc(sizeof(*knn_data)); if (!knn_data) { return SQLITE_NOMEM; } - memset(knn_data, 0, sizeof(struct vec0_query_knn_data)); + memset(knn_data, 0, sizeof(*knn_data)); void *queryVector; size_t dimensions; @@ -6035,30 +6335,45 @@ static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor, char *err; int rc = vector_from_value(argv[0], &queryVector, &dimensions, &elementType, &cleanup, &err); - todo_assert(elementType == p->blob->element_type); - todo_assert(dimensions == p->blob->dimensions); + if (rc != SQLITE_OK) { + return SQLITE_ERROR; + } + if (elementType != p->blob->element_type) { + return SQLITE_ERROR; + } + if (dimensions != p->blob->dimensions) { + return SQLITE_ERROR; + } -#define min(a, b) (((a) < (b)) ? (a) : (b)) - - i64 k = min(sqlite3_value_int64(argv[1]), p->blob->nvectors); - todo_assert(k >= 0); + i64 k = min(sqlite3_value_int64(argv[1]), (i64)p->blob->nvectors); + if (k < 0) { + // TODO handle + return SQLITE_ERROR; + } if (k == 0) { knn_data->k = 0; pCur->knn_data = knn_data; return SQLITE_OK; } - i32 *topk_rowids = sqlite3_malloc(k * sizeof(i32)); - todo_assert(topk_rowids); + i64 *topk_rowids = sqlite3_malloc(k * sizeof(i64)); + if (!topk_rowids) { + // TODO handle + return SQLITE_ERROR; + } f32 *distances = sqlite3_malloc(p->blob->nvectors * sizeof(f32)); - todo_assert(distances); + if (!distances) { + // TODO handle + return SQLITE_ERROR; + } for (size_t i = 0; i < p->blob->nvectors; i++) { float *v = ((float *)p->blob->p) + (i * p->blob->dimensions); distances[i] = distance_l2_sqr_float(v, (float *)queryVector, &p->blob->dimensions); + // TODO other metrics } - min_idx(distances, k, topk_rowids, k); + // min_idx(distances, k, topk_rowids, k); // TODO knn_data->current_idx = 0; knn_data->distances = distances; knn_data->k = k; @@ -6146,7 +6461,7 @@ static int vec_static_blob_entriesColumn(sqlite3_vtab_cursor *cur, static sqlite3_module vec_static_blob_entriesModule = { /* iVersion */ 3, - /* xCreate */ vec_static_blob_entriesCreate, + /* xCreate */ vec_static_blob_entriesCreate, // TODO rm? /* xConnect */ vec_static_blob_entriesConnect, /* xBestIndex */ vec_static_blob_entriesBestIndex, /* xDisconnect */ vec_static_blob_entriesDisconnect, @@ -6168,9 +6483,12 @@ static sqlite3_module vec_static_blob_entriesModule = { /* xSavepoint */ 0, /* xRelease */ 0, /* xRollbackTo */ 0, - /* xShadowName */ 0}; -#pragma endregion + /* xShadowName */ 0, +#if SQLITE_VERSION_NUMBER >= 3044000 + /* xIntegrity */ 0 #endif +}; +#pragma endregion int sqlite3_mmap_warm(sqlite3 *db, const char *zDb) { int rc = SQLITE_OK; @@ -6288,18 +6606,29 @@ __declspec(dllexport) #ifdef _WIN32 __declspec(dllexport) #endif -int sqlite3_vec_init(sqlite3 *db, char **pzErrMsg, - const sqlite3_api_routines *pApi) { + int sqlite3_vec_init(sqlite3 *db, char **pzErrMsg, + const sqlite3_api_routines *pApi) { SQLITE_EXTENSION_INIT2(pApi); int rc = SQLITE_OK; - #define DEFAULT_FLAGS (SQLITE_UTF8 | SQLITE_INNOCUOUS | SQLITE_DETERMINISTIC) + vec_static_blob_data *static_blob_data; + static_blob_data = sqlite3_malloc(sizeof(*static_blob_data)); + if (!static_blob_data) { + return SQLITE_NOMEM; + } + memset(static_blob_data, 0, sizeof(*static_blob_data)); - rc = sqlite3_create_function_v2(db, "vec_version", 0, DEFAULT_FLAGS, SQLITE_VEC_VERSION, _static_text_func, NULL, NULL, NULL); - if(rc != SQLITE_OK) { +#define DEFAULT_FLAGS (SQLITE_UTF8 | SQLITE_INNOCUOUS | SQLITE_DETERMINISTIC) + + rc = sqlite3_create_function_v2(db, "vec_version", 0, DEFAULT_FLAGS, + SQLITE_VEC_VERSION, _static_text_func, NULL, + NULL, NULL); + if (rc != SQLITE_OK) { return rc; } - rc = sqlite3_create_function_v2(db, "vec_debug", 0, DEFAULT_FLAGS, SQLITE_VEC_DEBUG_STRING, _static_text_func, NULL, NULL, NULL); - if(rc != SQLITE_OK) { + rc = sqlite3_create_function_v2(db, "vec_debug", 0, DEFAULT_FLAGS, + SQLITE_VEC_DEBUG_STRING, _static_text_func, + NULL, NULL, NULL); + if (rc != SQLITE_OK) { return rc; } static struct { @@ -6326,21 +6655,10 @@ int sqlite3_vec_init(sqlite3 *db, char **pzErrMsg, {"vec_quantize_i8", vec_quantize_i8, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, }, {"vec_quantize_i8", vec_quantize_i8, 3, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, }, {"vec_quantize_binary", vec_quantize_binary, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, }, - #ifdef SQLITE_VEC_ENABLE_EXPERIMENTAL - {"vec_static_blob_from_raw", vec_static_blob_from_raw, 4, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL }, - #endif + {"vec_static_blob_from_raw", vec_static_blob_from_raw, 4, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE }, // clang-format on }; -#ifdef SQLITE_VEC_ENABLE_EXPERIMENTAL - vec_static_blob_data *static_blob_data; - static_blob_data = sqlite3_malloc(sizeof(*static_blob_data)); - if (!static_blob_data) { - return SQLITE_NOMEM; - } - memset(static_blob_data, 0, sizeof(*static_blob_data)); -#endif - static struct { char *name; const sqlite3_module *module; @@ -6354,11 +6672,10 @@ int sqlite3_vec_init(sqlite3 *db, char **pzErrMsg, // clang-format on }; - for (unsigned long i = 0; - i < sizeof(aFunc) / sizeof(aFunc[0]) && rc == SQLITE_OK; i++) { + for (unsigned long i = 0; i < countof(aFunc) && rc == SQLITE_OK; i++) { rc = sqlite3_create_function_v2(db, aFunc[i].zFName, aFunc[i].nArg, - aFunc[i].flags, NULL, aFunc[i].xFunc, - NULL, NULL, NULL); + aFunc[i].flags, NULL, aFunc[i].xFunc, NULL, + NULL, NULL); if (rc != SQLITE_OK) { *pzErrMsg = sqlite3_mprintf("Error creating function %s: %s", aFunc[i].zFName, sqlite3_errmsg(db)); @@ -6374,7 +6691,6 @@ int sqlite3_vec_init(sqlite3 *db, char **pzErrMsg, return rc; } } -#ifdef SQLITE_VEC_ENABLE_EXPERIMENTAL rc = sqlite3_create_module_v2(db, "vec_static_blobs", &vec_static_blobsModule, static_blob_data, sqlite3_free); assert(rc == SQLITE_OK); @@ -6382,7 +6698,6 @@ int sqlite3_vec_init(sqlite3 *db, char **pzErrMsg, &vec_static_blob_entriesModule, static_blob_data, NULL); assert(rc == SQLITE_OK); -#endif return SQLITE_OK; } @@ -6390,8 +6705,8 @@ int sqlite3_vec_init(sqlite3 *db, char **pzErrMsg, #ifdef _WIN32 __declspec(dllexport) #endif -int sqlite3_vec_fs_read_init(sqlite3 *db, char **pzErrMsg, - const sqlite3_api_routines *pApi) { + int sqlite3_vec_fs_read_init(sqlite3 *db, char **pzErrMsg, + const sqlite3_api_routines *pApi) { UNUSED_PARAMETER(pzErrMsg); SQLITE_EXTENSION_INIT2(pApi); int rc = SQLITE_OK; diff --git a/tests/correctness/build.py b/tests/correctness/build.py new file mode 100644 index 0000000..dbf710d --- /dev/null +++ b/tests/correctness/build.py @@ -0,0 +1,16 @@ +import numpy as np +import duckdb +db = duckdb.connect(":memory:") + +result = db.execute( +""" + select + -- _id, + -- title, + -- text as contents, + embedding::float[] as embeddings + from "hf://datasets/Supabase/dbpedia-openai-3-large-1M/dbpedia_openai_3_large_00.parquet" +""" +).fetchnumpy()['embeddings'] + +np.save("dbpedia_openai_3_large_00.npy", np.vstack(result)) diff --git a/tests/correctness/test-correctness.py b/tests/correctness/test-correctness.py new file mode 100644 index 0000000..cb01f8f --- /dev/null +++ b/tests/correctness/test-correctness.py @@ -0,0 +1,124 @@ +import numpy as np +import numpy.typing as npt +import time +import tqdm +import pytest + +def cosine_similarity( + vec: npt.NDArray[np.float32], mat: npt.NDArray[np.float32], do_norm: bool = True +) -> npt.NDArray[np.float32]: + sim = vec @ mat.T + if do_norm: + sim /= np.linalg.norm(vec) * np.linalg.norm(mat, axis=1) + return sim + +def distance_l2( + vec: npt.NDArray[np.float32], mat: npt.NDArray[np.float32] +) -> npt.NDArray[np.float32]: + return np.sqrt(np.sum((mat - vec) ** 2, axis=1)) + + +def topk( + vec: npt.NDArray[np.float32], + mat: npt.NDArray[np.float32], + k: int = 5, +) -> tuple[npt.NDArray[np.int32], npt.NDArray[np.float32]]: + distances = distance_l2(vec, mat) + # Rather than sorting all similarities and taking the top K, it's faster to + # argpartition and then just sort the top K. + # The difference is O(N logN) vs O(N + k logk) + indices = np.argpartition(distances, kth=k)[:k] + top_indices = indices[np.argsort(distances[indices])] + return top_indices, distances[top_indices] + + + +vec = np.array([1.0, 2.0, 3.0], dtype=np.float32) +mat = np.array([ + [4.0, 5.0, 6.0], + [1.0, 2.0, 1.0], + [7.0, 8.0, 9.0] +], dtype=np.float32) +indices, distances = topk(vec, mat, k=2) +print(indices) +print(distances) + +import sqlite3 +import json +db = sqlite3.connect(":memory:") +db.enable_load_extension(True) +db.load_extension("../../dist/vec0") +db.execute("select load_extension('../../dist/vec0', 'sqlite3_vec_fs_read_init')") +db.enable_load_extension(False) + +results = db.execute( + ''' + select + key, + --value, + vec_distance_l2(:q, value) as distance + from json_each(:base) + order by distance + limit 2 + ''', + { + 'base': json.dumps(mat.tolist()), + 'q': '[1.0, 2.0, 3.0]' + }).fetchall() +a = [row[0] for row in results] +b = [row[1] for row in results] +print(a) +print(b) + + +#import sys; sys.exit() + +db.execute('PRAGMA page_size=16384') + +print("Loading into sqlite-vec vec0 table...") +t0 = time.time() +db.execute("create virtual table v using vec0(a float[3072], chunk_size=16)") +db.execute('insert into v select rowid, vector from vec_npy_each(vec_npy_file("dbpedia_openai_3_large_00.npy"))') +print(time.time() - t0) + +print("loading numpy array...") +t0 = time.time() +base = np.load('dbpedia_openai_3_large_00.npy') +print(time.time() - t0) + +np.random.seed(1) +queries = base[np.random.choice(base.shape[0], 20, replace=False), :] + +np_durations = [] +vec_durations = [] +from random import randrange + +def test_all(): + for idx, query in tqdm.tqdm(enumerate(queries)): + #k = randrange(20, 1000) + #k = 500 + k = 10 + + t0 = time.time() + np_ids, np_distances = topk(query, base, k=k) + np_durations.append(time.time() - t0) + + t0 = time.time() + rows = db.execute('select rowid, distance from v where a match ? and k = ?', [query, k]).fetchall() + vec_durations.append(time.time() - t0) + + vec_ids = [row[0] for row in rows] + vec_distances = [row[1] for row in rows] + + assert vec_distances == np_distances.tolist() + #assert vec_ids == np_ids.tolist() + #if (vec_ids != np_ids).any(): + # print('idx', idx) + # print('query', query) + # print('np_ids', np_ids) + # print('np_distances', np_distances) + # print('vec_ids', vec_ids) + # print('vec_distances', vec_distances) + # raise Exception(idx) + + print('final', 'np' ,np.mean(np_durations), 'vec', np.mean(vec_durations)) diff --git a/tests/leak-fixtures/each.sql b/tests/leak-fixtures/each.sql new file mode 100644 index 0000000..c300ca1 --- /dev/null +++ b/tests/leak-fixtures/each.sql @@ -0,0 +1,16 @@ +.load dist/vec0 +.mode box +.header on +.eqp on +.echo on + +select sqlite_version(), vec_version(); + +select * from vec_each('[1,2,3]'); + +select * +from json_each('[ + [1,2,3,4], + [1,2,3,4] +]') +join vec_each(json_each.value); diff --git a/tests/leak-fixtures/knn.sql b/tests/leak-fixtures/knn.sql new file mode 100644 index 0000000..c0323bf --- /dev/null +++ b/tests/leak-fixtures/knn.sql @@ -0,0 +1,61 @@ +.load dist/vec0 +.mode box +.header on +.eqp on +.echo on + +select sqlite_version(), vec_version(); + +create virtual table v using vec0(a float[1], chunk_size=8); + +insert into v + select value, format('[%f]', value / 100.0) + from generate_series(1, 100); + +select + rowid, + vec_to_json(a) +from v +where a match '[.3]' + and k = 2; + +select + rowid, + vec_to_json(a) +from v +where a match '[.3]' + and k = 0; + + +select + rowid, + vec_to_json(a) +from v +where a match '[2.0]' + and k = 2 + and rowid in (1,2,3,4,5); + + + +with queries as ( + select + rowid as query_id, + json_array(value / 100.0) as value + from generate_series(24, 39) +) +select + query_id, + rowid, + distance, + vec_to_json(a) +from queries, v +where a match queries.value + and k =5; + + +select * +from v +where rowid in (1,2,3,4); + +drop table v; + diff --git a/tests/test-loadable.py b/tests/test-loadable.py index 9a98913..293c667 100644 --- a/tests/test-loadable.py +++ b/tests/test-loadable.py @@ -113,11 +113,33 @@ FUNCTIONS = [ "vec_quantize_i8", "vec_quantize_i8", "vec_slice", + "vec_static_blob_from_raw", "vec_sub", "vec_to_json", "vec_version", ] -MODULES = ["vec0", "vec_each", "vec_npy_each"] +MODULES = [ + "vec0", + "vec_each", + "vec_npy_each", + "vec_static_blob_entries", + "vec_static_blobs", +] + + +@pytest.mark.skip(reason="TODO") +def test_vec_static_blob_from_raw(): + pass + + +@pytest.mark.skip(reason="TODO") +def test_vec_static_blobs(): + pass + + +@pytest.mark.skip(reason="TODO") +def test_vec_static_blob_entries(): + pass def test_funcs(): @@ -420,6 +442,11 @@ def test_vec_slice(): ): vec_slice(b"\xab\xab\xab\xab", 1, 0) + with _raises( + "slice 'start' index is equal to the 'end' index, vectors must have non-zero length" + ): + vec_slice(b"\xab\xab\xab\xab", 0, 0) + def test_vec_add(): vec_add = lambda *args, a="?", b="?": db.execute( @@ -775,6 +802,7 @@ def test_vec0_drops(): "t1_vector_chunks00", "t1_vector_chunks01", ] + db.execute("drop table t1") assert [ row["name"] @@ -1175,7 +1203,8 @@ def test_vec0_text_pk(): create virtual table t using vec0( t_id text primary key, aaa float[1], - bbb float8[1] + bbb float8[1], + chunk_size=8 ); """ ) @@ -1226,6 +1255,39 @@ def test_vec0_text_pk(): db.execute("select * from t") db.set_authorizer(None) + assert execute_all( + db, "select t_id, distance from t where aaa match ? and k = 3", ["[.01]"] + ) == [ + { + "t_id": "t_1", + "distance": 0.09000000357627869, + }, + { + "t_id": "t_2", + "distance": 0.1899999976158142, + }, + { + "t_id": "t_3", + "distance": 0.2900000214576721, + }, + ] + + if SUPPORTS_VTAB_IN: + assert execute_all( + db, + "select t_id, distance from t where aaa match ? and k = 3 and t_id in ('t_2', 't_3')", + ["[.01]"], + ) == [ + { + "t_id": "t_2", + "distance": 0.1899999976158142, + }, + { + "t_id": "t_3", + "distance": 0.2900000214576721, + }, + ] + def test_vec0_best_index(): db = connect(EXT_PATH) @@ -1251,15 +1313,15 @@ def test_vec0_best_index(): db.execute("select * from t where aaa MATCH ?") if SUPPORTS_VTAB_LIMIT: - with _raises("Only LIMIT or 'k =?' can be provided, not both"): - db.execute("select * from t where aaa MATCH ? and k = 10 limit 20") + with _raises("Only LIMIT or 'k =?' can be provided, not both"): + db.execute("select * from t where aaa MATCH ? and k = 10 limit 20") - with _raises( - "Only a single 'ORDER BY distance' clause is allowed on vec0 KNN queries" - ): - db.execute( - "select * from t where aaa MATCH NULL and k = 10 order by distance, distance" - ) + with _raises( + "Only a single 'ORDER BY distance' clause is allowed on vec0 KNN queries" + ): + db.execute( + "select * from t where aaa MATCH NULL and k = 10 order by distance, distance" + ) with _raises( "Only ascending in ORDER BY distance clause is supported, DESC is not supported yet." @@ -1679,6 +1741,214 @@ def test_vec0_create_errors(): db.set_authorizer(None) +def test_vec0_knn(): + db = connect(EXT_PATH) + db.execute( + """ + create virtual table v using vec0( + aaa float[8], + bbb int8[8], + ccc bit[8], + chunk_size=8 + ); + """ + ) + + with _raises( + 'Query vector on the "aaa" column is invalid: Input must have type BLOB (compact format) or TEXT (JSON), found NULL' + ): + db.execute("select * from v where aaa match NULL and k = 10") + + with _raises( + 'Query vector for the "aaa" column is expected to be of type float32, but a bit vector was provided.' + ): + db.execute("select * from v where aaa match vec_bit(X'AA') and k = 10") + + with _raises( + 'Dimension mismatch for inserted vector for the "aaa" column. Expected 8 dimensions but received 1.' + ): + db.execute("select * from v where aaa match vec_f32('[.1]') and k = 10") + + qaaa = json.dumps([0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01]) + with _raises("k value in knn queries must be greater than or equal to 0."): + db.execute("select * from v where aaa match vec_f32(?) and k = -1", [qaaa]) + + assert ( + execute_all(db, "select * from v where aaa match vec_f32(?) and k = 0", [qaaa]) + == [] + ) + + # EVIDENCE-OF: V06942_23781 + db.set_authorizer(authorizer_deny_on(sqlite3.SQLITE_READ, "v_chunks", "chunk_id")) + with _raises( + "Error preparing stmtChunk: access to v_chunks.chunk_id is prohibited", + sqlite3.DatabaseError, + ): + db.execute("select * from v where aaa match vec_f32(?) and k = 5", [qaaa]) + db.set_authorizer(None) + + assert ( + execute_all(db, "select * from v where aaa match vec_f32(?) and k = 5", [qaaa]) + == [] + ) + + db.executemany( + """ + INSERT INTO v VALUES + (:id, :vector, vec_quantize_i8(:vector, 'unit') ,vec_quantize_binary(:vector)); + """, + [ + { + "id": i, + "vector": json.dumps( + [ + i * 0.01, + i * 0.01, + i * 0.01, + i * 0.01, + i * 0.01, + i * 0.01, + i * 0.01, + i * 0.01, + ] + ), + } + for i in range(24) + ], + ) + + assert execute_all( + db, "select rowid from v where aaa match vec_f32(?) and k = 9", [qaaa] + ) == [ + {"rowid": 1}, + {"rowid": 2}, # ordering of 2 and 0 here depends on if min_idx uses < or <= + {"rowid": 0}, # + {"rowid": 3}, + {"rowid": 4}, + {"rowid": 5}, + {"rowid": 6}, + {"rowid": 7}, + {"rowid": 8}, + ] + # TODO separate test, DELETE FROM WHERE rowid in (...) is fullscan that calls vec0Rowid. try on text PKs + db.execute("delete from v where rowid in (1, 0, 8, 9)") + assert execute_all( + db, "select rowid from v where aaa match vec_f32(?) and k = 9", [qaaa] + ) == [ + {"rowid": 2}, + {"rowid": 3}, + {"rowid": 4}, + {"rowid": 5}, + {"rowid": 6}, + {"rowid": 7}, + {"rowid": 10}, + {"rowid": 11}, + {"rowid": 12}, + ] + + # EVIDENCE-OF: V05271_22109 vec0 knn validates chunk size + db.commit() + db.execute("BEGIN") + db.execute("update v_chunks set validity = zeroblob(100)") + with _raises("chunk validity size doesn't match - expected 1, found 100"): + db.execute("select * from v where aaa match ? and k = 2", [qaaa]) + db.rollback() + + # EVIDENCE-OF: V02796_19635 vec0 knn validates rowids size + db.commit() + db.execute("BEGIN") + db.execute("update v_chunks set rowids = zeroblob(100)") + with _raises("chunk rowids size doesn't match - expected 64, found 100"): + db.execute("select * from v where aaa match ? and k = 2", [qaaa]) + db.rollback() + + # EVIDENCE-OF: V16465_00535 vec0 knn validates vector chunk size + db.commit() + db.execute("BEGIN") + db.execute("update v_vector_chunks00 set vectors = zeroblob(100)") + with _raises("vectors blob size doesn't match - expected 256, found 100"): + db.execute("select * from v where aaa match ? and k = 2", [qaaa]) + db.rollback() + + +import numpy.typing as npt + + +def np_distance_l2( + vec: npt.NDArray[np.float32], mat: npt.NDArray[np.float32] +) -> npt.NDArray[np.float32]: + return np.sqrt(np.sum((mat - vec) ** 2, axis=1)) + + +def np_topk( + vec: npt.NDArray[np.float32], + mat: npt.NDArray[np.float32], + k: int = 5, +) -> tuple[npt.NDArray[np.int32], npt.NDArray[np.float32]]: + distances = np_distance_l2(vec, mat) + # Rather than sorting all similarities and taking the top K, it's faster to + # argpartition and then just sort the top K. + # The difference is O(N logN) vs O(N + k logk) + indices = np.argpartition(distances, kth=k)[:k] + top_indices = indices[np.argsort(distances[indices])] + return top_indices, distances[top_indices] + + +# import faiss +@pytest.mark.skip(reason="TODO") +def test_correctness_npy(): + db = connect(EXT_PATH) + np.random.seed(420 + 1 + 2) + mat = np.random.uniform(low=-1.0, high=1.0, size=(10000, 24)).astype(np.float32) + queries = np.random.uniform(low=-1.0, high=1.0, size=(1000, 24)).astype(np.float32) + + # sqlite-vec with vec0 + db.execute("create virtual table v using vec0(a float[24], chunk_size=8)") + for v in mat: + db.execute("insert into v(a) values (?)", [v]) + + # sqlite-vec with scalar functions + db.execute("create table t(a float[24])") + for v in mat: + db.execute("insert into t(a) values (?)", [v]) + + faiss_index = faiss.IndexFlatL2(24) + faiss_index.add(mat) + + k = 10000 - 1 + for idx, q in enumerate(queries): + print(idx) + result = execute_all( + db, + "select rowid - 1 as idx, distance from v where a match ? and k = ?", + [q, k], + ) + vec_vtab_rowids = [row["idx"] for row in result] + vec_vtab_distances = [row["distance"] for row in result] + + result = execute_all( + db, + "select rowid - 1 as idx, vec_distance_l2(a, ?) as distance from t order by 2 limit ?", + [q, k], + ) + vec_scalar_rowids = [row["idx"] for row in result] + vec_scalar_distances = [row["distance"] for row in result] + assert vec_scalar_rowids == vec_vtab_rowids + assert vec_scalar_distances == vec_vtab_distances + + faiss_distances, faiss_rowids = faiss_index.search(np.array([q]), k) + faiss_distances = np.sqrt(faiss_distances) + assert faiss_rowids[0].tolist() == vec_scalar_rowids + assert faiss_distances[0].tolist() == vec_scalar_distances + + assert faiss_distances[0].tolist() == vec_vtab_distances + assert faiss_rowids[0].tolist() == vec_vtab_rowids + + np_rowids, np_distances = np_topk(mat, q, k=k) + # assert vec_vtab_rowids == np_rowids.tolist() + # assert vec_vtab_distances == np_distances.tolist() + + def test_smoke(): db.execute("create virtual table vec_xyz using vec0( a float[2] )") assert execute_all( @@ -1833,20 +2103,15 @@ def test_vec0_stress_small_chunks(): "distance": 0.0, "rowid": 500, }, - { - "a": _f32([499 * 0.1] * 8), - "distance": 0.2828384041786194, - "rowid": 499, - }, { "a": _f32([501 * 0.1] * 8), "distance": 0.2828384041786194, "rowid": 501, }, { - "a": _f32([498 * 0.1] * 8), - "distance": 0.5656875967979431, - "rowid": 498, + "a": _f32([499 * 0.1] * 8), + "distance": 0.2828384041786194, + "rowid": 499, }, { "a": _f32([502 * 0.1] * 8), @@ -1854,15 +2119,20 @@ def test_vec0_stress_small_chunks(): "rowid": 502, }, { - "a": _f32([497 * 0.1] * 8), - "distance": 0.8485260009765625, - "rowid": 497, + "a": _f32([498 * 0.1] * 8), + "distance": 0.5656875967979431, + "rowid": 498, }, { "a": _f32([503 * 0.1] * 8), "distance": 0.8485260009765625, "rowid": 503, }, + { + "a": _f32([497 * 0.1] * 8), + "distance": 0.8485260009765625, + "rowid": 497, + }, { "a": _f32([496 * 0.1] * 8), "distance": 1.1313751935958862, diff --git a/tests/unittest.rs b/tests/unittest.rs index e95c6c3..4506b02 100644 --- a/tests/unittest.rs +++ b/tests/unittest.rs @@ -17,9 +17,55 @@ fn _min_idx(distances: Vec, k: i32) -> Vec { out } +fn _merge_sorted_lists( + a: &Vec, + a_rowids: &Vec, + b: &Vec, + b_rowids: &Vec, + b_top_idx: &Vec, + n: usize, +) -> (Vec, Vec) { + let mut out_used: i64 = 0; + let mut out: Vec = Vec::with_capacity(n); + let mut out_rowids: Vec = Vec::with_capacity(n); + unsafe { + merge_sorted_lists( + a.as_ptr().cast(), + a_rowids.as_ptr().cast(), + a.len() as i64, + b.as_ptr().cast(), + b_rowids.as_ptr().cast(), + b_top_idx.as_ptr().cast(), + b.len() as i64, + out.as_ptr().cast(), + out_rowids.as_ptr().cast(), + n as i64, + &mut out_used, + ); + out.set_len(out_used as usize); + out_rowids.set_len(out_used as usize); + } + + (out_rowids, out) +} + #[link(name = "sqlite-vec-internal")] extern "C" { fn min_idx(distances: *const f32, n: i32, out: *mut i32, k: i32) -> i32; + + fn merge_sorted_lists( + a: *const f32, + a_rowids: *const i64, + a_length: i64, + b: *const f32, + b_rowids: *const i64, + b_top_idx: *const i32, + b_length: i64, + out: *const f32, + out_rowids: *const i64, + out_length: i64, + out_used: *mut i64, + ); } #[cfg(test)] @@ -34,4 +80,85 @@ mod tests { assert_eq!(_min_idx(vec![1.0, 2.0, 3.0], 2), vec![0, 1]); assert_eq!(_min_idx(vec![3.0, 2.0, 1.0], 2), vec![2, 1]); } + + #[test] + fn test_merge_sorted_lists() { + let a = &vec![0.01, 0.02, 0.03]; + let a_rowids = &vec![1, 2, 3]; + + //let b = &vec![0.1, 0.2, 0.3, 0.4]; + //let b_rowids = &vec![4, 5, 6, 7]; + let b = &vec![0.4, 0.2, 0.3, 0.1]; + let b_rowids = &vec![7, 5, 6, 4]; + let b_top_idx = &vec![3, 1, 2, 0]; + + assert_eq!( + _merge_sorted_lists(a, a_rowids, b, b_rowids, b_top_idx, 0), + (vec![], vec![]) + ); + assert_eq!( + _merge_sorted_lists(a, a_rowids, b, b_rowids, b_top_idx, 1), + (vec![1], vec![0.01]) + ); + assert_eq!( + _merge_sorted_lists(a, a_rowids, b, b_rowids, b_top_idx, 2), + (vec![1, 2], vec![0.01, 0.02]) + ); + assert_eq!( + _merge_sorted_lists(a, a_rowids, b, b_rowids, b_top_idx, 3), + (vec![1, 2, 3], vec![0.01, 0.02, 0.03]) + ); + assert_eq!( + _merge_sorted_lists(a, a_rowids, b, b_rowids, b_top_idx, 4), + (vec![1, 2, 3, 4], vec![0.01, 0.02, 0.03, 0.1]) + ); + assert_eq!( + _merge_sorted_lists(a, a_rowids, b, b_rowids, b_top_idx, 5), + (vec![1, 2, 3, 4, 5], vec![0.01, 0.02, 0.03, 0.1, 0.2]) + ); + assert_eq!( + _merge_sorted_lists(a, a_rowids, b, b_rowids, b_top_idx, 6), + ( + vec![1, 2, 3, 4, 5, 6], + vec![0.01, 0.02, 0.03, 0.1, 0.2, 0.3] + ) + ); + assert_eq!( + _merge_sorted_lists(a, a_rowids, b, b_rowids, b_top_idx, 7), + ( + vec![1, 2, 3, 4, 5, 6, 7], + vec![0.01, 0.02, 0.03, 0.1, 0.2, 0.3, 0.4] + ) + ); + + assert_eq!( + _merge_sorted_lists(a, a_rowids, b, b_rowids, b_top_idx, 8), + ( + vec![1, 2, 3, 4, 5, 6, 7], + vec![0.01, 0.02, 0.03, 0.1, 0.2, 0.3, 0.4] + ) + ); + } + /* + #[test] + fn test_merge_sorted_lists_empty() { + let x = vec![0.1, 0.2, 0.3]; + let x_rowids = vec![666, 888, 777]; + assert_eq!( + _merge_sorted_lists(&x, &x_rowids, &vec![], &vec![], 3), + (vec![666, 888, 777], vec![0.1, 0.2, 0.3]) + ); + assert_eq!( + _merge_sorted_lists(&vec![], &vec![], &x, &x_rowids, 3), + (vec![666, 888, 777], vec![0.1, 0.2, 0.3]) + ); + assert_eq!( + _merge_sorted_lists(&vec![], &vec![], &x, &x_rowids, 4), + (vec![666, 888, 777], vec![0.1, 0.2, 0.3]) + ); + assert_eq!( + _merge_sorted_lists(&vec![], &vec![], &x, &x_rowids, 2), + (vec![666, 888], vec![0.1, 0.2]) + ); + }*/ }