diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 96da148..a39b78a 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -4,6 +4,7 @@ on: branches: - main - partition-by + - auxiliary permissions: contents: read jobs: diff --git a/TODO b/TODO index a487ddd..3914aba 100644 --- a/TODO +++ b/TODO @@ -3,3 +3,11 @@ - [ ] UPDATE on partition key values - remove previous row from chunk, insert into new one? - [ ] properly sqlite3_vtab_nochange / sqlite3_value_nochange handling + +# auxiliary columns + +- later: + - NOT NULL? + - perf: INSERT stmt should be cached on vec0_vtab + - perf: LEFT JOIN aux table to rowids query in vec0_cursor for rowid/point + stmts, to avoid N lookup queries diff --git a/sqlite-vec.c b/sqlite-vec.c index caa992e..062381a 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -1800,6 +1800,7 @@ enum Vec0TokenType { TOKEN_TYPE_DIGIT, TOKEN_TYPE_LBRACKET, TOKEN_TYPE_RBRACKET, + TOKEN_TYPE_PLUS, TOKEN_TYPE_EQ, }; struct Vec0Token { @@ -1827,6 +1828,12 @@ int vec0_token_next(char *start, char *end, struct Vec0Token *out) { if (is_whitespace(curr)) { ptr++; continue; + } else if (curr == '+') { + ptr++; + out->start = ptr; + out->end = ptr; + out->token_type = TOKEN_TYPE_PLUS; + return VEC0_TOKEN_RESULT_SOME; } else if (curr == '[') { ptr++; out->start = ptr; @@ -2011,6 +2018,76 @@ int vec0_parse_partition_key_definition(const char *source, int source_length, return SQLITE_OK; } +/** + * @brief Parse an argv[i] entry of a vec0 virtual table definition, and see if + * it's an auxiliar column definition, ie `+[name] [type]` like `+contents text` + * + * @param source: argv[i] source string + * @param source_length: length of the source string + * @param out_column_name: If it is a partition key, the output column name. Same lifetime + * as source, points to specific char * + * @param out_column_name_length: Length of out_column_name in bytes + * @param out_column_type: SQLITE_TEXT, SQLITE_INTEGER, SQLITE_FLOAT, or SQLITE_BLOB. + * @return int: SQLITE_EMPTY if not an aux column, SQLITE_OK if it is. + */ +int vec0_parse_auxiliary_column_definition(const char *source, int source_length, + char **out_column_name, + int *out_column_name_length, + int *out_column_type) { + struct Vec0Scanner scanner; + struct Vec0Token token; + char *column_name; + int column_name_length; + int column_type; + vec0_scanner_init(&scanner, source, source_length); + + // Check first token is '+', which denotes aux columns + int rc = vec0_scanner_next(&scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME && + token.token_type != TOKEN_TYPE_PLUS) { + return SQLITE_EMPTY; + } + + rc = vec0_scanner_next(&scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME && + token.token_type != TOKEN_TYPE_IDENTIFIER) { + return SQLITE_EMPTY; + } + + column_name = token.start; + column_name_length = token.end - token.start; + + // Check the next token matches "text" or "integer", as column type + rc = vec0_scanner_next(&scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME && + token.token_type != TOKEN_TYPE_IDENTIFIER) { + return SQLITE_EMPTY; + } + if (sqlite3_strnicmp(token.start, "text", token.end - token.start) == 0) { + column_type = SQLITE_TEXT; + } else if (sqlite3_strnicmp(token.start, "int", token.end - token.start) == + 0 || + sqlite3_strnicmp(token.start, "integer", + token.end - token.start) == 0) { + column_type = SQLITE_INTEGER; + } else if (sqlite3_strnicmp(token.start, "float", token.end - token.start) == + 0 || + sqlite3_strnicmp(token.start, "double", + token.end - token.start) == 0) { + column_type = SQLITE_FLOAT; + } else if (sqlite3_strnicmp(token.start, "blob", token.end - token.start) ==0) { + column_type = SQLITE_BLOB; + } else { + return SQLITE_EMPTY; + } + + *out_column_name = column_name; + *out_column_name_length = column_name_length; + *out_column_type = column_type; + + return SQLITE_OK; +} + /** * @brief Parse an argv[i] entry of a vec0 virtual table definition, and see if * it's a PRIMARY KEY definition. @@ -2108,6 +2185,12 @@ struct Vec0PartitionColumnDefinition { int name_length; }; +struct Vec0AuxiliaryColumnDefinition { + int type; + char * name; + int name_length; +}; + size_t vector_byte_size(enum VectorElementType element_type, size_t dimensions) { switch (element_type) { @@ -3260,6 +3343,8 @@ static sqlite3_module vec_npy_eachModule = { "vectors BLOB NOT NULL" \ ");" +#define VEC0_SHADOW_AUXILIARY_NAME "\"%w\".\"%w_auxiliary\"" + #define VEC_INTERAL_ERROR "Internal sqlite-vec error: " #define REPORT_URL "https://github.com/asg017/sqlite-vec/issues/new" @@ -3267,6 +3352,8 @@ typedef struct vec0_vtab vec0_vtab; #define VEC0_MAX_VECTOR_COLUMNS 16 #define VEC0_MAX_PARTITION_COLUMNS 4 +#define VEC0_MAX_AUXILIARY_COLUMNS 16 + #define SQLITE_VEC_VEC0_MAX_DIMENSIONS 8192 typedef enum { @@ -3276,7 +3363,10 @@ typedef enum { // partition key column, ie "user_id integer partition key" SQLITE_VEC0_USER_COLUMN_KIND_PARTITION = 2, - // TODO: metadata + metadata filters + // + SQLITE_VEC0_USER_COLUMN_KIND_AUXILIARY = 3, + + // TODO: metadata filters } vec0_user_column_kind; struct vec0_vtab { @@ -3295,6 +3385,9 @@ struct vec0_vtab { // number of defined PARTITION KEY columns. int numPartitionColumns; + // number of defined auxiliary columns + int numAuxiliaryColumns; + // Name of the schema the table exists on. // Must be freed with sqlite3_free() @@ -3314,10 +3407,9 @@ struct vec0_vtab { // contains enum vec0_user_column_kind values for up to // numVectorColumns + numPartitionColumns entries - uint8_t user_column_kinds[VEC0_MAX_VECTOR_COLUMNS + VEC0_MAX_PARTITION_COLUMNS]; - - uint8_t user_column_idxs[VEC0_MAX_VECTOR_COLUMNS + VEC0_MAX_PARTITION_COLUMNS]; + vec0_user_column_kind user_column_kinds[VEC0_MAX_VECTOR_COLUMNS + VEC0_MAX_PARTITION_COLUMNS + VEC0_MAX_AUXILIARY_COLUMNS]; + uint8_t user_column_idxs[VEC0_MAX_VECTOR_COLUMNS + VEC0_MAX_PARTITION_COLUMNS + VEC0_MAX_AUXILIARY_COLUMNS]; // Name of all the vector chunk shadow tables. // Ex '_vector_chunks00' @@ -3327,6 +3419,7 @@ struct vec0_vtab { struct VectorColumnDefinition vector_columns[VEC0_MAX_VECTOR_COLUMNS]; struct Vec0PartitionColumnDefinition paritition_columns[VEC0_MAX_PARTITION_COLUMNS]; + struct Vec0AuxiliaryColumnDefinition auxiliary_columns[VEC0_MAX_AUXILIARY_COLUMNS]; int chunk_size; @@ -3432,7 +3525,7 @@ void vec0_free(vec0_vtab *p) { } int vec0_num_defined_user_columns(vec0_vtab *p) { - return p->numVectorColumns + p->numPartitionColumns; + return p->numVectorColumns + p->numPartitionColumns + p->numAuxiliaryColumns; } /** @@ -3495,6 +3588,25 @@ int vec0_column_idx_to_partition_idx(vec0_vtab *pVtab, int column_idx) { return pVtab->user_column_idxs[column_idx - VEC0_COLUMN_USERN_START]; } +/** + * Returns 1 if the given column-based index is a auxiliary column, + * 0 otherwise. + */ +int vec0_column_idx_is_auxiliary(vec0_vtab *pVtab, int column_idx) { + return column_idx >= VEC0_COLUMN_USERN_START && + column_idx <= (VEC0_COLUMN_USERN_START + vec0_num_defined_user_columns(pVtab) - 1) && + pVtab->user_column_kinds[column_idx - VEC0_COLUMN_USERN_START] == SQLITE_VEC0_USER_COLUMN_KIND_AUXILIARY; +} + +/** + * Returns the auxiliary column index of the given user column index. + * ONLY call if validated with vec0_column_idx_to_partition_idx before + */ +int vec0_column_idx_to_auxiliary_idx(vec0_vtab *pVtab, int column_idx) { + UNUSED_PARAMETER(pVtab); + return pVtab->user_column_idxs[column_idx - VEC0_COLUMN_USERN_START]; +} + /** * @brief Retrieve the chunk_id, chunk_offset, and possible "id" value * of a vec0_vtab row with the provided rowid @@ -3771,6 +3883,45 @@ int vec0_get_partition_value_for_rowid(vec0_vtab *pVtab, i64 rowid, int partitio } +/** + * @brief Get the value of an auxiliary column for the given rowid + * + * @param pVtab vec0_vtab + * @param rowid the rowid of the row to lookup + * @param auxiliary_idx aux index of the column we care about + * @param outValue Output sqlite3_value to store + * @return int SQLITE_OK on success, error code otherwise + */ +int vec0_get_auxiliary_value_for_rowid(vec0_vtab *pVtab, i64 rowid, int auxiliary_idx, sqlite3_value ** outValue) { + int rc; + sqlite3_stmt * stmt = NULL; + char * zSql = sqlite3_mprintf("SELECT value%02d FROM " VEC0_SHADOW_AUXILIARY_NAME " WHERE rowid = ?", auxiliary_idx, pVtab->schemaName, pVtab->tableName); + if(!zSql) { + return SQLITE_NOMEM; + } + rc = sqlite3_prepare_v2(pVtab->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if(rc != SQLITE_OK) { + return rc; + } + sqlite3_bind_int64(stmt, 1, rowid); + rc = sqlite3_step(stmt); + if(rc != SQLITE_ROW) { + rc = SQLITE_ERROR; + goto done; + } + *outValue = sqlite3_value_dup(sqlite3_column_value(stmt, 0)); + if(!*outValue) { + rc = SQLITE_NOMEM; + goto done; + } + rc = SQLITE_OK; + + done: + sqlite3_finalize(stmt); + return rc; +} + int vec0_get_latest_chunk_rowid(vec0_vtab *p, i64 *chunk_rowid, sqlite3_value ** partitionKeyValues) { int rc; const char *zSql; @@ -4247,6 +4398,7 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, int chunk_size = -1; int numVectorColumns = 0; int numPartitionColumns = 0; + int numAuxiliaryColumns = 0; int user_column_idx = 0; // track if a "primary key" column is defined @@ -4257,6 +4409,7 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, for (int i = 3; i < argc; i++) { struct VectorColumnDefinition vecColumn; struct Vec0PartitionColumnDefinition partitionColumn; + struct Vec0AuxiliaryColumnDefinition auxColumn; char *cName = NULL; int cNameLength; int cType; @@ -4339,6 +4492,33 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, continue; } + // Scenario #4: Constructor argument is a auxiliary column definition, ie `+contents text` + rc = vec0_parse_auxiliary_column_definition(argv[i], strlen(argv[i]), &cName, + &cNameLength, &cType); + if(rc == SQLITE_OK) { + if (numAuxiliaryColumns >= VEC0_MAX_AUXILIARY_COLUMNS) { + *pzErr = sqlite3_mprintf( + VEC_CONSTRUCTOR_ERROR + "More than %d auxiliary columns were provided", + VEC0_MAX_AUXILIARY_COLUMNS); + goto error; + } + auxColumn.type = cType; + auxColumn.name_length = cNameLength; + auxColumn.name = sqlite3_mprintf("%.*s", cNameLength, cName); + if(!auxColumn.name) { + rc = SQLITE_NOMEM; + goto error; + } + + pNew->user_column_kinds[user_column_idx] = SQLITE_VEC0_USER_COLUMN_KIND_AUXILIARY; + pNew->user_column_idxs[user_column_idx] = numAuxiliaryColumns; + memcpy(&pNew->auxiliary_columns[numAuxiliaryColumns], &auxColumn, sizeof(auxColumn)); + numAuxiliaryColumns++; + user_column_idx++; + continue; + } + // Scenario #4: Constructor argument is a table-level option, ie `chunk_size` char *key; @@ -4406,7 +4586,7 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, } else { sqlite3_str_appendall(createStr, "rowid, "); } - for (int i = 0; i < numVectorColumns + numPartitionColumns; i++) { + for (int i = 0; i < numVectorColumns + numPartitionColumns + numAuxiliaryColumns; i++) { switch(pNew->user_column_kinds[i]) { case SQLITE_VEC0_USER_COLUMN_KIND_VECTOR: { int vector_idx = pNew->user_column_idxs[i]; @@ -4422,6 +4602,13 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, pNew->paritition_columns[partition_idx].name); break; } + case SQLITE_VEC0_USER_COLUMN_KIND_AUXILIARY: { + int auxiliary_idx = pNew->user_column_idxs[i]; + sqlite3_str_appendf(createStr, "\"%.*w\", ", + pNew->auxiliary_columns[auxiliary_idx].name_length, + pNew->auxiliary_columns[auxiliary_idx].name); + break; + } } } @@ -4465,6 +4652,8 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, } pNew->numVectorColumns = numVectorColumns; pNew->numPartitionColumns = numPartitionColumns; + pNew->numAuxiliaryColumns = numAuxiliaryColumns; + for (int i = 0; i < pNew->numVectorColumns; i++) { pNew->shadowVectorChunksNames[i] = sqlite3_mprintf("%s_vector_chunks%02d", tableName, i); @@ -4551,6 +4740,30 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, } sqlite3_finalize(stmt); } + + if(pNew->numAuxiliaryColumns > 0) { + sqlite3_stmt * stmt; + sqlite3_str * s = sqlite3_str_new(NULL); + sqlite3_str_appendf(s, "CREATE TABLE " VEC0_SHADOW_AUXILIARY_NAME "( rowid integer PRIMARY KEY ", pNew->schemaName, pNew->tableName); + for(int i = 0; i < pNew->numAuxiliaryColumns; i++) { + sqlite3_str_appendf(s, ", value%02d", i); + } + sqlite3_str_appendall(s, ")"); + char *zSql = sqlite3_str_finish(s); + if(!zSql) { + goto error; + } + rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, NULL); + if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) { + sqlite3_finalize(stmt); + *pzErr = sqlite3_mprintf( + "Could not create auxiliary shadow table: %s", + sqlite3_errmsg(db)); + + goto error; + } + sqlite3_finalize(stmt); + } } *ppVtab = (sqlite3_vtab *)pNew; @@ -4621,6 +4834,18 @@ static int vec0Destroy(sqlite3_vtab *pVtab) { } sqlite3_finalize(stmt); } + + if(p->numAuxiliaryColumns > 0) { + zSql = sqlite3_mprintf("DROP TABLE " VEC0_SHADOW_AUXILIARY_NAME, p->schemaName, p->tableName); + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0); + sqlite3_free((void *)zSql); + if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) { + rc = SQLITE_ERROR; + goto done; + } + sqlite3_finalize(stmt); + } + stmt = NULL; rc = SQLITE_OK; @@ -4697,6 +4922,7 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { int iRowidTerm = -1; int iKTerm = -1; int iRowidInTerm = -1; + int hasAuxConstraint = 0; #ifdef SQLITE_VEC_DEBUG printf("pIdxInfo->nOrderBy=%d, pIdxInfo->nConstraint=%d\n", pIdxInfo->nOrderBy, pIdxInfo->nConstraint); @@ -4751,6 +4977,11 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { if (op == SQLITE_INDEX_CONSTRAINT_EQ && iColumn == vec0_column_k_idx(p)) { iKTerm = i; } + if( + (op != SQLITE_INDEX_CONSTRAINT_LIMIT && op != SQLITE_INDEX_CONSTRAINT_OFFSET) + && vec0_column_idx_is_auxiliary(p, iColumn)) { + hasAuxConstraint = 1; + } } sqlite3_str *idxStr = sqlite3_str_new(NULL); @@ -4793,6 +5024,13 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { } } + if(hasAuxConstraint) { + // IMP: V25623_09693 + vtab_set_error(pVTab, "An illegal WHERE constraint was provided on a vec0 auxiliary column in a KNN query."); + rc = SQLITE_ERROR; + goto done; + } + sqlite3_str_appendchar(idxStr, 1, VEC0_QUERY_PLAN_KNN); int argvIndex = 1; @@ -5865,6 +6103,18 @@ static int vec0Column_fullscan(vec0_vtab *pVtab, vec0_cursor *pCur, int rc = vec0_get_partition_value_for_rowid(pVtab, rowid, partition_idx, &v); if(rc == SQLITE_OK) { sqlite3_result_value(context, v); + sqlite3_value_free(v); + }else { + sqlite3_result_error_code(context, rc); + } + } + else if(vec0_column_idx_is_auxiliary(pVtab, i)) { + int auxiliary_idx = vec0_column_idx_to_auxiliary_idx(pVtab, i); + sqlite3_value * v; + int rc = vec0_get_auxiliary_value_for_rowid(pVtab, rowid, auxiliary_idx, &v); + if(rc == SQLITE_OK) { + sqlite3_result_value(context, v); + sqlite3_value_free(v); }else { sqlite3_result_error_code(context, rc); } @@ -5910,6 +6160,22 @@ static int vec0Column_point(vec0_vtab *pVtab, vec0_cursor *pCur, int rc = vec0_get_partition_value_for_rowid(pVtab, rowid, partition_idx, &v); if(rc == SQLITE_OK) { sqlite3_result_value(context, v); + sqlite3_value_free(v); + }else { + sqlite3_result_error_code(context, rc); + } + } + else if(vec0_column_idx_is_auxiliary(pVtab, i)) { + if(sqlite3_vtab_nochange(context)) { + return SQLITE_OK; + } + i64 rowid = pCur->point_data->rowid; + int auxiliary_idx = vec0_column_idx_to_auxiliary_idx(pVtab, i); + sqlite3_value * v; + int rc = vec0_get_auxiliary_value_for_rowid(pVtab, rowid, auxiliary_idx, &v); + if(rc == SQLITE_OK) { + sqlite3_result_value(context, v); + sqlite3_value_free(v); }else { sqlite3_result_error_code(context, rc); } @@ -5956,6 +6222,19 @@ static int vec0Column_knn(vec0_vtab *pVtab, vec0_cursor *pCur, int rc = vec0_get_partition_value_for_rowid(pVtab, rowid, partition_idx, &v); if(rc == SQLITE_OK) { sqlite3_result_value(context, v); + sqlite3_value_free(v); + }else { + sqlite3_result_error_code(context, rc); + } + } + else if(vec0_column_idx_is_auxiliary(pVtab, i)) { + int auxiliary_idx = vec0_column_idx_to_auxiliary_idx(pVtab, i); + i64 rowid = pCur->knn_data->rowids[pCur->knn_data->current_idx]; + sqlite3_value * v; + int rc = vec0_get_auxiliary_value_for_rowid(pVtab, rowid, auxiliary_idx, &v); + if(rc == SQLITE_OK) { + sqlite3_result_value(context, v); + sqlite3_value_free(v); }else { sqlite3_result_error_code(context, rc); } @@ -6434,6 +6713,67 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, } } + if(p->numAuxiliaryColumns > 0) { + sqlite3_stmt *stmt; + sqlite3_str * s = sqlite3_str_new(NULL); + sqlite3_str_appendf(s, "INSERT INTO " VEC0_SHADOW_AUXILIARY_NAME "(", p->schemaName, p->tableName); + for(int i = 0; i < p->numAuxiliaryColumns; i++) { + if(i!=0) { + sqlite3_str_appendchar(s, 1, ','); + } + sqlite3_str_appendf(s, "value%02d", i); + } + sqlite3_str_appendall(s, ") VALUES ("); + for(int i = 0; i < p->numAuxiliaryColumns; i++) { + if(i!=0) { + sqlite3_str_appendchar(s, 1, ','); + } + sqlite3_str_appendchar(s, 1, '?'); + } + sqlite3_str_appendall(s, ")"); + char * zSql = sqlite3_str_finish(s); + // TODO double check error handling ehre + if(!zSql) { + rc = SQLITE_NOMEM; + goto cleanup; + } + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + if(rc != SQLITE_OK) { + goto cleanup; + } + + for (int i = 0; i < vec0_num_defined_user_columns(p); i++) { + if(p->user_column_kinds[i] != SQLITE_VEC0_USER_COLUMN_KIND_AUXILIARY) { + continue; + } + int auxiliary_key_idx = p->user_column_idxs[i]; + sqlite3_value * v = argv[2+VEC0_COLUMN_USERN_START + i]; + int v_type = sqlite3_value_type(v); + if(v_type != SQLITE_NULL && (v_type != p->auxiliary_columns[auxiliary_key_idx].type)) { + sqlite3_finalize(stmt); + rc = SQLITE_ERROR; + vtab_set_error( + pVTab, + "Auxiliary column type mismatch: The auxiliary column %.*s has type %s, but %s was provided.", + p->auxiliary_columns[auxiliary_key_idx].name_length, + p->auxiliary_columns[auxiliary_key_idx].name, + type_name(p->auxiliary_columns[auxiliary_key_idx].type), + type_name(v_type) + ); + goto cleanup; + } + sqlite3_bind_value(stmt, 1 + auxiliary_key_idx, v); + } + + rc = sqlite3_step(stmt); + if(rc != SQLITE_DONE) { + sqlite3_finalize(stmt); + rc = SQLITE_ERROR; + goto cleanup; + } + sqlite3_finalize(stmt); + } + // read all the inserted vectors into vectorDatas, validate their lengths. for (int i = 0; i < vec0_num_defined_user_columns(p); i++) { if(p->user_column_kinds[i] != SQLITE_VEC0_USER_COLUMN_KIND_VECTOR) { @@ -6634,6 +6974,34 @@ cleanup: return rc; } +int vec0Update_Delete_DeleteAux(vec0_vtab *p, i64 rowid) { + int rc; + sqlite3_stmt *stmt = NULL; + + char *zSql = + sqlite3_mprintf("DELETE FROM " VEC0_SHADOW_AUXILIARY_NAME " WHERE rowid = ?", + p->schemaName, p->tableName); + if (!zSql) { + return SQLITE_NOMEM; + } + + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) { + goto cleanup; + } + sqlite3_bind_int64(stmt, 1, rowid); + rc = sqlite3_step(stmt); + if (rc != SQLITE_DONE) { + goto cleanup; + } + rc = SQLITE_OK; + +cleanup: + sqlite3_finalize(stmt); + return rc; +} + int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite3_value *idValue) { vec0_vtab *p = (vec0_vtab *)pVTab; int rc; @@ -6679,6 +7047,36 @@ int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite3_value *idValue) { return rc; } + // 6. delete any auxiliary rows + if(p->numAuxiliaryColumns > 0) { + rc = vec0Update_Delete_DeleteAux(p, rowid); + if (rc != SQLITE_OK) { + return rc; + } + } + + return SQLITE_OK; +} + +int vec0Update_UpdateAuxColumn(vec0_vtab *p, int auxiliary_column_idx, sqlite3_value * value, i64 rowid) { + int rc; + sqlite3_stmt *stmt; + const char * zSql = sqlite3_mprintf("UPDATE " VEC0_SHADOW_AUXILIARY_NAME " SET value%02d = ? WHERE rowid = ?", p->schemaName, p->tableName, auxiliary_column_idx); + if(!zSql) { + return SQLITE_NOMEM; + } + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + if(rc != SQLITE_OK) { + return rc; + } + sqlite3_bind_value(stmt, 1, value); + sqlite3_bind_int64(stmt, 2, rowid); + rc = sqlite3_step(stmt); + if(rc != SQLITE_DONE) { + sqlite3_finalize(stmt); + return SQLITE_ERROR; + } + sqlite3_finalize(stmt); return SQLITE_OK; } @@ -6806,7 +7204,23 @@ int vec0Update_Update(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv) { return SQLITE_ERROR; } - // 3) iterate over all new vectors, update the vectors + // 3) handle auxiliary column updates + for (int i = 0; i < vec0_num_defined_user_columns(p); i++) { + if(p->user_column_kinds[i] != SQLITE_VEC0_USER_COLUMN_KIND_AUXILIARY) { + continue; + } + int auxiliary_column_idx = p->user_column_idxs[i]; + sqlite3_value * value = argv[2+VEC0_COLUMN_USERN_START + i]; + if(sqlite3_value_nochange(value)) { + continue; + } + rc = vec0Update_UpdateAuxColumn(p, auxiliary_column_idx, value, rowid); + if(rc != SQLITE_OK) { + return SQLITE_ERROR; + } + } + + // 4) iterate over all new vectors, update the vectors for (int i = 0; i < vec0_num_defined_user_columns(p); i++) { if(p->user_column_kinds[i] != SQLITE_VEC0_USER_COLUMN_KIND_VECTOR) { continue; @@ -6857,7 +7271,7 @@ static int vec0Update(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, } static int vec0ShadowName(const char *zName) { - static const char *azName[] = {"rowids", "chunks", "vector_chunks"}; + static const char *azName[] = {"rowids", "chunks", "auxiliary", "vector_chunks"}; for (size_t i = 0; i < sizeof(azName) / sizeof(azName[0]); i++) { if (sqlite3_stricmp(zName, azName[i]) == 0) diff --git a/test.sql b/test.sql index 7434207..e9a64a8 100644 --- a/test.sql +++ b/test.sql @@ -1,9 +1,30 @@ + .load dist/vec0 .echo on .bail on .mode qbox +create virtual table vec_chunks using vec0( + chunk_id integer primary key, + contents_embedding float[1], + +contents text +); +insert into vec_chunks(chunk_id, contents_embedding, contents) values + (1, '[1]', 'alex'), + (2, '[2]', 'brian'), + (3, '[3]', 'craig'), + (4, '[4]', 'dylan'); + +select * from vec_chunks; + +select chunk_id, contents, distance +from vec_chunks +where contents_embedding match '[5]' +and k = 3; + +.exit + create virtual table v using vec0(a float[1]); select count(*) from v_chunks; insert into v(a) values ('[1.11]'); diff --git a/tests/__snapshots__/test-auxiliary.ambr b/tests/__snapshots__/test-auxiliary.ambr new file mode 100644 index 0000000..eb84f0f --- /dev/null +++ b/tests/__snapshots__/test-auxiliary.ambr @@ -0,0 +1,642 @@ +# serializer version: 1 +# name: test_constructor_limit[max 16 auxiliary columns] + dict({ + 'error': 'OperationalError', + 'message': 'vec0 constructor error: More than 16 auxiliary columns were provided', + }) +# --- +# name: test_deletes + OrderedDict({ + 'sql': 'select rowid, * from v', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'vector': b'\x00\x00\x80?', + 'name': 'alex', + }), + OrderedDict({ + 'rowid': 2, + 'vector': b'\x00\x00\x00@', + 'name': 'brian', + }), + OrderedDict({ + 'rowid': 3, + 'vector': b'\x00\x00@@', + 'name': 'craig', + }), + ]), + }) +# --- +# name: test_deletes.1 + dict({ + 'v_auxiliary': OrderedDict({ + 'sql': 'select * from v_auxiliary', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'value00': 'alex', + }), + OrderedDict({ + 'rowid': 2, + 'value00': 'brian', + }), + OrderedDict({ + 'rowid': 3, + 'value00': 'craig', + }), + ]), + }), + 'v_chunks': OrderedDict({ + 'sql': 'select * from v_chunks', + 'rows': list([ + OrderedDict({ + 'chunk_id': 1, + 'size': 8, + 'validity': b'\x07', + 'rowids': b'\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', + }), + ]), + }), + 'v_rowids': OrderedDict({ + 'sql': 'select * from v_rowids', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'id': None, + 'chunk_id': 1, + 'chunk_offset': 0, + }), + OrderedDict({ + 'rowid': 2, + 'id': None, + 'chunk_id': 1, + 'chunk_offset': 1, + }), + OrderedDict({ + 'rowid': 3, + 'id': None, + 'chunk_id': 1, + 'chunk_offset': 2, + }), + ]), + }), + 'v_vector_chunks00': OrderedDict({ + 'sql': 'select * from v_vector_chunks00', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'vectors': b'\x00\x00\x80?\x00\x00\x00@\x00\x00@@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', + }), + ]), + }), + }) +# --- +# name: test_deletes.2 + OrderedDict({ + 'sql': 'delete from v where rowid = 1', + 'rows': list([ + ]), + }) +# --- +# name: test_deletes.3 + OrderedDict({ + 'sql': 'select rowid, * from v', + 'rows': list([ + OrderedDict({ + 'rowid': 2, + 'vector': b'\x00\x00\x00@', + 'name': 'brian', + }), + OrderedDict({ + 'rowid': 3, + 'vector': b'\x00\x00@@', + 'name': 'craig', + }), + ]), + }) +# --- +# name: test_deletes.4 + dict({ + 'v_auxiliary': OrderedDict({ + 'sql': 'select * from v_auxiliary', + 'rows': list([ + OrderedDict({ + 'rowid': 2, + 'value00': 'brian', + }), + OrderedDict({ + 'rowid': 3, + 'value00': 'craig', + }), + ]), + }), + 'v_chunks': OrderedDict({ + 'sql': 'select * from v_chunks', + 'rows': list([ + OrderedDict({ + 'chunk_id': 1, + 'size': 8, + 'validity': b'\x06', + 'rowids': b'\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', + }), + ]), + }), + 'v_rowids': OrderedDict({ + 'sql': 'select * from v_rowids', + 'rows': list([ + OrderedDict({ + 'rowid': 2, + 'id': None, + 'chunk_id': 1, + 'chunk_offset': 1, + }), + OrderedDict({ + 'rowid': 3, + 'id': None, + 'chunk_id': 1, + 'chunk_offset': 2, + }), + ]), + }), + 'v_vector_chunks00': OrderedDict({ + 'sql': 'select * from v_vector_chunks00', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'vectors': b'\x00\x00\x80?\x00\x00\x00@\x00\x00@@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', + }), + ]), + }), + }) +# --- +# name: test_knn + OrderedDict({ + 'sql': 'select * from v', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'vector': b'\x00\x00\x80?', + 'name': 'alex', + }), + OrderedDict({ + 'rowid': 2, + 'vector': b'\x00\x00\x00@', + 'name': 'brian', + }), + OrderedDict({ + 'rowid': 3, + 'vector': b'\x00\x00@@', + 'name': 'craig', + }), + ]), + }) +# --- +# name: test_knn[illegal KNN w/ aux] + dict({ + 'error': 'OperationalError', + 'message': 'An illegal WHERE constraint was provided on a vec0 auxiliary column in a KNN query.', + }) +# --- +# name: test_knn[legal KNN w/ aux] + OrderedDict({ + 'sql': "select *, distance from v where vector match '[5]' and k = 10", + 'rows': list([ + OrderedDict({ + 'rowid': 3, + 'vector': b'\x00\x00@@', + 'name': 'craig', + 'distance': 2.0, + }), + OrderedDict({ + 'rowid': 2, + 'vector': b'\x00\x00\x00@', + 'name': 'brian', + 'distance': 3.0, + }), + OrderedDict({ + 'rowid': 1, + 'vector': b'\x00\x00\x80?', + 'name': 'alex', + 'distance': 4.0, + }), + ]), + }) +# --- +# name: test_normal.1 + OrderedDict({ + 'sql': 'select * from v', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'a': b'\x11\x11\x11\x11', + 'name': 'alex', + }), + OrderedDict({ + 'rowid': 2, + 'a': b'""""', + 'name': 'brian', + }), + OrderedDict({ + 'rowid': 3, + 'a': b'3333', + 'name': 'craig', + }), + ]), + }) +# --- +# name: test_normal.2 + dict({ + 'v_auxiliary': OrderedDict({ + 'sql': 'select * from v_auxiliary', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'value00': 'alex', + }), + OrderedDict({ + 'rowid': 2, + 'value00': 'brian', + }), + OrderedDict({ + 'rowid': 3, + 'value00': 'craig', + }), + ]), + }), + 'v_chunks': OrderedDict({ + 'sql': 'select * from v_chunks', + 'rows': list([ + OrderedDict({ + 'chunk_id': 1, + 'size': 8, + 'validity': b'\x07', + 'rowids': b'\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', + }), + ]), + }), + 'v_rowids': OrderedDict({ + 'sql': 'select * from v_rowids', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'id': None, + 'chunk_id': 1, + 'chunk_offset': 0, + }), + OrderedDict({ + 'rowid': 2, + 'id': None, + 'chunk_id': 1, + 'chunk_offset': 1, + }), + OrderedDict({ + 'rowid': 3, + 'id': None, + 'chunk_id': 1, + 'chunk_offset': 2, + }), + ]), + }), + 'v_vector_chunks00': OrderedDict({ + 'sql': 'select * from v_vector_chunks00', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'vectors': b'\x11\x11\x11\x11""""3333\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', + }), + ]), + }), + }) +# --- +# name: test_normal[sqlite_master post drop] + OrderedDict({ + 'sql': 'select * from sqlite_master order by name', + 'rows': list([ + OrderedDict({ + 'type': 'table', + 'name': 'sqlite_sequence', + 'tbl_name': 'sqlite_sequence', + 'rootpage': 3, + 'sql': 'CREATE TABLE sqlite_sequence(name,seq)', + }), + ]), + }) +# --- +# name: test_normal[sqlite_master] + OrderedDict({ + 'sql': 'select * from sqlite_master order by name', + 'rows': list([ + OrderedDict({ + 'type': 'index', + 'name': 'sqlite_autoindex_v_vector_chunks00_1', + 'tbl_name': 'v_vector_chunks00', + 'rootpage': 6, + 'sql': None, + }), + OrderedDict({ + 'type': 'table', + 'name': 'sqlite_sequence', + 'tbl_name': 'sqlite_sequence', + 'rootpage': 3, + 'sql': 'CREATE TABLE sqlite_sequence(name,seq)', + }), + OrderedDict({ + 'type': 'table', + 'name': 'v', + 'tbl_name': 'v', + 'rootpage': 0, + 'sql': 'CREATE VIRTUAL TABLE v using vec0(a float[1], +name text, chunk_size=8)', + }), + OrderedDict({ + 'type': 'table', + 'name': 'v_auxiliary', + 'tbl_name': 'v_auxiliary', + 'rootpage': 7, + 'sql': 'CREATE TABLE "v_auxiliary"( rowid integer PRIMARY KEY , value00)', + }), + OrderedDict({ + 'type': 'table', + 'name': 'v_chunks', + 'tbl_name': 'v_chunks', + 'rootpage': 2, + 'sql': 'CREATE TABLE "v_chunks"(chunk_id INTEGER PRIMARY KEY AUTOINCREMENT,size INTEGER NOT NULL,validity BLOB NOT NULL,rowids BLOB NOT NULL)', + }), + OrderedDict({ + 'type': 'table', + 'name': 'v_rowids', + 'tbl_name': 'v_rowids', + 'rootpage': 4, + 'sql': 'CREATE TABLE "v_rowids"(rowid INTEGER PRIMARY KEY AUTOINCREMENT,id,chunk_id INTEGER,chunk_offset INTEGER)', + }), + OrderedDict({ + 'type': 'table', + 'name': 'v_vector_chunks00', + 'tbl_name': 'v_vector_chunks00', + 'rootpage': 5, + 'sql': 'CREATE TABLE "v_vector_chunks00"(rowid PRIMARY KEY,vectors BLOB NOT NULL)', + }), + ]), + }) +# --- +# name: test_types + OrderedDict({ + 'sql': 'select * from v', + 'rows': list([ + ]), + }) +# --- +# name: test_types.1 + OrderedDict({ + 'sql': 'insert into v(vector, aux_int, aux_float, aux_text, aux_blob) values (?, ?, ?, ?, ?)', + 'rows': list([ + ]), + }) +# --- +# name: test_types.2 + OrderedDict({ + 'sql': 'select * from v', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'vector': b'\x11\x11\x11\x11', + 'aux_int': 1, + 'aux_float': 1.22, + 'aux_text': 'text', + 'aux_blob': b'blob', + }), + ]), + }) +# --- +# name: test_types.3 + dict({ + 'error': 'OperationalError', + 'message': 'Auxiliary column type mismatch: The auxiliary column aux_int has type INTEGER, but TEXT was provided.', + }) +# --- +# name: test_types.4 + dict({ + 'error': 'OperationalError', + 'message': 'Auxiliary column type mismatch: The auxiliary column aux_float has type FLOAT, but TEXT was provided.', + }) +# --- +# name: test_types.5 + dict({ + 'error': 'OperationalError', + 'message': 'Auxiliary column type mismatch: The auxiliary column aux_text has type TEXT, but INTEGER was provided.', + }) +# --- +# name: test_types.6 + dict({ + 'error': 'OperationalError', + 'message': 'Auxiliary column type mismatch: The auxiliary column aux_blob has type BLOB, but INTEGER was provided.', + }) +# --- +# name: test_types.7 + OrderedDict({ + 'sql': 'insert into v(vector, aux_int, aux_float, aux_text, aux_blob) values (?, ?, ?, ?, ?)', + 'rows': list([ + ]), + }) +# --- +# name: test_types.8 + OrderedDict({ + 'sql': 'select * from v', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'vector': b'\x11\x11\x11\x11', + 'aux_int': 1, + 'aux_float': 1.22, + 'aux_text': 'text', + 'aux_blob': b'blob', + }), + OrderedDict({ + 'rowid': 2, + 'vector': b'\x11\x11\x11\x11', + 'aux_int': None, + 'aux_float': None, + 'aux_text': None, + 'aux_blob': None, + }), + ]), + }) +# --- +# name: test_updates + OrderedDict({ + 'sql': 'select rowid, * from v', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'vector': b'\x00\x00\x80?', + 'name': 'alex', + }), + OrderedDict({ + 'rowid': 2, + 'vector': b'\x00\x00\x00@', + 'name': 'brian', + }), + OrderedDict({ + 'rowid': 3, + 'vector': b'\x00\x00@@', + 'name': 'craig', + }), + ]), + }) +# --- +# name: test_updates.1 + dict({ + 'v_auxiliary': OrderedDict({ + 'sql': 'select * from v_auxiliary', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'value00': 'alex', + }), + OrderedDict({ + 'rowid': 2, + 'value00': 'brian', + }), + OrderedDict({ + 'rowid': 3, + 'value00': 'craig', + }), + ]), + }), + 'v_chunks': OrderedDict({ + 'sql': 'select * from v_chunks', + 'rows': list([ + OrderedDict({ + 'chunk_id': 1, + 'size': 8, + 'validity': b'\x07', + 'rowids': b'\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', + }), + ]), + }), + 'v_rowids': OrderedDict({ + 'sql': 'select * from v_rowids', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'id': None, + 'chunk_id': 1, + 'chunk_offset': 0, + }), + OrderedDict({ + 'rowid': 2, + 'id': None, + 'chunk_id': 1, + 'chunk_offset': 1, + }), + OrderedDict({ + 'rowid': 3, + 'id': None, + 'chunk_id': 1, + 'chunk_offset': 2, + }), + ]), + }), + 'v_vector_chunks00': OrderedDict({ + 'sql': 'select * from v_vector_chunks00', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'vectors': b'\x00\x00\x80?\x00\x00\x00@\x00\x00@@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', + }), + ]), + }), + }) +# --- +# name: test_updates.2 + OrderedDict({ + 'sql': "update v set name = 'ALEX' where rowid = 1", + 'rows': list([ + ]), + }) +# --- +# name: test_updates.3 + OrderedDict({ + 'sql': 'select rowid, * from v', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'vector': b'\x00\x00\x80?', + 'name': 'ALEX', + }), + OrderedDict({ + 'rowid': 2, + 'vector': b'\x00\x00\x00@', + 'name': 'brian', + }), + OrderedDict({ + 'rowid': 3, + 'vector': b'\x00\x00@@', + 'name': 'craig', + }), + ]), + }) +# --- +# name: test_updates.4 + dict({ + 'v_auxiliary': OrderedDict({ + 'sql': 'select * from v_auxiliary', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'value00': 'ALEX', + }), + OrderedDict({ + 'rowid': 2, + 'value00': 'brian', + }), + OrderedDict({ + 'rowid': 3, + 'value00': 'craig', + }), + ]), + }), + 'v_chunks': OrderedDict({ + 'sql': 'select * from v_chunks', + 'rows': list([ + OrderedDict({ + 'chunk_id': 1, + 'size': 8, + 'validity': b'\x07', + 'rowids': b'\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', + }), + ]), + }), + 'v_rowids': OrderedDict({ + 'sql': 'select * from v_rowids', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'id': None, + 'chunk_id': 1, + 'chunk_offset': 0, + }), + OrderedDict({ + 'rowid': 2, + 'id': None, + 'chunk_id': 1, + 'chunk_offset': 1, + }), + OrderedDict({ + 'rowid': 3, + 'id': None, + 'chunk_id': 1, + 'chunk_offset': 2, + }), + ]), + }), + 'v_vector_chunks00': OrderedDict({ + 'sql': 'select * from v_vector_chunks00', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'vectors': b'\x00\x00\x80?\x00\x00\x00@\x00\x00@@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', + }), + ]), + }), + }) +# --- diff --git a/tests/test-auxiliary.py b/tests/test-auxiliary.py new file mode 100644 index 0000000..9f03cbd --- /dev/null +++ b/tests/test-auxiliary.py @@ -0,0 +1,155 @@ +import sqlite3 +from collections import OrderedDict + + +def test_constructor_limit(db, snapshot): + assert exec( + db, + f""" + create virtual table v using vec0( + {",".join([f"+aux{x} integer" for x in range(17)])} + v float[1] + ) + """, + ) == snapshot(name="max 16 auxiliary columns") + + +def test_normal(db, snapshot): + db.execute( + "create virtual table v using vec0(a float[1], +name text, chunk_size=8)" + ) + assert exec(db, "select * from sqlite_master order by name") == snapshot( + name="sqlite_master" + ) + + db.execute("insert into v(a, name) values (?, ?)", [b"\x11\x11\x11\x11", "alex"]) + db.execute("insert into v(a, name) values (?, ?)", [b"\x22\x22\x22\x22", "brian"]) + db.execute("insert into v(a, name) values (?, ?)", [b"\x33\x33\x33\x33", "craig"]) + + assert exec(db, "select * from v") == snapshot() + assert vec0_shadow_table_contents(db, "v") == snapshot() + + db.execute("drop table v;") + assert exec(db, "select * from sqlite_master order by name") == snapshot( + name="sqlite_master post drop" + ) + + +def test_types(db, snapshot): + db.execute( + """ + create virtual table v using vec0( + vector float[1], + +aux_int integer, + +aux_float float, + +aux_text text, + +aux_blob blob + ) + """ + ) + assert exec(db, "select * from v") == snapshot() + INSERT = "insert into v(vector, aux_int, aux_float, aux_text, aux_blob) values (?, ?, ?, ?, ?)" + + assert ( + exec(db, INSERT, [b"\x11\x11\x11\x11", 1, 1.22, "text", b"blob"]) == snapshot() + ) + assert exec(db, "select * from v") == snapshot() + + # bad types + assert ( + exec(db, INSERT, [b"\x11\x11\x11\x11", "not int", 1.2, "text", b"blob"]) + == snapshot() + ) + assert ( + exec(db, INSERT, [b"\x11\x11\x11\x11", 1, "not float", "text", b"blob"]) + == snapshot() + ) + assert exec(db, INSERT, [b"\x11\x11\x11\x11", 1, 1.2, 1, b"blob"]) == snapshot() + assert exec(db, INSERT, [b"\x11\x11\x11\x11", 1, 1.2, "text", 1]) == snapshot() + + # NULLs are totally chill + assert exec(db, INSERT, [b"\x11\x11\x11\x11", None, None, None, None]) == snapshot() + assert exec(db, "select * from v") == snapshot() + + +def test_updates(db, snapshot): + db.execute( + "create virtual table v using vec0(vector float[1], +name text, chunk_size=8)" + ) + db.executemany( + "insert into v(vector, name) values (?, ?)", + [("[1]", "alex"), ("[2]", "brian"), ("[3]", "craig")], + ) + assert exec(db, "select rowid, * from v") == snapshot() + assert vec0_shadow_table_contents(db, "v") == snapshot() + + assert exec(db, "update v set name = 'ALEX' where rowid = 1") == snapshot() + assert exec(db, "select rowid, * from v") == snapshot() + assert vec0_shadow_table_contents(db, "v") == snapshot() + + +def test_deletes(db, snapshot): + db.execute( + "create virtual table v using vec0(vector float[1], +name text, chunk_size=8)" + ) + db.executemany( + "insert into v(vector, name) values (?, ?)", + [("[1]", "alex"), ("[2]", "brian"), ("[3]", "craig")], + ) + assert exec(db, "select rowid, * from v") == snapshot() + assert vec0_shadow_table_contents(db, "v") == snapshot() + + assert exec(db, "delete from v where rowid = 1") == snapshot() + assert exec(db, "select rowid, * from v") == snapshot() + assert vec0_shadow_table_contents(db, "v") == snapshot() + + +def test_knn(db, snapshot): + db.execute("create virtual table v using vec0(vector float[1], +name text)") + db.executemany( + "insert into v(vector, name) values (?, ?)", + [("[1]", "alex"), ("[2]", "brian"), ("[3]", "craig")], + ) + assert exec(db, "select * from v") == snapshot() + assert exec( + db, "select *, distance from v where vector match '[5]' and k = 10" + ) == snapshot(name="legal KNN w/ aux") + + # EVIDENCE-OF: V25623_09693 No aux constraint allowed on KNN queries + assert exec( + db, + "select *, distance from v where vector match '[5]' and k = 10 and name = 'alex'", + ) == snapshot(name="illegal KNN w/ aux") + + +def exec(db, sql, parameters=[]): + try: + rows = db.execute(sql, parameters).fetchall() + except (sqlite3.OperationalError, sqlite3.DatabaseError) as e: + return { + "error": e.__class__.__name__, + "message": str(e), + } + a = [] + for row in rows: + o = OrderedDict() + for k in row.keys(): + o[k] = row[k] + a.append(o) + result = OrderedDict() + result["sql"] = sql + result["rows"] = a + return result + + +def vec0_shadow_table_contents(db, v): + shadow_tables = [ + row[0] + for row in db.execute( + "select name from sqlite_master where name like ? order by 1", [f"{v}_%"] + ).fetchall() + ] + o = {} + for shadow_table in shadow_tables: + o[shadow_table] = exec(db, f"select * from {shadow_table}") + return o