small todo progress

This commit is contained in:
Alex Garcia 2024-05-12 00:16:10 -07:00
parent 9ecafe18e0
commit ab45c39f31

View file

@ -344,6 +344,14 @@ struct Array {
void *z;
};
/**
* @brief Initial an array with the given element size and capacity.
*
* @param array
* @param element_size
* @param init_capacity
* @return SQLITE_OK on success, error code on failure. Only error is SQLITE_NOMEM
*/
int array_init(struct Array *array, size_t element_size, size_t init_capacity) {
void *z = sqlite3_malloc(element_size * init_capacity);
if (!z) {
@ -389,6 +397,7 @@ static int fvec_from_value(sqlite3_value *value, f32 **vector,
size_t *dimensions, fvec_cleanup *cleanup,
char **pzErr) {
int value_type = sqlite3_value_type(value);
if (value_type == SQLITE_BLOB) {
const void *blob = sqlite3_value_blob(value);
int bytes = sqlite3_value_bytes(value);
@ -415,7 +424,9 @@ static int fvec_from_value(sqlite3_value *value, f32 **vector,
struct Array x;
int rc = array_init(&x, sizeof(f32), ceil(source_len / 2.0));
todo_assert(rc == SQLITE_OK);
if(rc != SQLITE_OK) {
return rc;
}
// advance leading whitespace to first '['
while (i < source_len) {
@ -924,7 +935,11 @@ static void vec_quantize_binary(sqlite3_context *context, int argc,
if (elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) {
u8 *out = sqlite3_malloc(dimensions / CHAR_BIT);
todo_assert(out);
if(!out) {
cleanup(vector);
sqlite3_result_error_code(context, SQLITE_NOMEM);
return;
}
for (size_t i = 0; i < dimensions; i++) {
int res = ((f32 *)vector)[i] > 0.0;
out[i / 8] |= (res << (i % 8));
@ -933,7 +948,11 @@ static void vec_quantize_binary(sqlite3_context *context, int argc,
sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_BIT);
} else if (elementType == SQLITE_VEC_ELEMENT_TYPE_INT8) {
u8 *out = sqlite3_malloc(dimensions / CHAR_BIT);
todo_assert(out);
if(!out) {
cleanup(vector);
sqlite3_result_error_code(context, SQLITE_NOMEM);
return;
}
for (size_t i = 0; i < dimensions; i++) {
int res = ((i8 *)vector)[i] > 0;
out[i / 8] |= (res << (i % 8));
@ -941,7 +960,8 @@ static void vec_quantize_binary(sqlite3_context *context, int argc,
sqlite3_result_blob(context, out, dimensions / CHAR_BIT, sqlite3_free);
sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_BIT);
} else {
todo("wut");
sqlite3_result_error(context, "Can only binary quantize float or int8 vectors", -1);
return;
}
}
@ -1236,7 +1256,12 @@ static void vec_normalize(sqlite3_context *context, int argc,
}
f32 *out = sqlite3_malloc(dimensions * sizeof(f32));
todo_assert(out);
if(!out) {
cleanup(vector);
sqlite3_result_error_code(context, SQLITE_NOMEM);
return;
}
f32 *v = (f32 *)vector;
f32 norm = 0;
@ -1250,6 +1275,7 @@ static void vec_normalize(sqlite3_context *context, int argc,
sqlite3_result_blob(context, out, dimensions * sizeof(f32), sqlite3_free);
sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_FLOAT32);
cleanup(vector);
}
static void _static_text_func(sqlite3_context *context, int argc,
@ -2704,7 +2730,7 @@ int vec0_get_chunk_position(vec0_vtab *p, i64 rowid, i64 *chunk_id,
}
/**
* @brief Adds a new chunk for the vec0 table, and the cooresponding vector
* @brief Adds a new chunk for the vec0 table, and the corresponding vector
* chunks.
*
* Inserts a new row into the _chunks table, with blank data, and uses that new
@ -2726,21 +2752,32 @@ int vec0_new_chunk(vec0_vtab *p, i64 *chunk_rowid) {
"(size, validity, rowids) "
"VALUES (?, ?, ?);",
p->schemaName, p->tableName);
todo_assert(zSql);
if(!zSql) {
return SQLITE_NOMEM;
}
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
sqlite3_free(zSql);
todo_assert(rc == SQLITE_OK);
if(rc != SQLITE_OK) {
return rc;
}
#ifdef SQLITE_VEC_THREADSAFE
sqlite3_mutex_enter(sqlite3_db_mutex(p->db));
#endif
rc = sqlite3_bind_int64(stmt, 1, p->chunk_size); // size
todo_assert(rc == SQLITE_OK);
if(rc != SQLITE_OK) {
#ifdef SQLITE_VEC_THREADSAFE
sqlite3_mutex_leave(sqlite3_db_mutex(p->db));
#endif
sqlite3_finalize(stmt);
return SQLITE_ERROR;
}
rc = sqlite3_bind_zeroblob(stmt, 2,
p->chunk_size / CHAR_BIT); // validity bitmap
todo_assert(rc == SQLITE_OK);
rc = sqlite3_bind_zeroblob(stmt, 3,
p->chunk_size * sizeof(i64)); // rowids
todo_assert(rc == SQLITE_OK);
rc = sqlite3_step(stmt);
todo_assert(rc == SQLITE_DONE);
rowid = sqlite3_last_insert_rowid(p->db);