diff --git a/sqlite-vec.c b/sqlite-vec.c index 0dba749..54b0ba4 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -736,7 +736,7 @@ static int fvec_from_value(sqlite3_value *value, f32 **vector, if (source[offset] == ',') { offset++; continue; - } // TODO multiple commas in a row without digits? + } if (source[offset] == ']') goto done; break; @@ -2006,7 +2006,7 @@ int parse_vector_column(const char *source, int source_length, struct VectorColumnDefinition *outColumn) { // parses a vector column definition like so: // "abc float[123]", "abc_123 bit[1234]", eetc. - // TODO: more precise error messages here. + // https://github.com/asg017/sqlite-vec/issues/46 int rc; struct Vec0Scanner scanner; struct Vec0Token token; @@ -2989,7 +2989,7 @@ static int vec_npy_eachColumnBuffer(vec_npy_each_cursor *pCur, context, &((unsigned char *) pCur->vector)[pCur->iRowid * pCur->nDimensions * sizeof(f32)], - pCur->nDimensions * sizeof(f32), SQLITE_STATIC); + pCur->nDimensions * sizeof(f32), SQLITE_TRANSIENT); break; } @@ -3361,7 +3361,6 @@ cleanup: int vec0_get_id_value_from_rowid(vec0_vtab *pVtab, i64 rowid, sqlite3_value **out) { // PERF: different strategy than get_chunk_position? - // TODO: test / evidence-of return vec0_get_chunk_position((vec0_vtab *)pVtab, rowid, out, NULL, NULL); } @@ -3461,7 +3460,6 @@ int vec0_get_vector_data(vec0_vtab *pVtab, i64 rowid, int vector_column_idx, "vectors", chunk_id, 0, &vectorBlob); if (rc != SQLITE_OK) { - // TODO evidence-of vtab_set_error(&pVtab->base, "Could not fetch vector data for %lld, opening blob failed", rowid); @@ -4490,7 +4488,11 @@ void bitmap_and_inplace(u8 *base, u8 *other, i32 n) { } void bitmap_set(u8 *bitmap, i32 position, int value) { - bitmap[position / CHAR_BIT] |= value << (position % CHAR_BIT); + if (value) { + bitmap[position / CHAR_BIT] |= 1 << (position % CHAR_BIT); + } else { + bitmap[position / CHAR_BIT] &= ~(1 << (position % CHAR_BIT)); + } } int bitmap_get(u8 *bitmap, i32 position) { @@ -4502,6 +4504,11 @@ void bitmap_clear(u8 *bitmap, i32 n) { memset(bitmap, 0, n / CHAR_BIT); } +void bitmap_fill(u8 *bitmap, i32 n) { + assert((n % 8) == 0); + memset(bitmap, 0xFF, n / CHAR_BIT); +} + /** * @brief Finds the minimum k items in distances, and writes the indicies to * out. @@ -4885,7 +4892,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum, if (dimensions != vector_column->dimensions) { vtab_set_error( &p->base, - "Dimension mismatch for inserted vector for the \"%.*s\" column. " + "Dimension mismatch for query vector for the \"%.*s\" column. " "Expected %d dimensions but received %d.", vector_column->name_length, vector_column->name, vector_column->dimensions, dimensions); @@ -5728,8 +5735,7 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, // a write-able blob of the validity column for the given chunk. Used to mark // validity bit sqlite3_blob *blobChunksValidity = NULL; - // buffer for the valididty column for the given chunk. TODO maybe not needed - // here? + // buffer for the valididty column for the given chunk. Maybe not needed here? const unsigned char *bufferChunksValidity = NULL; int numReadVectors = 0; @@ -5951,9 +5957,9 @@ int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite_int64 rowid) { return rc; } - // 3. zero out rowid in chunks.rowids TODO + // 3. zero out rowid in chunks.rowids https://github.com/asg017/sqlite-vec/issues/54 - // 4. zero out any data in vector chunks tables TODO + // 4. zero out any data in vector chunks tables https://github.com/asg017/sqlite-vec/issues/54 // 5. delete from _rowids table rc = vec0Update_Delete_DeleteRowids(p, rowid); @@ -5975,9 +5981,7 @@ int vec0Update_UpdateVectorColumn(vec0_vtab *p, i64 chunk_id, i64 chunk_offset, enum VectorElementType elementType; void *vector; vector_cleanup cleanup = vector_cleanup_noop; - // TODO: Can't update non f32, bc subtypes are stripped from UPDATEs. - // Need to 1) create a less strict vector_from_value, or 2) wait - // for this to resolve: https://sqlite.org/forum/forumpost/65317ce9c6 + // https://github.com/asg017/sqlite-vec/issues/53 rc = vector_from_value(valueVector, &vector, &dimensions, &elementType, &cleanup, &pzError); if (rc != SQLITE_OK) { @@ -6449,12 +6453,36 @@ typedef enum { VEC_SBE__QUERYPLAN_KNN = 2 } vec_sbe_query_plan; +struct sbe_query_knn_data { + i64 k; + i64 k_used; + // Array of rowids of size k. Must be freed with sqlite3_free(). + i32 *rowids; + // Array of distances of size k. Must be freed with sqlite3_free(). + f32 *distances; + i64 current_idx; +}; +void sbe_query_knn_data_clear(struct sbe_query_knn_data *knn_data) { + if (!knn_data) + return; + + if (knn_data->rowids) { + sqlite3_free(knn_data->rowids); + knn_data->rowids = NULL; + } + if (knn_data->distances) { + sqlite3_free(knn_data->distances); + knn_data->distances = NULL; + } +} + + typedef struct vec_static_blob_entries_cursor vec_static_blob_entries_cursor; struct vec_static_blob_entries_cursor { sqlite3_vtab_cursor base; sqlite3_int64 iRowid; vec_sbe_query_plan query_plan; - struct vec0_query_knn_data *knn_data; + struct sbe_query_knn_data *knn_data; }; static int vec_static_blob_entriesConnect(sqlite3 *db, void *pAux, int argc, @@ -6519,6 +6547,7 @@ static int vec_static_blob_entriesOpen(sqlite3_vtab *p, static int vec_static_blob_entriesClose(sqlite3_vtab_cursor *cur) { vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)cur; + sqlite3_free(pCur->knn_data); sqlite3_free(pCur); return SQLITE_OK; } @@ -6528,7 +6557,7 @@ static int vec_static_blob_entriesBestIndex(sqlite3_vtab *pVTab, vec_static_blob_entries_vtab *p = (vec_static_blob_entries_vtab *)pVTab; int iMatchTerm = -1; int iLimitTerm = -1; - // int iRowidTerm = -1; // TODO point query + // int iRowidTerm = -1; // https://github.com/asg017/sqlite-vec/issues/47 int iKTerm = -1; for (int i = 0; i < pIdxInfo->nConstraint; i++) { @@ -6540,7 +6569,7 @@ static int vec_static_blob_entriesBestIndex(sqlite3_vtab *pVTab, if (op == SQLITE_INDEX_CONSTRAINT_MATCH && iColumn == VEC_STATIC_BLOB_ENTRIES_VECTOR) { if (iMatchTerm > -1) { - // TODO only 1 match operator at a time + // https://github.com/asg017/sqlite-vec/issues/51 return SQLITE_ERROR; } iMatchTerm = i; @@ -6555,7 +6584,7 @@ static int vec_static_blob_entriesBestIndex(sqlite3_vtab *pVTab, } if (iMatchTerm >= 0) { if (iLimitTerm < 0 && iKTerm < 0) { - // TODO: error, match on vector1 should require a limit for KNN + // https://github.com/asg017/sqlite-vec/issues/51 return SQLITE_ERROR; } if (iLimitTerm >= 0 && iKTerm >= 0) { @@ -6566,7 +6595,7 @@ static int vec_static_blob_entriesBestIndex(sqlite3_vtab *pVTab, return SQLITE_CONSTRAINT; } if (pIdxInfo->nOrderBy > 1) { - // TODO error, orderByConsumed is all or nothing, only 1 order by allowed + // https://github.com/asg017/sqlite-vec/issues/51 vtab_set_error(pVTab, "more than 1 ORDER BY clause provided"); return SQLITE_CONSTRAINT; } @@ -6615,8 +6644,9 @@ static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor, (vec_static_blob_entries_vtab *)pCur->base.pVtab; if (idxNum == VEC_SBE__QUERYPLAN_KNN) { + assert(argc == 2); pCur->query_plan = VEC_SBE__QUERYPLAN_KNN; - struct vec0_query_knn_data *knn_data; + struct sbe_query_knn_data *knn_data; knn_data = sqlite3_malloc(sizeof(*knn_data)); if (!knn_data) { return SQLITE_NOMEM; @@ -6642,7 +6672,7 @@ static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor, i64 k = min(sqlite3_value_int64(argv[1]), (i64)p->blob->nvectors); if (k < 0) { - // TODO handle + // HANDLE https://github.com/asg017/sqlite-vec/issues/55 return SQLITE_ERROR; } if (k == 0) { @@ -6651,24 +6681,38 @@ static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor, return SQLITE_OK; } - i64 *topk_rowids = sqlite3_malloc(k * sizeof(i64)); + size_t bsize = (p->blob->nvectors + 7) & ~7; + + + i32 *topk_rowids = sqlite3_malloc(k * sizeof(i32)); if (!topk_rowids) { - // TODO handle + // HANDLE https://github.com/asg017/sqlite-vec/issues/55 return SQLITE_ERROR; } - f32 *distances = sqlite3_malloc(p->blob->nvectors * sizeof(f32)); + f32 *distances = sqlite3_malloc(bsize * sizeof(f32)); if (!distances) { - // TODO handle + // HANDLE https://github.com/asg017/sqlite-vec/issues/55 return SQLITE_ERROR; } for (size_t i = 0; i < p->blob->nvectors; i++) { + // https://github.com/asg017/sqlite-vec/issues/52 float *v = ((float *)p->blob->p) + (i * p->blob->dimensions); distances[i] = distance_l2_sqr_float(v, (float *)queryVector, &p->blob->dimensions); - // TODO other metrics } - // min_idx(distances, k, topk_rowids, k); // TODO + u8 * candidates = bitmap_new(bsize); + assert(candidates); + + u8 * taken = bitmap_new(bsize); + assert(taken); + + bitmap_fill(candidates, bsize); + for(size_t i = bsize; i >= p->blob->nvectors; i--) { + bitmap_set(candidates, i, 0); + } + i32 k_used = 0; + min_idx(distances, bsize, candidates, topk_rowids, k, taken, &k_used); knn_data->current_idx = 0; knn_data->distances = distances; knn_data->k = k; @@ -6686,8 +6730,18 @@ static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor, static int vec_static_blob_entriesRowid(sqlite3_vtab_cursor *cur, sqlite_int64 *pRowid) { vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)cur; - *pRowid = pCur->iRowid; - return SQLITE_OK; + switch (pCur->query_plan) { + case VEC_SBE__QUERYPLAN_FULLSCAN: { + *pRowid = pCur->iRowid; + return SQLITE_OK; + } + case VEC_SBE__QUERYPLAN_KNN: { + i32 rowid = ((i32 *)pCur->knn_data->rowids)[pCur->knn_data->current_idx]; + *pRowid = (sqlite3_int64) rowid; + return SQLITE_OK; + } + } + } static int vec_static_blob_entriesNext(sqlite3_vtab_cursor *cur) { @@ -6732,7 +6786,7 @@ static int vec_static_blob_entriesColumn(sqlite3_vtab_cursor *cur, context, ((unsigned char *)p->blob->p) + (pCur->iRowid * p->blob->dimensions * sizeof(float)), - p->blob->dimensions * sizeof(float), SQLITE_STATIC); + p->blob->dimensions * sizeof(float), SQLITE_TRANSIENT); sqlite3_result_subtype(context, p->blob->element_type); break; } @@ -6742,11 +6796,10 @@ static int vec_static_blob_entriesColumn(sqlite3_vtab_cursor *cur, switch (i) { case VEC_STATIC_BLOB_ENTRIES_VECTOR: { i32 rowid = ((i32 *)pCur->knn_data->rowids)[pCur->knn_data->current_idx]; - sqlite3_result_blob(context, ((unsigned char *)p->blob->p) + (rowid * p->blob->dimensions * sizeof(float)), - p->blob->dimensions * sizeof(float), SQLITE_STATIC); + p->blob->dimensions * sizeof(float), SQLITE_TRANSIENT); sqlite3_result_subtype(context, p->blob->element_type); break; } @@ -6758,7 +6811,7 @@ static int vec_static_blob_entriesColumn(sqlite3_vtab_cursor *cur, static sqlite3_module vec_static_blob_entriesModule = { /* iVersion */ 3, - /* xCreate */ vec_static_blob_entriesCreate, // TODO rm? + /* xCreate */ vec_static_blob_entriesCreate, // handle rm? https://github.com/asg017/sqlite-vec/issues/55 /* xConnect */ vec_static_blob_entriesConnect, /* xBestIndex */ vec_static_blob_entriesBestIndex, /* xDisconnect */ vec_static_blob_entriesDisconnect, @@ -6787,91 +6840,6 @@ static sqlite3_module vec_static_blob_entriesModule = { }; #pragma endregion -int sqlite3_mmap_warm(sqlite3 *db, const char *zDb) { - int rc = SQLITE_OK; - char *zSql = 0; - int pgsz = 0; - unsigned int nTotal = 0; - - if (0 == sqlite3_get_autocommit(db)) - return SQLITE_MISUSE; - - /* Open a read-only transaction on the file in question */ - zSql = sqlite3_mprintf("BEGIN; SELECT * FROM %s%q%ssqlite_schema", - (zDb ? "'" : ""), (zDb ? zDb : ""), (zDb ? "'." : "")); - if (zSql == 0) - return SQLITE_NOMEM; - rc = sqlite3_exec(db, zSql, 0, 0, 0); - sqlite3_free(zSql); - - /* Find the SQLite page size of the file */ - if (rc == SQLITE_OK) { - zSql = sqlite3_mprintf("PRAGMA %s%q%spage_size", (zDb ? "'" : ""), - (zDb ? zDb : ""), (zDb ? "'." : "")); - if (zSql == 0) { - rc = SQLITE_NOMEM; - } else { - sqlite3_stmt *pPgsz = 0; - rc = sqlite3_prepare_v2(db, zSql, -1, &pPgsz, 0); - sqlite3_free(zSql); - if (rc == SQLITE_OK) { - if (sqlite3_step(pPgsz) == SQLITE_ROW) { - pgsz = sqlite3_column_int(pPgsz, 0); - } - rc = sqlite3_finalize(pPgsz); - } - if (rc == SQLITE_OK && pgsz == 0) { - rc = SQLITE_ERROR; - } - } - } - - /* Touch each mmap'd page of the file */ - if (rc == SQLITE_OK) { - int rc2; - sqlite3_file *pFd = 0; - rc = sqlite3_file_control(db, zDb, SQLITE_FCNTL_FILE_POINTER, &pFd); - if (rc == SQLITE_OK && pFd->pMethods && pFd->pMethods->iVersion >= 3) { - i64 iPg = 1; - sqlite3_io_methods const *p = pFd->pMethods; - while (1) { - unsigned char *pMap; - rc = p->xFetch(pFd, pgsz * iPg, pgsz, (void **)&pMap); - if (rc != SQLITE_OK || pMap == 0) - break; - - nTotal += (unsigned int)pMap[0]; - nTotal += (unsigned int)pMap[pgsz - 1]; - - rc = p->xUnfetch(pFd, pgsz * iPg, (void *)pMap); - if (rc != SQLITE_OK) - break; - iPg++; - } - sqlite3_log(SQLITE_OK, - "sqlite3_mmap_warm_cache: Warmed up %d pages of %s", - iPg == 1 ? 0 : iPg, sqlite3_db_filename(db, zDb)); - } - - rc2 = sqlite3_exec(db, "END", 0, 0, 0); - if (rc == SQLITE_OK) - rc = rc2; - } - - (void)nTotal; - return rc; -} - -#ifdef _WIN32 -__declspec(dllexport) -#endif - int sqlite3_vec_warm_mmap(sqlite3 *db, char **pzErrMsg, - const sqlite3_api_routines *pApi) { - UNUSED_PARAMETER(pzErrMsg); - SQLITE_EXTENSION_INIT2(pApi); - return sqlite3_mmap_warm(db, NULL); -} - #ifdef SQLITE_VEC_ENABLE_AVX #define SQLITE_VEC_DEBUG_BUILD_AVX "avx" #else @@ -6907,12 +6875,6 @@ __declspec(dllexport) const sqlite3_api_routines *pApi) { SQLITE_EXTENSION_INIT2(pApi); int rc = SQLITE_OK; - vec_static_blob_data *static_blob_data; - static_blob_data = sqlite3_malloc(sizeof(*static_blob_data)); - if (!static_blob_data) { - return SQLITE_NOMEM; - } - memset(static_blob_data, 0, sizeof(*static_blob_data)); #define DEFAULT_FLAGS (SQLITE_UTF8 | SQLITE_INNOCUOUS | SQLITE_DETERMINISTIC) @@ -6953,7 +6915,6 @@ __declspec(dllexport) {"vec_int8", vec_int8, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, }, {"vec_quantize_int8", vec_quantize_int8, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, }, {"vec_quantize_binary", vec_quantize_binary, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, }, - {"vec_static_blob_from_raw", vec_static_blob_from_raw, 4, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE }, // clang-format on }; @@ -6989,13 +6950,7 @@ __declspec(dllexport) return rc; } } - rc = sqlite3_create_module_v2(db, "vec_static_blobs", &vec_static_blobsModule, - static_blob_data, sqlite3_free); - assert(rc == SQLITE_OK); - rc = sqlite3_create_module_v2(db, "vec_static_blob_entries", - &vec_static_blob_entriesModule, - static_blob_data, NULL); - assert(rc == SQLITE_OK); + return SQLITE_OK; } @@ -7014,23 +6969,33 @@ __declspec(dllexport) } #endif -#ifdef SQLITE_VEC_ENABLE_TRACE_ENTRYPOINT -int trace(unsigned int x, void *p1, void *p2, void *p3) { - if (x == SQLITE_TRACE_STMT) { - sqlite3_stmt *stmt = (sqlite3_stmt *)p2; - char *zSql = sqlite3_expanded_sql(stmt); - printf("%s\n", zSql); - } -} #ifdef _WIN32 __declspec(dllexport) #endif - int trace_debug(sqlite3 *db, char **pzErrMsg, - const sqlite3_api_routines *pApi) { + int sqlite3_vec_static_blobs_init(sqlite3 *db, char **pzErrMsg, + const sqlite3_api_routines *pApi) { UNUSED_PARAMETER(pzErrMsg); SQLITE_EXTENSION_INIT2(pApi); - sqlite3_trace_v2(db, SQLITE_TRACE_STMT, trace, NULL); - return SQLITE_OK; + + int rc = SQLITE_OK; + vec_static_blob_data *static_blob_data; + static_blob_data = sqlite3_malloc(sizeof(*static_blob_data)); + if (!static_blob_data) { + return SQLITE_NOMEM; + } + memset(static_blob_data, 0, sizeof(*static_blob_data)); + + rc = sqlite3_create_function_v2(db, "vec_static_blob_from_raw", 4, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, + NULL, vec_static_blob_from_raw, NULL, NULL, NULL); + if(rc != SQLITE_OK) return rc; + + rc = sqlite3_create_module_v2(db, "vec_static_blobs", &vec_static_blobsModule, + static_blob_data, sqlite3_free); + if(rc != SQLITE_OK) return rc; + rc = sqlite3_create_module_v2(db, "vec_static_blob_entries", + &vec_static_blob_entriesModule, + static_blob_data, NULL); + if(rc != SQLITE_OK) return rc; + return rc; } -#endif diff --git a/tests/test-loadable.py b/tests/test-loadable.py index 025d9ce..83e0fd5 100644 --- a/tests/test-loadable.py +++ b/tests/test-loadable.py @@ -80,8 +80,6 @@ def connect(ext, path=":memory:", extra_entrypoint=None): db = connect(EXT_PATH) -# db.load_extension(EXT_PATH, entrypoint="trace_debug") - def explain_query_plan(sql): return db.execute("explain query plan " + sql).fetchone()["detail"] @@ -113,7 +111,6 @@ FUNCTIONS = [ "vec_quantize_binary", "vec_quantize_int8", "vec_slice", - "vec_static_blob_from_raw", "vec_sub", "vec_to_json", "vec_type", @@ -123,24 +120,150 @@ MODULES = [ "vec0", "vec_each", "vec_npy_each", - "vec_static_blob_entries", - "vec_static_blobs", + #"vec_static_blob_entries", + #"vec_static_blobs", ] -@pytest.mark.skip(reason="TODO") -def test_vec_static_blob_from_raw(): - pass + +def register_numpy(db, name: str, array): + ptr = array.__array_interface__["data"][0] + nvectors, dimensions = array.__array_interface__["shape"] + element_type = array.__array_interface__["typestr"] + + assert element_type == "\x9a\x99\x99>", + }, + { + "vector": b"fff?\xcd\xccL?", + }, + ] + assert execute_all(db, "select rowid, (vector) from z") == [ + { + "rowid": 0, + "vector": b"\xcd\xcc\xcc=\xcd\xcc\xcc=\xcd\xcc\xcc=\xcd\xcc\xcc=", + }, + { + "rowid": 1, + "vector": b"\xcd\xccL>\xcd\xccL>\xcd\xccL>\xcd\xccL>", + }, + { + "rowid": 2, + "vector": b"\x9a\x99\x99>\x9a\x99\x99>\x9a\x99\x99>\x9a\x99\x99>", + }, + { + "rowid": 3, + "vector": b"\xcd\xcc\xcc>\xcd\xcc\xcc>\xcd\xcc\xcc>\xcd\xcc\xcc>", + }, + { + "rowid": 4, + "vector": b"\x00\x00\x00?\x00\x00\x00?\x00\x00\x00?\x00\x00\x00?", + }, + ] + assert execute_all( + db, + "select rowid, vec_to_json(vector) as v from z where vector match ? and k = 3 order by distance;", + [np.array([0.3, 0.3, 0.3, 0.3], dtype=np.float32)], + ) == [ + { + "rowid": 2, + "v": "[0.300000,0.300000,0.300000,0.300000]", + }, + { + "rowid": 3, + "v": "[0.400000,0.400000,0.400000,0.400000]", + }, + { + "rowid": 1, + "v": "[0.200000,0.200000,0.200000,0.200000]", + }, + ] + assert execute_all( + db, + "select rowid, vec_to_json(vector) as v from z where vector match ? and k = 3 order by distance;", + [np.array([0.6, 0.6, 0.6, 0.6], dtype=np.float32)], + ) == [ + { + "rowid": 4, + "v": "[0.500000,0.500000,0.500000,0.500000]", + }, + { + "rowid": 3, + "v": "[0.400000,0.400000,0.400000,0.400000]", + }, + { + "rowid": 2, + "v": "[0.300000,0.300000,0.300000,0.300000]", + }, + ] def test_funcs(): @@ -1872,7 +1995,7 @@ def test_vec0_knn(): db.execute("select * from v where aaa match vec_bit(X'AA') and k = 10") with _raises( - 'Dimension mismatch for inserted vector for the "aaa" column. Expected 8 dimensions but received 1.' + 'Dimension mismatch for query vector for the "aaa" column. Expected 8 dimensions but received 1.' ): db.execute("select * from v where aaa match vec_f32('[.1]') and k = 10") @@ -2120,36 +2243,42 @@ def test_smoke(): db.execute("insert into vec_xyz(rowid, a) select 2, X'0000000000000040'") chunk = db.execute("select * from vec_xyz_chunks").fetchone() - assert chunk[ - "rowids" - ] == b"\x01\x00\x00\x00\x00\x00\x00\x00" + b"\x02\x00\x00\x00\x00\x00\x00\x00" + bytearray( - int(1024 * 8) - 8 * 2 + assert ( + chunk["rowids"] + == b"\x01\x00\x00\x00\x00\x00\x00\x00" + + b"\x02\x00\x00\x00\x00\x00\x00\x00" + + bytearray(int(1024 * 8) - 8 * 2) ) assert chunk["chunk_id"] == 1 assert chunk["validity"] == b"\x03" + bytearray(int(1024 / 8) - 1) vchunk = db.execute("select * from vec_xyz_vector_chunks00").fetchone() assert vchunk["rowid"] == 1 - assert vchunk[ - "vectors" - ] == b"\x00\x00\x00\x00\x00\x00\x80\x3f" + b"\x00\x00\x00\x00\x00\x00\x00\x40" + bytearray( - int(1024 * 4 * 2) - (2 * 4 * 2) + assert ( + vchunk["vectors"] + == b"\x00\x00\x00\x00\x00\x00\x80\x3f" + + b"\x00\x00\x00\x00\x00\x00\x00\x40" + + bytearray(int(1024 * 4 * 2) - (2 * 4 * 2)) ) db.execute("insert into vec_xyz(rowid, a) select 3, X'00000000000080bf'") chunk = db.execute("select * from vec_xyz_chunks").fetchone() assert chunk["chunk_id"] == 1 assert chunk["validity"] == b"\x07" + bytearray(int(1024 / 8) - 1) - assert chunk[ - "rowids" - ] == b"\x01\x00\x00\x00\x00\x00\x00\x00" + b"\x02\x00\x00\x00\x00\x00\x00\x00" + b"\x03\x00\x00\x00\x00\x00\x00\x00" + bytearray( - int(1024 * 8) - 8 * 3 + assert ( + chunk["rowids"] + == b"\x01\x00\x00\x00\x00\x00\x00\x00" + + b"\x02\x00\x00\x00\x00\x00\x00\x00" + + b"\x03\x00\x00\x00\x00\x00\x00\x00" + + bytearray(int(1024 * 8) - 8 * 3) ) vchunk = db.execute("select * from vec_xyz_vector_chunks00").fetchone() assert vchunk["rowid"] == 1 - assert vchunk[ - "vectors" - ] == b"\x00\x00\x00\x00\x00\x00\x80\x3f" + b"\x00\x00\x00\x00\x00\x00\x00\x40" + b"\x00\x00\x00\x00\x00\x00\x80\xbf" + bytearray( - int(1024 * 4 * 2) - (2 * 4 * 3) + assert ( + vchunk["vectors"] + == b"\x00\x00\x00\x00\x00\x00\x80\x3f" + + b"\x00\x00\x00\x00\x00\x00\x00\x40" + + b"\x00\x00\x00\x00\x00\x00\x80\xbf" + + bytearray(int(1024 * 4 * 2) - (2 * 4 * 3)) ) # db.execute("select * from vec_xyz") @@ -2192,66 +2321,63 @@ def test_vec0_stress_small_chunks(): {"rowid": 994, "a": _f32([99.4] * 8)}, {"rowid": 993, "a": _f32([99.3] * 8)}, ] - assert ( - execute_all( - db, - """ + assert execute_all( + db, + """ select rowid, a, distance from vec_small where a match ? and k = 9 order by distance """, - [_f32([50.0] * 8)], - ) - == [ - { - "a": _f32([500 * 0.1] * 8), - "distance": 0.0, - "rowid": 500, - }, - { - "a": _f32([501 * 0.1] * 8), - "distance": 0.2828384041786194, - "rowid": 501, - }, - { - "a": _f32([499 * 0.1] * 8), - "distance": 0.2828384041786194, - "rowid": 499, - }, - { - "a": _f32([502 * 0.1] * 8), - "distance": 0.5656875967979431, - "rowid": 502, - }, - { - "a": _f32([498 * 0.1] * 8), - "distance": 0.5656875967979431, - "rowid": 498, - }, - { - "a": _f32([503 * 0.1] * 8), - "distance": 0.8485260009765625, - "rowid": 503, - }, - { - "a": _f32([497 * 0.1] * 8), - "distance": 0.8485260009765625, - "rowid": 497, - }, - { - "a": _f32([496 * 0.1] * 8), - "distance": 1.1313751935958862, - "rowid": 496, - }, - { - "a": _f32([504 * 0.1] * 8), - "distance": 1.1313751935958862, - "rowid": 504, - }, - ] - ) + [_f32([50.0] * 8)], + ) == [ + { + "a": _f32([500 * 0.1] * 8), + "distance": 0.0, + "rowid": 500, + }, + { + "a": _f32([501 * 0.1] * 8), + "distance": 0.2828384041786194, + "rowid": 501, + }, + { + "a": _f32([499 * 0.1] * 8), + "distance": 0.2828384041786194, + "rowid": 499, + }, + { + "a": _f32([502 * 0.1] * 8), + "distance": 0.5656875967979431, + "rowid": 502, + }, + { + "a": _f32([498 * 0.1] * 8), + "distance": 0.5656875967979431, + "rowid": 498, + }, + { + "a": _f32([503 * 0.1] * 8), + "distance": 0.8485260009765625, + "rowid": 503, + }, + { + "a": _f32([497 * 0.1] * 8), + "distance": 0.8485260009765625, + "rowid": 497, + }, + { + "a": _f32([496 * 0.1] * 8), + "distance": 1.1313751935958862, + "rowid": 496, + }, + { + "a": _f32([504 * 0.1] * 8), + "distance": 1.1313751935958862, + "rowid": 504, + }, + ] def test_vec0_distance_metric(): diff --git a/tmp-static.py b/tmp-static.py index 979c553..a3b5f37 100644 --- a/tmp-static.py +++ b/tmp-static.py @@ -5,6 +5,7 @@ db = sqlite3.connect(":memory:") db.enable_load_extension(True) db.load_extension("./dist/vec0") +db.execute("select load_extension('./dist/vec0', 'sqlite3_vec_raw_init')") db.enable_load_extension(False) x = np.array([[0.1, 0.2, 0.3, 0.4], [0.9, 0.8, 0.7, 0.6]], dtype=np.float32)