diff --git a/sqlite-vec.c b/sqlite-vec.c index e7783ae..49d6fcf 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -344,6 +344,14 @@ struct Array { void *z; }; +/** + * @brief Initial an array with the given element size and capacity. + * + * @param array + * @param element_size + * @param init_capacity + * @return SQLITE_OK on success, error code on failure. Only error is SQLITE_NOMEM + */ int array_init(struct Array *array, size_t element_size, size_t init_capacity) { void *z = sqlite3_malloc(element_size * init_capacity); if (!z) { @@ -389,6 +397,7 @@ static int fvec_from_value(sqlite3_value *value, f32 **vector, size_t *dimensions, fvec_cleanup *cleanup, char **pzErr) { int value_type = sqlite3_value_type(value); + if (value_type == SQLITE_BLOB) { const void *blob = sqlite3_value_blob(value); int bytes = sqlite3_value_bytes(value); @@ -415,7 +424,9 @@ static int fvec_from_value(sqlite3_value *value, f32 **vector, struct Array x; int rc = array_init(&x, sizeof(f32), ceil(source_len / 2.0)); - todo_assert(rc == SQLITE_OK); + if(rc != SQLITE_OK) { + return rc; + } // advance leading whitespace to first '[' while (i < source_len) { @@ -924,7 +935,11 @@ static void vec_quantize_binary(sqlite3_context *context, int argc, if (elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) { u8 *out = sqlite3_malloc(dimensions / CHAR_BIT); - todo_assert(out); + if(!out) { + cleanup(vector); + sqlite3_result_error_code(context, SQLITE_NOMEM); + return; + } for (size_t i = 0; i < dimensions; i++) { int res = ((f32 *)vector)[i] > 0.0; out[i / 8] |= (res << (i % 8)); @@ -933,7 +948,11 @@ static void vec_quantize_binary(sqlite3_context *context, int argc, sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_BIT); } else if (elementType == SQLITE_VEC_ELEMENT_TYPE_INT8) { u8 *out = sqlite3_malloc(dimensions / CHAR_BIT); - todo_assert(out); + if(!out) { + cleanup(vector); + sqlite3_result_error_code(context, SQLITE_NOMEM); + return; + } for (size_t i = 0; i < dimensions; i++) { int res = ((i8 *)vector)[i] > 0; out[i / 8] |= (res << (i % 8)); @@ -941,7 +960,8 @@ static void vec_quantize_binary(sqlite3_context *context, int argc, sqlite3_result_blob(context, out, dimensions / CHAR_BIT, sqlite3_free); sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_BIT); } else { - todo("wut"); + sqlite3_result_error(context, "Can only binary quantize float or int8 vectors", -1); + return; } } @@ -1236,7 +1256,12 @@ static void vec_normalize(sqlite3_context *context, int argc, } f32 *out = sqlite3_malloc(dimensions * sizeof(f32)); - todo_assert(out); + if(!out) { + cleanup(vector); + sqlite3_result_error_code(context, SQLITE_NOMEM); + return; + } + f32 *v = (f32 *)vector; f32 norm = 0; @@ -1250,6 +1275,7 @@ static void vec_normalize(sqlite3_context *context, int argc, sqlite3_result_blob(context, out, dimensions * sizeof(f32), sqlite3_free); sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_FLOAT32); + cleanup(vector); } static void _static_text_func(sqlite3_context *context, int argc, @@ -2704,7 +2730,7 @@ int vec0_get_chunk_position(vec0_vtab *p, i64 rowid, i64 *chunk_id, } /** - * @brief Adds a new chunk for the vec0 table, and the cooresponding vector + * @brief Adds a new chunk for the vec0 table, and the corresponding vector * chunks. * * Inserts a new row into the _chunks table, with blank data, and uses that new @@ -2726,21 +2752,32 @@ int vec0_new_chunk(vec0_vtab *p, i64 *chunk_rowid) { "(size, validity, rowids) " "VALUES (?, ?, ?);", p->schemaName, p->tableName); - todo_assert(zSql); + if(!zSql) { + return SQLITE_NOMEM; + } rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql); - todo_assert(rc == SQLITE_OK); + if(rc != SQLITE_OK) { + return rc; + } #ifdef SQLITE_VEC_THREADSAFE sqlite3_mutex_enter(sqlite3_db_mutex(p->db)); #endif rc = sqlite3_bind_int64(stmt, 1, p->chunk_size); // size - todo_assert(rc == SQLITE_OK); + if(rc != SQLITE_OK) { + #ifdef SQLITE_VEC_THREADSAFE + sqlite3_mutex_leave(sqlite3_db_mutex(p->db)); + #endif + sqlite3_finalize(stmt); + return SQLITE_ERROR; + } rc = sqlite3_bind_zeroblob(stmt, 2, p->chunk_size / CHAR_BIT); // validity bitmap todo_assert(rc == SQLITE_OK); rc = sqlite3_bind_zeroblob(stmt, 3, p->chunk_size * sizeof(i64)); // rowids todo_assert(rc == SQLITE_OK); + rc = sqlite3_step(stmt); todo_assert(rc == SQLITE_DONE); rowid = sqlite3_last_insert_rowid(p->db);