diff --git a/sqlite-vec.c b/sqlite-vec.c index ee45a42..d63ce53 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -467,6 +467,33 @@ void array_cleanup(struct Array *array) { array->z = NULL; } +char *vector_subtype_name(int subtype) { + switch (subtype) { + case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: + return "float32"; + case SQLITE_VEC_ELEMENT_TYPE_INT8: + return "int8"; + case SQLITE_VEC_ELEMENT_TYPE_BIT: + return "bit"; + } + return ""; +} +char *type_name(int type) { + switch (type) { + case SQLITE_INTEGER: + return "INTEGER"; + case SQLITE_BLOB: + return "BLOB"; + case SQLITE_TEXT: + return "TEXT"; + case SQLITE_FLOAT: + return "FLOAT"; + case SQLITE_NULL: + return "NULL"; + } + return ""; +} + typedef void (*fvec_cleanup)(f32 *vector); void fvec_cleanup_noop(f32 *_) { UNUSED_PARAMETER(_); } @@ -586,7 +613,7 @@ static int fvec_from_value(sqlite3_value *value, f32 **vector, } *pzErr = sqlite3_mprintf( - "Input must have type BLOB (compact format) or TEXT (JSON)"); + "Input must have type BLOB (compact format) or TEXT (JSON), found %s", type_name(value_type)); return SQLITE_ERROR; } @@ -612,7 +639,7 @@ static int bitvec_from_value(sqlite3_value *value, u8 **vector, static int int8_vec_from_value(sqlite3_value *value, i8 **vector, size_t *dimensions, vector_cleanup *cleanup, - char **pzErr) { + char **pzErr) { int value_type = sqlite3_value_type(value); if (value_type == SQLITE_BLOB) { const void *blob = sqlite3_value_blob(value); @@ -772,17 +799,7 @@ int vector_from_value(sqlite3_value *value, void **vector, size_t *dimensions, return SQLITE_ERROR; } -char *vector_subtype_name(int subtype) { - switch (subtype) { - case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: - return "float32"; - case SQLITE_VEC_ELEMENT_TYPE_INT8: - return "int8"; - case SQLITE_VEC_ELEMENT_TYPE_BIT: - return "bit"; - } - return ""; -} + int ensure_vector_match(sqlite3_value *aValue, sqlite3_value *bValue, void **a, void **b, enum VectorElementType *element_type, size_t *dimensions, vector_cleanup *outACleanup, @@ -1624,7 +1641,6 @@ int parse_primary_key_definition(const char *source, int source_length, char **out_column_name, int *out_column_name_length, int *out_column_type) { - // TODO return SQLITE_ERROR on PK parse errors. struct Vec0Scanner scanner; struct Vec0Token token; char *column_name; @@ -1699,17 +1715,22 @@ struct VectorColumnDefinition { enum Vec0DistanceMetrics distance_metric; }; -size_t vector_column_byte_size(struct VectorColumnDefinition column) { - switch (column.element_type) { +size_t vector_byte_size(enum VectorElementType element_type, + size_t dimensions) { + switch (element_type) { case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: - return column.dimensions * sizeof(f32); + return dimensions * sizeof(f32); case SQLITE_VEC_ELEMENT_TYPE_INT8: - return column.dimensions * sizeof(i8); + return dimensions * sizeof(i8); case SQLITE_VEC_ELEMENT_TYPE_BIT: - return column.dimensions / CHAR_BIT; + return dimensions / CHAR_BIT; } } +inline size_t vector_column_byte_size(struct VectorColumnDefinition column) { + return vector_byte_size(column.element_type, column.dimensions); +} + /** * @brief Parse an vec0 vtab argv[i] column definition and see if * it's a vector column defintion, ex `contents_embedding float[768]`. @@ -1790,7 +1811,6 @@ int parse_vector_column(const char *source, int source_length, } // any other tokens left should be column-level options , ex `key=value` - // TODO make sure options are defined only once. // ex `distance_metric=L2 distance_metric=cosine` should error while (1) { // should be EOF or identifer (option key) @@ -2030,7 +2050,7 @@ static sqlite3_module vec_eachModule = { /* xRelease */ 0, /* xRollbackTo */ 0, /* xShadowName */ 0, -#if SQLITE_VERSION_NUMBER >= 3440000 +#if SQLITE_VERSION_NUMBER >= 3044000 /* xIntegrity */ 0 #endif }; @@ -2156,6 +2176,7 @@ int npy_scanner_next(struct NpyScanner *scanner, struct NpyToken *out) { return rc; } +#define NPY_PARSE_ERROR "Error parsing numpy array: " int parse_npy_header(sqlite3_vtab *pVTab, const unsigned char *header, size_t headerLength, enum VectorElementType *out_element_type, @@ -2169,13 +2190,14 @@ int parse_npy_header(sqlite3_vtab *pVTab, const unsigned char *header, if (npy_scanner_next(&scanner, &token) != VEC0_TOKEN_RESULT_SOME && token.token_type != NPY_TOKEN_TYPE_LBRACE) { - vtab_set_error(pVTab, "numpy header did not start with '{'"); + vtab_set_error(pVTab, + NPY_PARSE_ERROR "numpy header did not start with '{'"); return SQLITE_ERROR; } while (1) { rc = npy_scanner_next(&scanner, &token); if (rc != VEC0_TOKEN_RESULT_SOME) { - vtab_set_error(pVTab, "expected key in numpy header"); + vtab_set_error(pVTab, NPY_PARSE_ERROR "expected key in numpy header"); return SQLITE_ERROR; } @@ -2183,128 +2205,114 @@ int parse_npy_header(sqlite3_vtab *pVTab, const unsigned char *header, break; } if (token.token_type != NPY_TOKEN_TYPE_STRING) { - vtab_set_error(pVTab, "expected a string as key in numpy header"); + vtab_set_error(pVTab, NPY_PARSE_ERROR + "expected a string as key in numpy header"); return SQLITE_ERROR; } unsigned char *key = token.start; - // TODO use this in strncmp()? - // int keyLength = token.end - token.start; rc = npy_scanner_next(&scanner, &token); if ((rc != VEC0_TOKEN_RESULT_SOME) || (token.token_type != NPY_TOKEN_TYPE_COLON)) { - vtab_set_error(pVTab, "expected a ':' after key in numpy header"); + vtab_set_error(pVTab, NPY_PARSE_ERROR + "expected a ':' after key in numpy header"); return SQLITE_ERROR; } - // TODO: strcmp safe? if (strncmp((char *)key, "'descr'", strlen("'descr'")) == 0) { rc = npy_scanner_next(&scanner, &token); - todo_assert(rc == VEC0_TOKEN_RESULT_SOME); - todo_assert(token.token_type == NPY_TOKEN_TYPE_STRING); - todo_assert(strncmp((char *)token.start, "'maxChunks = 1024; + pCur->chunksBufferSize = + (vector_byte_size(element_type, numDimensions)) * pCur->maxChunks; + pCur->chunksBuffer = sqlite3_malloc(pCur->chunksBufferSize); + if (pCur->chunksBufferSize && !pCur->chunksBuffer) { + return SQLITE_NOMEM; + } + + pCur->currentChunkSize = + fread(pCur->chunksBuffer, vector_byte_size(element_type, numDimensions), + pCur->maxChunks, file); + + pCur->currentChunkIndex = 0; + pCur->elementType = element_type; + pCur->nElements = numElements; + pCur->nDimensions = numDimensions; + pCur->input_type = VEC_NPY_EACH_INPUT_FILE; + + pCur->eof = pCur->currentChunkSize == 0; + pCur->file = file; + return SQLITE_OK; +} + +int parse_npy_buffer(sqlite3_vtab *pVTab, const unsigned char *buffer, + int bufferLength, void **data, size_t *numElements, + size_t *numDimensions, + enum VectorElementType *element_type) { + + if (bufferLength < 10) { + // IMP: V03312_20150 + vtab_set_error(pVTab, "numpy array too short"); + return SQLITE_ERROR; + } + if (memcmp(NPY_MAGIC, buffer, sizeof(NPY_MAGIC)) != 0) { + // V11954_28792 + vtab_set_error(pVTab, "numpy array does not contain the 'magic' header"); + return SQLITE_ERROR; + } + + u8 major = buffer[6]; + u8 minor = buffer[7]; + uint16_t headerLength = 0; + memcpy(&headerLength, &buffer[8], sizeof(uint16_t)); + + i32 totalHeaderLength = sizeof(NPY_MAGIC) + sizeof(major) + sizeof(minor) + + sizeof(headerLength) + headerLength; + i32 dataSize = bufferLength - totalHeaderLength; + + if (dataSize < 0) { + vtab_set_error(pVTab, "numpy array header length is invalid"); + return SQLITE_ERROR; + } + + const unsigned char *header = &buffer[10]; + int fortran_order; + + int rc = parse_npy_header(pVTab, header, headerLength, element_type, + &fortran_order, numElements, numDimensions); + if (rc != SQLITE_OK) { + return rc; + } + + i32 expectedDataSize = + (*numElements * vector_byte_size(*element_type, *numDimensions)); + if (expectedDataSize != dataSize) { + vtab_set_error(pVTab, + "numpy array error: Expected a data size of %d, found %d", + expectedDataSize, dataSize); + return SQLITE_ERROR; + } + + *data = (void *)&buffer[totalHeaderLength]; + return SQLITE_OK; +} + static int vec_npy_eachConnect(sqlite3 *db, void *pAux, int argc, const char *const *argv, sqlite3_vtab **ppVtab, char **pzErr) { UNUSED_PARAMETER(pAux); UNUSED_PARAMETER(argc); UNUSED_PARAMETER(argv); - UNUSED_PARAMETER(pzErr); // TODO use + UNUSED_PARAMETER(pzErr); vec_npy_each_vtab *pNew; int rc; @@ -2405,12 +2557,11 @@ static int vec_npy_eachClose(sqlite3_vtab_cursor *cur) { fclose(pCur->file); pCur->file = NULL; } - if (pCur->fileBuffer) { - sqlite3_free(pCur->fileBuffer); - pCur->fileBuffer = NULL; + if (pCur->chunksBuffer) { + sqlite3_free(pCur->chunksBuffer); + pCur->chunksBuffer = NULL; } if (pCur->vector) { - // sqlite3_free(pCur->vector); pCur->vector = NULL; } sqlite3_free(pCur); @@ -2452,104 +2603,48 @@ static int vec_npy_eachFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum, UNUSED_PARAMETER(idxNum); UNUSED_PARAMETER(idxStr); assert(argc == 1); + int rc; + vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)pVtabCursor; if (pCur->file) { fclose(pCur->file); pCur->file = NULL; } - if (pCur->fileBuffer) { - sqlite3_free(pCur->fileBuffer); - pCur->fileBuffer = NULL; + if (pCur->chunksBuffer) { + sqlite3_free(pCur->chunksBuffer); + pCur->chunksBuffer = NULL; } if (pCur->vector) { - // sqlite3_free(pCur->vector); TODO don't need to free this?? pCur->vector = NULL; } struct VecNpyFile *f = NULL; if ((f = sqlite3_value_pointer(argv[0], SQLITE_VEC_NPY_FILE_NAME))) { - int n; FILE *file = fopen(f->path, "r"); - todo_assert(file); - - fseek(file, 0, SEEK_END); - long fileSize = ftell(file); - - fseek(file, 0L, SEEK_SET); - - unsigned char header[10]; - n = fread(&header, sizeof(unsigned char), 10, file); - todo_assert(n == 10); - - for (size_t i = 0; i < countof(NPY_MAGIC); i++) { - todo_assert(NPY_MAGIC[i] == header[i]); - } - u8 major = header[6]; - u8 minor = header[7]; - - uint16_t headerLength = 0; - memcpy(&headerLength, &header[8], sizeof(uint16_t)); - - size_t totalHeaderLength = sizeof(NPY_MAGIC) + sizeof(major) + - sizeof(minor) + sizeof(headerLength) + - headerLength; - size_t dataSize = fileSize - totalHeaderLength; - todo_assert(dataSize > 0); - - unsigned char *headerX = sqlite3_malloc(headerLength); - todo_assert(headerX); - n = fread(headerX, sizeof(char), headerLength, file); - todo_assert(n == headerLength); - - int fortran_order; - enum VectorElementType element_type; - size_t numElements; - size_t numDimensions; - int rc = parse_npy_header(pVtabCursor->pVtab, headerX, headerLength, - &element_type, &fortran_order, &numElements, - &numDimensions); - sqlite3_free(headerX); - todo_assert(rc == SQLITE_OK); - - int element_size = 0; - if (element_type == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) { - element_size = sizeof(f32); - } else { - todo("non-f32 numpy array"); + if (!file) { + vtab_set_error(pVtabCursor->pVtab, "Could not open numpy file"); + return SQLITE_ERROR; } - todo_assert((numElements * numDimensions * element_size) == dataSize); - - pCur->bufferIndex = 0; - pCur->bufferLength = 1024; - pCur->elementSize = element_size; - pCur->elementType = element_type; - pCur->nElements = numElements; - pCur->nDimensions = numDimensions; - pCur->fileBufferSize = numDimensions * element_size * pCur->bufferLength; - pCur->fileBuffer = sqlite3_malloc(pCur->fileBufferSize); - todo_assert(pCur->fileBuffer); - pCur->input_type = VEC_NPY_EACH_INPUT_FILE; - n = fread(pCur->fileBuffer, 1, pCur->fileBufferSize, file); - todo_assert((size_t)n == pCur->fileBufferSize); // TODO may be smaller - - pCur->eof = 0; - pCur->file = file; + rc = parse_npy_file(pVtabCursor->pVtab, file, pCur); + if (rc != SQLITE_OK) { + fclose(file); + return rc; + } } else { const unsigned char *input = sqlite3_value_blob(argv[0]); int inputLength = sqlite3_value_bytes(argv[0]); - int rc; void *data; size_t numElements; size_t numDimensions; enum VectorElementType element_type; - rc = parse_npy(pVtabCursor->pVtab, input, inputLength, &data, &numElements, - &numDimensions, &element_type); + rc = parse_npy_buffer(pVtabCursor->pVtab, input, inputLength, &data, + &numElements, &numDimensions, &element_type); if (rc != SQLITE_OK) { return rc; } @@ -2574,7 +2669,7 @@ static int vec_npy_eachRowid(sqlite3_vtab_cursor *cur, sqlite_int64 *pRowid) { static int vec_npy_eachEof(sqlite3_vtab_cursor *cur) { vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)cur; if (pCur->input_type == VEC_NPY_EACH_INPUT_BUFFER) { - return (size_t)pCur->iRowid >= pCur->nElements; + return (!pCur->nElements) || (size_t)pCur->iRowid >= pCur->nElements; } return pCur->eof; } @@ -2582,55 +2677,44 @@ static int vec_npy_eachEof(sqlite3_vtab_cursor *cur) { static int vec_npy_eachNext(sqlite3_vtab_cursor *cur) { vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)cur; pCur->iRowid++; - if (pCur->input_type == VEC_NPY_EACH_INPUT_FILE) { - pCur->bufferIndex++; - if (pCur->bufferIndex >= pCur->bufferLength) { - int n = fread(pCur->fileBuffer, 1, pCur->fileBufferSize, pCur->file); - if (!n) { - pCur->eof = 1; - } - pCur->bufferIndex = 0; - pCur->bufferLength = n / pCur->nDimensions / pCur->elementSize; + if (pCur->input_type == VEC_NPY_EACH_INPUT_BUFFER) { + return SQLITE_OK; + } + + // else: input is a file + pCur->currentChunkIndex++; + if (pCur->currentChunkIndex >= pCur->currentChunkSize) { + pCur->currentChunkSize = + fread(pCur->chunksBuffer, + vector_byte_size(pCur->elementType, pCur->nDimensions), + pCur->maxChunks, pCur->file); + if (!pCur->currentChunkSize) { + pCur->eof = 1; } + pCur->currentChunkIndex = 0; } return SQLITE_OK; } -static int vec_npy_eachColumn(sqlite3_vtab_cursor *cur, - sqlite3_context *context, int i) { - vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)cur; +static int vec_npy_eachColumnBuffer(vec_npy_each_cursor *pCur, + sqlite3_context *context, int i) { switch (i) { case VEC_NPY_EACH_COLUMN_VECTOR: { - if (pCur->input_type == VEC_NPY_EACH_INPUT_BUFFER) { - switch (pCur->elementType) { - case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { - sqlite3_result_blob( - context, - &pCur->vector[pCur->iRowid * pCur->nDimensions * sizeof(f32)], - pCur->nDimensions * sizeof(f32), SQLITE_STATIC); - break; - } - case SQLITE_VEC_ELEMENT_TYPE_INT8: - case SQLITE_VEC_ELEMENT_TYPE_BIT: { - todo("bit array npy column"); - break; - } - } - } else if (pCur->input_type == VEC_NPY_EACH_INPUT_FILE) { - switch (pCur->elementType) { - case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { - sqlite3_result_blob(context, - &pCur->fileBuffer[pCur->bufferIndex * - pCur->nDimensions * sizeof(f32)], - pCur->nDimensions * sizeof(f32), SQLITE_TRANSIENT); - break; - } - case SQLITE_VEC_ELEMENT_TYPE_INT8: - case SQLITE_VEC_ELEMENT_TYPE_BIT: { - todo("bit array npy column"); - break; - } - } + sqlite3_result_subtype(context, pCur->elementType); + switch (pCur->elementType) { + case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { + sqlite3_result_blob( + context, + &pCur->vector[pCur->iRowid * pCur->nDimensions * sizeof(f32)], + pCur->nDimensions * sizeof(f32), SQLITE_STATIC); + + break; + } + case SQLITE_VEC_ELEMENT_TYPE_INT8: + case SQLITE_VEC_ELEMENT_TYPE_BIT: { + todo("bit array npy column"); + break; + } } break; @@ -2638,6 +2722,39 @@ static int vec_npy_eachColumn(sqlite3_vtab_cursor *cur, } return SQLITE_OK; } +static int vec_npy_eachColumnFile(vec_npy_each_cursor *pCur, + sqlite3_context *context, int i) { + switch (i) { + case VEC_NPY_EACH_COLUMN_VECTOR: { + switch (pCur->elementType) { + case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { + sqlite3_result_blob(context, + &pCur->chunksBuffer[pCur->currentChunkIndex * + pCur->nDimensions * sizeof(f32)], + pCur->nDimensions * sizeof(f32), SQLITE_TRANSIENT); + break; + } + case SQLITE_VEC_ELEMENT_TYPE_INT8: + case SQLITE_VEC_ELEMENT_TYPE_BIT: { + todo("bit array npy column"); + break; + } + } + break; + } + } + return SQLITE_OK; +} +static int vec_npy_eachColumn(sqlite3_vtab_cursor *cur, + sqlite3_context *context, int i) { + vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)cur; + switch (pCur->input_type) { + case VEC_NPY_EACH_INPUT_BUFFER: + return vec_npy_eachColumnBuffer(pCur, context, i); + case VEC_NPY_EACH_INPUT_FILE: + return vec_npy_eachColumnFile(pCur, context, i); + } +} static sqlite3_module vec_npy_eachModule = { /* iVersion */ 0, @@ -2664,7 +2781,7 @@ static sqlite3_module vec_npy_eachModule = { /* xRelease */ 0, /* xRollbackTo */ 0, /* xShadowName */ 0, -#if SQLITE_VERSION_NUMBER >= 3440000 +#if SQLITE_VERSION_NUMBER >= 3044000 /* xIntegrity */ 0, #endif }; @@ -2814,7 +2931,7 @@ struct vec0_vtab { * Result columns: * 0: chunk_id (i64) * 1: chunk_offset (i64) - * SQL: "SELECT chunk_id, chunk_offset FROM _rowids WHERE rowid = ?"" + * SQL: "SELECT id, chunk_id, chunk_offset FROM _rowids WHERE rowid = ?"" * * Must be cleaned up with sqlite3_finalize(). */ @@ -2914,19 +3031,23 @@ int vec0_column_idx_to_vector_idx(vec0_vtab *pVtab, int column_idx) { */ int vec0_get_id_value_from_rowid(vec0_vtab *pVtab, i64 rowid, sqlite3_value **out) { - // TODO different stmt than stmtRowidsGetChunkPosition? - // TODO return rc instead - sqlite3_reset(pVtab->stmtRowidsGetChunkPosition); - sqlite3_clear_bindings(pVtab->stmtRowidsGetChunkPosition); + int rc; + // PERF: different stmt than stmtRowidsGetChunkPosition? + // TODO: test / evidence-of sqlite3_bind_int64(pVtab->stmtRowidsGetChunkPosition, 1, rowid); - int rc = sqlite3_step(pVtab->stmtRowidsGetChunkPosition); + rc = sqlite3_step(pVtab->stmtRowidsGetChunkPosition); if (rc == SQLITE_ROW) { return SQLITE_ERROR; } sqlite3_value *value = sqlite3_column_value(pVtab->stmtRowidsGetChunkPosition, 0); *out = sqlite3_value_dup(value); - return SQLITE_OK; + rc = SQLITE_OK; + + cleanup: + sqlite3_reset(pVtab->stmtRowidsGetChunkPosition); + sqlite3_clear_bindings(pVtab->stmtRowidsGetChunkPosition); + return rc; } // TODO make sure callees use the return value of this function @@ -3015,16 +3136,29 @@ cleanup: int vec0_get_chunk_position(vec0_vtab *p, i64 rowid, i64 *chunk_id, i64 *chunk_offset) { int rc; - sqlite3_reset(p->stmtRowidsGetChunkPosition); - sqlite3_clear_bindings(p->stmtRowidsGetChunkPosition); sqlite3_bind_int64(p->stmtRowidsGetChunkPosition, 1, rowid); + rc = sqlite3_step(p->stmtRowidsGetChunkPosition); - todo_assert(rc == SQLITE_ROW); + if (rc != SQLITE_ROW) { + vtab_set_error(&p->base, "Could not find chunk position for %lld", rowid); + goto cleanup; + } *chunk_id = sqlite3_column_int64(p->stmtRowidsGetChunkPosition, 1); *chunk_offset = sqlite3_column_int64(p->stmtRowidsGetChunkPosition, 2); + rc = sqlite3_step(p->stmtRowidsGetChunkPosition); - todo_assert(rc == SQLITE_DONE); - return SQLITE_OK; + if (rc != SQLITE_DONE) { + vtab_set_error(&p->base, "Could not find chunk position for %lld", rowid); + goto cleanup; + } + + rc = SQLITE_OK; + +cleanup: + sqlite3_reset(p->stmtRowidsGetChunkPosition); + sqlite3_clear_bindings(p->stmtRowidsGetChunkPosition); + + return rc; } /** @@ -3818,6 +3952,7 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { // forward delcaration bc vec0Filter uses it static int vec0Next(sqlite3_vtab_cursor *cur); +// TODO: Ya this shit is slow void dethrone(int k, f32 *base_distances, i64 *base_rowids, size_t chunk_size, i32 *chunk_top_idx, f32 *chunk_distances, i64 *chunk_rowids, @@ -3845,6 +3980,7 @@ void dethrone(int k, f32 *base_distances, i64 *base_rowids, size_t chunk_size, } } +/* // TODO is this better?? from vec_expo experiment void dethrone2(int k, f32 *base_distances, i64 *base_rowids, size_t chunk_size, i32 *chunk_top_idx, f32 *chunk_distances, i64 *chunk_rowids, @@ -3871,8 +4007,7 @@ void dethrone2(int k, f32 *base_distances, i64 *base_rowids, size_t chunk_size, } } } - -// TODO: Ya this shit is slow +*/ /** * @brief Finds the minimum k items in distances, and writes the indicies to @@ -4336,6 +4471,10 @@ static int vec0Column_point(vec0_vtab *pVtab, vec0_cursor *pCur, } // TODO only have 1st vector data if (vec0_column_idx_is_vector(pVtab, i)) { + if(sqlite3_vtab_nochange(context)) { + sqlite3_result_null(context); + return SQLITE_OK; + } int vector_idx = vec0_column_idx_to_vector_idx(pVtab, i); sqlite3_result_blob( context, pCur->point_data->vectors[vector_idx], @@ -4930,7 +5069,7 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, if (rc != SQLITE_OK) { // IMP: V06519_23358 vtab_set_error( - pVTab, "Inserted vector for the \"%.*s\" column is invalid: %s", + pVTab, "Inserted vector for the \"%.*s\" column is invalid: %z", p->vector_columns[i].name_length, p->vector_columns[i].name, pzError); rc = SQLITE_ERROR; goto cleanup; @@ -4974,7 +5113,6 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, } // Cannot insert a value in the hidden "k" column if (sqlite3_value_type(argv[2 + vec0_column_k_idx(p)]) != SQLITE_NULL) { - // TODO cleanups // IMP: V11875_28713 vtab_set_error(pVTab, "A value was provided for the hidden \"k\" column."); rc = SQLITE_ERROR; @@ -5017,54 +5155,214 @@ cleanup: return rc; } +int vec0Update_Delete_ClearValidity(vec0_vtab *p, i64 chunk_id, + u64 chunk_offset) { + int rc, brc; + sqlite3_blob *blobChunksValidity = NULL; + char unsigned bx; + int validityOffset = chunk_offset / CHAR_BIT; + + // 2. ensure chunks.validity bit is 1, then set to 0 + rc = sqlite3_blob_open(p->db, p->schemaName, p->shadowChunksName, "validity", + chunk_id, 1, &blobChunksValidity); + if (rc != SQLITE_OK) { + // IMP: V26002_10073 + vtab_set_error(&p->base, "could not open validity blob for %s.%s.%lld", + p->schemaName, p->shadowChunksName, chunk_id); + return SQLITE_ERROR; + } + // will skip the sqlite3_blob_bytes(blobChunksValidity) check for now, + // the read below would catch it + + rc = sqlite3_blob_read(blobChunksValidity, &bx, sizeof(bx), validityOffset); + if (rc != SQLITE_OK) { + // IMP: V21193_05263 + vtab_set_error( + &p->base, "could not read validity blob for %s.%s.%lld at %d", + p->schemaName, p->shadowChunksName, chunk_id, validityOffset); + goto cleanup; + } + if (!(bx >> (chunk_offset % CHAR_BIT))) { + // IMP: V21193_05263 + rc = SQLITE_ERROR; + vtab_set_error( + &p->base, + "vec0 deletion error: validity bit is not set for %s.%s.%lld at %d", + p->schemaName, p->shadowChunksName, chunk_id, validityOffset); + goto cleanup; + } + char unsigned mask = ~(1 << (chunk_offset % CHAR_BIT)); + char result = bx & mask; + rc = sqlite3_blob_write(blobChunksValidity, &result, sizeof(bx), + validityOffset); + if (rc != SQLITE_OK) { + vtab_set_error( + &p->base, "could not write to validity blob for %s.%s.%lld at %d", + p->schemaName, p->shadowChunksName, chunk_id, validityOffset); + goto cleanup; + } + +cleanup: + + brc = sqlite3_blob_close(blobChunksValidity); + if (rc != SQLITE_OK) + return rc; + if (brc != SQLITE_OK) { + vtab_set_error(&p->base, + "vec0 deletion error: Error commiting validity blob " + "transaction on %s.%s.%lld at %d", + p->schemaName, p->shadowChunksName, chunk_id, + validityOffset); + return brc; + } + return SQLITE_OK; +} + +int vec0Update_Delete_DeleteRowids(vec0_vtab *p, i64 rowid) { + int rc; + sqlite3_stmt *stmt = NULL; + + char *zSql = + sqlite3_mprintf("DELETE FROM " VEC0_SHADOW_ROWIDS_NAME " WHERE rowid = ?", + p->schemaName, p->tableName); + if (!zSql) { + return SQLITE_NOMEM; + } + + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if(rc != SQLITE_OK ) { + goto cleanup; + } + sqlite3_bind_int64(stmt, 1, rowid); + rc = sqlite3_step(stmt); + if(rc != SQLITE_DONE) { + goto cleanup; + } + rc = SQLITE_OK; + + cleanup: + sqlite3_finalize(stmt); + return rc; +} + int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite_int64 rowid) { vec0_vtab *p = (vec0_vtab *)pVTab; int rc; i64 chunk_id; i64 chunk_offset; - sqlite3_blob *blobChunksValidity = NULL; + + // 1. Find chunk position for given rowid + // 2. Ensure that validity bit for position is 1, then set to 0 + // 3. Zero out rowid in chunks.rowid + // 4. Zero out vector data in all vector column chunks + // 5. Delete value in _rowids table // 1. get chunk_id and chunk_offset from _rowids + // TODO how to make this fail without failing the point query rc = vec0_get_chunk_position(p, rowid, &chunk_id, &chunk_offset); - todo_assert(rc == SQLITE_OK); + if (rc != SQLITE_OK) { + return rc; + } - // 2. ensure chunks.validity bit is 1, then set to 0 - rc = sqlite3_blob_open(p->db, p->schemaName, p->shadowChunksName, "validity", - chunk_id, 1, &blobChunksValidity); - todo_assert(rc == SQLITE_OK); - char unsigned bx; - rc = sqlite3_blob_read(blobChunksValidity, &bx, sizeof(bx), - chunk_offset / CHAR_BIT); - todo_assert(rc == SQLITE_OK); - todo_assert(bx >> (chunk_offset % CHAR_BIT)); - char unsigned mask = ~(1 << (chunk_offset % CHAR_BIT)); - char result = bx & mask; - rc = sqlite3_blob_write(blobChunksValidity, &result, sizeof(bx), - chunk_offset / CHAR_BIT); - todo_assert(rc == SQLITE_OK); - sqlite3_blob_close(blobChunksValidity); + rc = vec0Update_Delete_ClearValidity(p, chunk_id, chunk_offset); + if (rc != SQLITE_OK) { + return rc; + } // 3. zero out rowid in chunks.rowids TODO // 4. zero out any data in vector chunks tables TODO // 5. delete from _rowids table - char *zSql = - sqlite3_mprintf("DELETE FROM " VEC0_SHADOW_ROWIDS_NAME " WHERE rowid = ?", - p->schemaName, p->tableName); - todo_assert(zSql); - sqlite3_stmt *stmt; - rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); - sqlite3_free(zSql); - todo_assert(rc == SQLITE_OK); - sqlite3_bind_int64(stmt, 1, rowid); - rc = sqlite3_step(stmt); - todo_assert(SQLITE_DONE); - sqlite3_finalize(stmt); + rc = vec0Update_Delete_DeleteRowids(p, rowid); + if (rc != SQLITE_OK) { + return rc; + } return SQLITE_OK; } +int vec0Update_UpdateVectorColumn(vec0_vtab *p, i64 chunk_id, i64 chunk_offset, + int i, sqlite3_value *valueVector) { + int rc; + + sqlite3_blob *blobVectors = NULL; + + char *pzError; + size_t dimensions; + enum VectorElementType elementType; + void *vector; + vector_cleanup cleanup = vector_cleanup_noop; + // TODO: Can't update non f32, bc subtypes are stripped from UPDATEs. + // Need to 1) create a less strict vector_from_value, or 2) wait + // for this to resolve: https://sqlite.org/forum/forumpost/65317ce9c6 + rc = vector_from_value(valueVector, &vector, &dimensions, &elementType, + &cleanup, &pzError); + if (rc != SQLITE_OK) { + // IMP: V15203_32042 + vtab_set_error( + &p->base, "Updated vector for the \"%.*s\" column is invalid: %z", + p->vector_columns[i].name_length, p->vector_columns[i].name, pzError); + rc = SQLITE_ERROR; + goto cleanup; + } + if (elementType != p->vector_columns[i].element_type) { + // IMP: V03643_20481 + vtab_set_error( + &p->base, + "Updated vector for the \"%.*s\" column is expected to be of type " + "%s, but a %s vector was provided.", + p->vector_columns[i].name_length, p->vector_columns[i].name, + vector_subtype_name(p->vector_columns[i].element_type), + vector_subtype_name(elementType)); + rc = SQLITE_ERROR; + goto cleanup; + } + if (dimensions != p->vector_columns[i].dimensions) { + // IMP: V25739_09810 + vtab_set_error( + &p->base, + "Dimension mismatch for new updated vector for the \"%.*s\" column. " + "Expected %d dimensions but received %d.", + p->vector_columns[i].name_length, p->vector_columns[i].name, + p->vector_columns[i].dimensions, dimensions); + rc = SQLITE_ERROR; + goto cleanup; + } + + rc = sqlite3_blob_open(p->db, p->schemaName, p->shadowVectorChunksNames[i], + "vectors", chunk_id, 1, &blobVectors); + if (rc != SQLITE_OK) { + vtab_set_error(&p->base, "Could not open vectors blob for %s.%s.%lld", + p->schemaName, p->shadowVectorChunksNames[i], chunk_id); + goto cleanup; + } + rc = vec0_write_vector_to_vector_blob(blobVectors, chunk_offset, vector, + p->vector_columns[i].dimensions, + p->vector_columns[i].element_type); + if (rc != SQLITE_OK) { + vtab_set_error(&p->base, "Could not write to vectors blob for %s.%s.%lld", + p->schemaName, p->shadowVectorChunksNames[i], chunk_id); + goto cleanup; + } + +cleanup: + cleanup(vector); + int brc = sqlite3_blob_close(blobVectors); + if (rc != SQLITE_OK) { + return rc; + } + if (brc != SQLITE_OK) { + vtab_set_error( + &p->base, + "Could not commit blob transaction for vectors blob for %s.%s.%lld", + p->schemaName, p->shadowVectorChunksNames[i], chunk_id); + return brc; + } + return SQLITE_OK; +} + int vec0Update_UpdateOnRowid(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv) { UNUSED_PARAMETER(argc); @@ -5076,46 +5374,33 @@ int vec0Update_UpdateOnRowid(sqlite3_vtab *pVTab, int argc, // 1. get chunk_id and chunk_offset from _rowids rc = vec0_get_chunk_position(p, rowid, &chunk_id, &chunk_offset); - todo_assert(rc == SQLITE_OK); + if(rc != SQLITE_OK) { + return rc; + } // 2) iterate over all new vectors, update the vectors - // read all the inserted vectors into vectorDatas, validate their lengths. for (int i = 0; i < p->numVectorColumns; i++) { sqlite3_value *valueVector = argv[2 + VEC0_COLUMN_VECTORN_START + i]; - size_t dimensions; - void *vector = (void *)sqlite3_value_blob(valueVector); - switch (p->vector_columns[i].element_type) { - case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: - dimensions = sqlite3_value_bytes(valueVector) / sizeof(f32); - break; - case SQLITE_VEC_ELEMENT_TYPE_INT8: - dimensions = sqlite3_value_bytes(valueVector) * sizeof(i8); - break; - case SQLITE_VEC_ELEMENT_TYPE_BIT: - dimensions = sqlite3_value_bytes(valueVector) * CHAR_BIT; - break; - } - if (dimensions != p->vector_columns[i].dimensions) { - SET_VTAB_ERROR("TODO vector length dont make sense."); - sqlite3_free(pVTab->zErrMsg); - pVTab->zErrMsg = - sqlite3_mprintf("Vector length mismatch on '%s' column: Expected %d " - "dimensions, found %d", - p->vector_columns[i].name, - p->vector_columns[i].dimensions, dimensions); - return SQLITE_ERROR; + // in vec0Column, we check sqlite3_vtab_nochange() on vector columns. + // If the vector column isn't being changed, we return NULL; + // That's not great, that means vector columns can never be NULLABLE + // (bc we cant distinguish if an updated vector is truly NULL or nochange). + // Also it means that if someone tries to run `UPDATE v SET X = NULL`, + // we can't effectively detect and raise an error. + // A better solution would be to use a custom result_type for "empty", + // but subtypes don't appear to survive xColumn -> xUpdate, it's always 0. + // So for now, we'll just use NULL and warn people to not SET X = NULL + // in the docs. + if(sqlite3_value_type(valueVector) == SQLITE_NULL) { + continue; } - sqlite3_blob *blobVectors; - rc = sqlite3_blob_open(p->db, p->schemaName, p->shadowVectorChunksNames[i], - "vectors", chunk_id, 1, &blobVectors); - todo_assert(rc == SQLITE_OK); - rc = vec0_write_vector_to_vector_blob(blobVectors, chunk_offset, vector, - p->vector_columns[i].dimensions, - p->vector_columns[i].element_type); - todo_assert(rc == SQLITE_OK); - sqlite3_blob_close(blobVectors); + rc = vec0Update_UpdateVectorColumn(p, chunk_id, chunk_offset, i, + valueVector); + if(rc != SQLITE_OK){ + return SQLITE_ERROR; + } } return SQLITE_OK; @@ -5139,7 +5424,8 @@ static int vec0Update(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, return vec0Update_UpdateOnRowid(pVTab, argc, argv); } - SET_VTAB_ERROR("UPDATE operation on rowids with vec0 is not supported."); + vtab_set_error(pVTab, + "UPDATE operation on rowids with vec0 is not supported."); return SQLITE_ERROR; } // unknown operation @@ -5150,7 +5436,6 @@ static int vec0Update(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, } static int vec0ShadowName(const char *zName) { - // TODO multiple vector_chunk tables static const char *azName[] = {"rowids", "chunks", "vector_chunks"}; for (size_t i = 0; i < sizeof(azName) / sizeof(azName[0]); i++) { @@ -5185,7 +5470,7 @@ static sqlite3_module vec0Module = { /* xRelease */ 0, /* xRollbackTo */ 0, /* xShadowName */ vec0ShadowName, -#if SQLITE_VERSION_NUMBER >= 3440000 +#if SQLITE_VERSION_NUMBER >= 3044000 /* xIntegrity */ 0, // TODO #endif }; @@ -5203,7 +5488,10 @@ static void vec_static_blob_from_raw(sqlite3_context *context, int argc, sqlite3_value **argv) { struct static_blob_definition *p; p = sqlite3_malloc(sizeof(*p)); - todo_assert(p); + if(!p) { + sqlite3_result_error_nomem(context); + return; + } p->p = sqlite3_value_int64(argv[0]); p->element_type = SQLITE_VEC_ELEMENT_TYPE_FLOAT32; p->dimensions = sqlite3_value_int64(argv[2]); @@ -5891,7 +6179,9 @@ int sqlite3_vec_init(sqlite3 *db, char **pzErrMsg, #ifdef SQLITE_VEC_ENABLE_EXPERIMENTAL vec_static_blob_data *static_blob_data; static_blob_data = sqlite3_malloc(sizeof(*static_blob_data)); - todo_assert(static_blob_data); + if(!static_blob_data) { + return SQLITE_NOMEM; + } memset(static_blob_data, 0, sizeof(*static_blob_data)); #endif diff --git a/tests/test-loadable.py b/tests/test-loadable.py index b18f679..1e2d306 100644 --- a/tests/test-loadable.py +++ b/tests/test-loadable.py @@ -14,8 +14,8 @@ from math import isclose EXT_PATH = "./dist/vec0" -SUPPORTS_SUBTYPE = sqlite3.version_info[1] > 38 -SUPPORTS_DROP_COLUMN = sqlite3.version_info[1] >= 35 +SUPPORTS_SUBTYPE = sqlite3.sqlite_version_info[1] > 38 +SUPPORTS_DROP_COLUMN = sqlite3.sqlite_version_info[1] >= 35 def bitmap_full(n: int) -> bytearray: @@ -39,11 +39,19 @@ def _f32(list): return struct.pack("%sf" % len(list), *list) +def _i64(list): + return struct.pack("%sL" % len(list), *list) + + def _int8(list): return struct.pack("%sb" % len(list), *list) -def connect(ext, path=":memory:"): +def bitmap(bitstring): + return bytes([int(bitstring, 2)]) + + +def connect(ext, path=":memory:", extra_entrypoint=None): db = sqlite3.connect(path) db.execute( @@ -54,6 +62,9 @@ def connect(ext, path=":memory:"): db.enable_load_extension(True) db.load_extension(ext) + if extra_entrypoint: + db.execute("select load_extension(?, ?)", [ext, extra_entrypoint]) + db.execute( "create temp table loaded_functions as select name from pragma_function_list where name not in (select name from base_functions) order by name" ) @@ -497,7 +508,7 @@ def test_vec0(): pass -def test_vec0_updates(): +def test_vec0_inserts(): db = connect(EXT_PATH) db.execute( """ @@ -527,18 +538,18 @@ def test_vec0_updates(): "ccc": bitmap_full(128), } ] - db.execute( - "update t set aaa = ? where rowid = ?", - [np.full((128,), 0.00011, dtype="float32"), 1], - ) - assert execute_all(db, "select * from t") == [ - { - "rowid": 1, - "aaa": _f32([0.00011] * 128), - "bbb": _int8([4] * 128), - "ccc": bitmap_full(128), - } - ] + #db.execute( + # "update t set aaa = ? where rowid = ?", + # [np.full((128,), 0.00011, dtype="float32"), 1], + #) + #assert execute_all(db, "select * from t") == [ + # { + # "rowid": 1, + # "aaa": _f32([0.00011] * 128), + # "bbb": _int8([4] * 128), + # "ccc": bitmap_full(128), + # } + #] db.execute("create virtual table t1 using vec0(aaa float[4], chunk_size=8)") db.execute( @@ -688,7 +699,7 @@ def test_vec0_updates(): db.execute("insert into txt_pk(txt_id, aaa) values ('b', '[2,2,2,2]')") -def test_vec0_update_insert_errors2(): +def test_vec0_insert_errors2(): db = connect(EXT_PATH) db.execute("create virtual table t1 using vec0(aaa float[4], chunk_size=8)") db.execute( @@ -772,7 +783,140 @@ def test_vec0_drops(): ] == [] -def test_vec0_update_deletes(): +def test_vec0_delete(): + db = connect(EXT_PATH) + db.execute("create virtual table t1 using vec0(aaa float[4], chunk_size=8)") + db.execute( + """ + insert into t1(aaa) values + ('[1,1,1,1]'), + ('[2,1,1,1]'), + ('[3,1,1,1]'), + ('[4,1,1,1]'), + ('[5,1,1,1]'), + ('[6,1,1,1]') + """ + ) + assert execute_all(db, "select * from t1_rowids") == [ + { + "chunk_id": 1, + "chunk_offset": 0, + "id": None, + "rowid": 1, + }, + { + "chunk_id": 1, + "chunk_offset": 1, + "id": None, + "rowid": 2, + }, + { + "chunk_id": 1, + "chunk_offset": 2, + "id": None, + "rowid": 3, + }, + { + "chunk_id": 1, + "chunk_offset": 3, + "id": None, + "rowid": 4, + }, + { + "chunk_id": 1, + "chunk_offset": 4, + "id": None, + "rowid": 5, + }, + { + "chunk_id": 1, + "chunk_offset": 5, + "id": None, + "rowid": 6, + }, + ] + assert execute_all(db, "select * from t1_chunks") == [ + { + "chunk_id": 1, + "rowids": _i64([1, 2, 3, 4, 5, 6, 0, 0]), + "size": 8, + "validity": bitmap("00111111"), + } + ] + assert execute_all(db, "select * from t1_vector_chunks00") == [ + { + "rowid": 1, + "vectors": _f32([1, 1, 1, 1]) + + _f32([2, 1, 1, 1]) + + _f32([3, 1, 1, 1]) + + _f32([4, 1, 1, 1]) + + _f32([5, 1, 1, 1]) + + _f32([6, 1, 1, 1]) + + _f32([0, 0, 0, 0]) + + _f32([0, 0, 0, 0]), + } + ] + + db.execute("DELETE FROM t1 WHERE rowid = 1") + assert execute_all(db, "select * from t1_rowids") == [ + { + "chunk_id": 1, + "chunk_offset": 1, + "id": None, + "rowid": 2, + }, + { + "chunk_id": 1, + "chunk_offset": 2, + "id": None, + "rowid": 3, + }, + { + "chunk_id": 1, + "chunk_offset": 3, + "id": None, + "rowid": 4, + }, + { + "chunk_id": 1, + "chunk_offset": 4, + "id": None, + "rowid": 5, + }, + { + "chunk_id": 1, + "chunk_offset": 5, + "id": None, + "rowid": 6, + }, + ] + # TODO finish delete support + # assert execute_all(db, "select * from t1_chunks") == [ + # { + # 'chunk_id': 1, + # 'rowids': _i64([0,2,3,4,5,6,0,0]), + # 'size': 8, + # 'validity': bitmap("00111110"), + # } + # ] + # assert execute_all(db, "select * from t1_vector_chunks00") == [ + # { + # 'rowid': 1, + # 'vectors': _f32([0,0,0,0]) + # +_f32([2,1,1,1]) + # +_f32([3,1,1,1]) + # +_f32([4,1,1,1]) + # +_f32([5,1,1,1]) + # +_f32([6,1,1,1]) + # +_f32([0,0,0,0]) + # +_f32([0,0,0,0]) + # } + # ] + + # TODO test with text primary keys + + +def test_vec0_delete_errors(): db = connect(EXT_PATH) db.execute("create virtual table t1 using vec0(aaa float[4], chunk_size=8)") db.execute( @@ -791,9 +935,36 @@ def test_vec0_update_deletes(): # db.execute("begin") # db.execute("DELETE FROM t1_rowids WHERE rowid = 1") # with _raises("XXX"): - # db.execute("DELETE FROM t1 where rowid = 1") + # db.execute("DELETE FROM t1 where rowid = 1") # db.rollback() + # EVIDENCE-OF: V26002_10073 vec0 DELETE error on reading validity blob + if SUPPORTS_DROP_COLUMN: + db.commit() + db.execute("begin") + db.execute("ALTER TABLE t1_chunks DROP COLUMN validity") + with _raises("could not open validity blob for main.t1_chunks.1"): + db.execute("delete from t1 where rowid = 1") + db.rollback() + + # EVIDENCE-OF: V21193_05263 vec0 DELETE verifies that the validity bit is 1 before clearing + db.commit() + db.execute("begin") + db.execute("UPDATE t1_chunks SET validity = zeroblob(1)") + with _raises( + "vec0 deletion error: validity bit is not set for main.t1_chunks.1 at 0" + ): + db.execute("delete from t1 where rowid = 1") + db.rollback() + + # EVIDENCE-OF: V21193_05263 vec0 DELETE raises error on validity blob error + db.commit() + db.execute("begin") + db.execute("UPDATE t1_chunks SET validity = zeroblob(0)") + with _raises("could not read validity blob for main.t1_chunks.1 at 0"): + db.execute("delete from t1 where rowid = 1") + db.rollback() + if False: # TODO with _raises("XXX"): db.execute("DELETE FROM t1 WHERE rowid = 999") @@ -806,6 +977,158 @@ def test_vec0_update_deletes(): db.rollback() +def test_vec0_updates(): + db = connect(EXT_PATH) + db.execute( + """ + create virtual table t3 using vec0( + aaa float[8], + bbb int8[8], + ccc bit[8] + ); + """ + ) + db.execute( + """ + INSERT INTO t3 VALUES + (1, :x, vec_quantize_i8(:x, 'unit') ,vec_quantize_binary(:x)), + (2, :y, vec_quantize_i8(:y, 'unit') ,vec_quantize_binary(:y)), + (3, :z, vec_quantize_i8(:z, 'unit') ,vec_quantize_binary(:z)); + """, + { + "x": "[.1, .1, .1, .1, -.1, -.1, -.1, -.1]", + "y": "[-.2, .2, .2, .2, .2, .2, -.2, .2]", + "z": "[.3, .3, .3, .3, .3, .3, .3, .3]" + }, + ) + assert execute_all(db, "select * from t3") == [ + { + "rowid": 1, + "aaa": _f32([0.1, 0.1, 0.1, 0.1, -0.1, -0.1, -0.1, -0.1]), + "bbb": _int8([12, 12, 12, 12, -13, -13, -13, -13]), + "ccc": bitmap("00001111"), + }, + { + "rowid": 2, + "aaa": _f32([-0.2, 0.2, 0.2, 0.2, 0.2, 0.2, -0.2, 0.2]), + "bbb": _int8([-26, 24, 24, 24, 24, 24, -26, 24]), + "ccc": bitmap("10111110"), + }, + { + "rowid": 3, + "aaa": _f32([0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3]), + "bbb": _int8([37, 37, 37, 37, 37, 37, 37, 37, ]), + "ccc": bitmap("11111111"), + }, + ] + + db.execute("UPDATE t3 SET aaa = ? WHERE rowid = 1", ['[.9,.9,.9,.9,.9,.9,.9,.9]']) + assert execute_all(db, "select * from t3") == [ + { + "rowid": 1, + "aaa": _f32([0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9]), + "bbb": _int8([12, 12, 12, 12, -13, -13, -13, -13]), + "ccc": bitmap("00001111"), + }, + { + "rowid": 2, + "aaa": _f32([-0.2, 0.2, 0.2, 0.2, 0.2, 0.2, -0.2, 0.2]), + "bbb": _int8([-26, 24, 24, 24, 24, 24, -26, 24]), + "ccc": bitmap("10111110"), + }, + { + "rowid": 3, + "aaa": _f32([0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3]), + "bbb": _int8([37, 37, 37, 37, 37, 37, 37, 37, ]), + "ccc": bitmap("11111111"), + }, + ] + + # EVIDENCE-OF: V15203_32042 vec0 UPDATE validates vector + with _raises('Updated vector for the "aaa" column is invalid: invalid float32 vector BLOB length. Must be divisible by 4, found 1'): + db.execute("UPDATE t3 SET aaa = X'AB' WHERE rowid = 1") + + # EVIDENCE-OF: V25739_09810 vec0 UPDATE validates dimension length + with _raises('Dimension mismatch for new updated vector for the "aaa" column. Expected 8 dimensions but received 1.'): + db.execute("UPDATE t3 SET aaa = vec_bit(X'AABBCCDD') WHERE rowid = 1") + + # EVIDENCE-OF: V03643_20481 vec0 UPDATE validates vector column type + with _raises('Updated vector for the "bbb" column is expected to be of type int8, but a float32 vector was provided.'): + db.execute("UPDATE t3 SET bbb = X'ABABABAB' WHERE rowid = 1") + + db.execute("CREATE VIRTUAL TABLE t2 USING vec0(a float[2], b float[2])") + db.execute("INSERT INTO t2(rowid, a, b) VALUES (1, '[.1, .1]', '[.2, .2]')") + assert execute_all(db, "select * from t2") == [{ + 'rowid': 1, + 'a': _f32([.1, .1]), + 'b': _f32([.2, .2]), + }] + # sanity check: the 1st column UPDATE "works", but since the 2nd one fails, + # then aaa should remain unchanged. + with _raises('Dimension mismatch for new updated vector for the "b" column. Expected 2 dimensions but received 3.'): + db.execute("UPDATE t2 SET a = '[.11, .11]', b = '[.22, .22, .22]' WHERE rowid = 1") + assert execute_all(db, "select * from t2") == [{ + 'rowid': 1, + 'a': _f32([.1, .1]), + 'b': _f32([.2, .2]), + }] + # TODO: set UPDATEs on int8/bit columns + + # db.execute("UPDATE t3 SET ccc = vec_bit(?) WHERE rowid = 3", [bitmap('01010101')]) + # assert execute_all(db, "select * from t3") == [ + # { + # "rowid": 1, + # "aaa": _f32([0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9]), + # "bbb": _int8([12, 12, 12, 12, -13, -13, -13, -13]), + # "ccc": bitmap("00001111"), + # }, + # { + # "rowid": 2, + # "aaa": _f32([-0.2, 0.2, 0.2, 0.2, 0.2, 0.2, -0.2, 0.2]), + # "bbb": _int8([-26, 24, 24, 24, 24, 24, -26, 24]), + # "ccc": bitmap("10111110"), + # }, + # { + # "rowid": 3, + # "aaa": _f32([0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3]), + # "bbb": _int8([37, 37, 37, 37, 37, 37, 37, 37, ]), + # "ccc": bitmap("11111111"), + # }, + # ] + + +def test_vec0_text_pk(): + db = connect(EXT_PATH) + db.execute( + """ + create virtual table t using vec0( + t_id text primary key, + aaa float[8], + bbb float8[8] + ); + """ + ) + db.executemany("INSERT INTO t VALUES (:t_id, :aaa, :bbb)", + [ + { + "t_id": "t_1", + "aaa": "[.1, .1, .1, .1, -.1, -.1, -.1, -.1]", + "bbb": "[.1, .1, .1, .1, -.1, -.1, -.1, -.1]", + }, + { + "t_id": "t_2", + "aaa": "[.1, .1, .1, .1, -.1, -.1, -.1, -.1]", + "bbb": "[.1, .1, .1, .1, -.1, -.1, -.1, -.1]", + }, + { + "t_id": "t_3", + "aaa": "[.1, .1, .1, .1, -.1, -.1, -.1, -.1]", + "bbb": "[.1, .1, .1, .1, -.1, -.1, -.1, -.1]", + }, + ], + ) + assert execute_all(db, "select * from t") == [] + def authorizer_deny_on(operation, x1, x2=None): def _auth(op, p1, p2, p3, p4): if op == operation and p1 == x1 and p2 == x2: @@ -879,6 +1202,8 @@ def test_vec_npy_each(): }, ] + assert vec_npy_each(to_npy(np.array([], dtype=np.float32))) == [] + def test_vec_npy_each_errors(): vec_npy_each = lambda *args: execute_all( @@ -921,12 +1246,144 @@ def test_vec_npy_each_errors(): b"\x93NUMPY\x01\x00v\x00{'descr' False \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@" ) + with _raises("expected a string value after 'descr' key"): + vec_npy_each( + b"\x93NUMPY\x01\x00v\x00{'descr': \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@" + ) + + with _raises("Only '