init pass vec_min vec_max vec_avg

This commit is contained in:
Alex Garcia 2024-06-12 00:10:00 -07:00
parent 6875f7649c
commit 44dcb3b391
3 changed files with 333 additions and 0 deletions

View file

@ -1444,6 +1444,256 @@ static void vec_normalize(sqlite3_context *context, int argc,
cleanup(vector);
}
typedef struct vec_agg_context vec_agg_context;
struct vec_agg_context {
enum VectorElementType elementType;
void * vector;
size_t dimensions;
int n;
};
size_t vector_size(enum VectorElementType elementType, size_t dimensions) {
switch (elementType) {
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32:
return dimensions * sizeof(f32);
case SQLITE_VEC_ELEMENT_TYPE_INT8:
return dimensions * sizeof(i8);
case SQLITE_VEC_ELEMENT_TYPE_BIT:
return dimensions / CHAR_BIT;
}
}
static void vec_minStep(sqlite3_context *context, int argc, sqlite3_value **argv){
todo_assert(argc==1);
void *vector;
size_t dimensions;
vector_cleanup cleanup;
char *err;
enum VectorElementType elementType;
int rc = vector_from_value(argv[0], &vector, &dimensions, &elementType,
&cleanup, &err);
if (rc != SQLITE_OK) {
sqlite3_result_error(context, err, -1);
sqlite3_free(err);
return;
}
if ((elementType != SQLITE_VEC_ELEMENT_TYPE_FLOAT32) && (elementType != SQLITE_VEC_ELEMENT_TYPE_INT8)) {
sqlite3_result_error(
context, "only float32 or int8 vectors are supported in vec_min", -1);
goto finish;
}
vec_agg_context *p;
p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p));
if(!p) {
sqlite3_result_error_nomem(context);
goto finish;
}
if(p->n) {
p->n++;
if(p->elementType != elementType) {
sqlite3_result_error(context, "vec_min(): vector type mismatch.", -1);
goto finish;
}
if(p->dimensions != dimensions) {
sqlite3_result_error(context, "vec_min(): vector dimensions do not match.", -1);
goto finish;
}
for(size_t i = 0; i < dimensions; i++) {
if(p->elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) {
if( ((f32*)vector)[i] < ((f32*)p->vector)[i]) {
((f32*)p->vector)[i] = ((f32*)vector)[i];
}
}
else if(p->elementType == SQLITE_VEC_ELEMENT_TYPE_INT8) {
if( ((i8*)vector)[i] < ((i8*)p->vector)[i]) {
((i8*)p->vector)[i] = ((i8*)vector)[i];
}
}
}
}else {
size_t sz = vector_size(elementType, dimensions);
p->dimensions = dimensions;
p->elementType = elementType;
p->vector = sqlite3_malloc(sz);
if(!p->vector) {
sqlite3_result_error_nomem(context);
goto finish;
}
memset(p->vector, 0, sz);
memcpy(p->vector, vector, sz);
p->n = 1;
}
finish:
cleanup(vector);
}
static void vec_minFinal(sqlite3_context *context){
vec_agg_context *p;
p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p));
if(!p) {
sqlite3_result_error_nomem(context);
return;
}
sqlite3_result_blob(context, p->vector, vector_size(p->elementType, p->dimensions), sqlite3_free);
sqlite3_result_subtype(context, p->elementType);
}
static void vec_maxStep(sqlite3_context *context, int argc, sqlite3_value **argv){
todo_assert(argc==1);
void *vector;
size_t dimensions;
vector_cleanup cleanup;
char *err;
enum VectorElementType elementType;
int rc = vector_from_value(argv[0], &vector, &dimensions, &elementType,
&cleanup, &err);
if (rc != SQLITE_OK) {
sqlite3_result_error(context, err, -1);
sqlite3_free(err);
return;
}
if ((elementType != SQLITE_VEC_ELEMENT_TYPE_FLOAT32) && (elementType != SQLITE_VEC_ELEMENT_TYPE_INT8)) {
sqlite3_result_error(
context, "only float32 or int8 vectors are supported in vec_max", -1);
goto finish;
}
vec_agg_context *p;
p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p));
if(!p) {
sqlite3_result_error_nomem(context);
goto finish;
}
if(p->n) {
p->n++;
if(p->elementType != elementType) {
sqlite3_result_error(context, "vec_max(): vector type mismatch.", -1);
goto finish;
}
if(p->dimensions != dimensions) {
sqlite3_result_error(context, "vec_max(): vector dimensions do not match.", -1);
goto finish;
}
for(size_t i = 0; i < dimensions; i++) {
if(p->elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) {
if( ((f32*)vector)[i] > ((f32*)p->vector)[i]) {
((f32*)p->vector)[i] = ((f32*)vector)[i];
}
}
else if(p->elementType == SQLITE_VEC_ELEMENT_TYPE_INT8) {
if( ((i8*)vector)[i] > ((i8*)p->vector)[i]) {
((i8*)p->vector)[i] = ((i8*)vector)[i];
}
}
}
}else {
size_t sz = vector_size(elementType, dimensions);
p->dimensions = dimensions;
p->elementType = elementType;
p->vector = sqlite3_malloc(sz);
if(!p->vector) {
sqlite3_result_error_nomem(context);
goto finish;
}
memset(p->vector, 0, sz);
memcpy(p->vector, vector, sz);
p->n = 1;
}
finish:
cleanup(vector);
}
static void vec_maxFinal(sqlite3_context *context){
vec_agg_context *p;
p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p));
if(!p) {
sqlite3_result_error_nomem(context);
return;
}
sqlite3_result_blob(context, p->vector, vector_size(p->elementType, p->dimensions), sqlite3_free);
sqlite3_result_subtype(context, p->elementType);
}
static void vec_avgStep(sqlite3_context *context, int argc, sqlite3_value **argv){
todo_assert(argc==1);
void *vector;
size_t dimensions;
vector_cleanup cleanup;
char *err;
enum VectorElementType elementType;
int rc = vector_from_value(argv[0], &vector, &dimensions, &elementType,
&cleanup, &err);
if (rc != SQLITE_OK) {
sqlite3_result_error(context, err, -1);
sqlite3_free(err);
return;
}
if ((elementType != SQLITE_VEC_ELEMENT_TYPE_FLOAT32)) {
sqlite3_result_error(
context, "only float32 vectors are supported in vec_avg", -1);
goto finish;
}
vec_agg_context *p;
p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p));
if(!p) {
sqlite3_result_error_nomem(context);
goto finish;
}
if(p->n) {
p->n++;
if(p->elementType != elementType) {
sqlite3_result_error(context, "vec_avg(): vector type mismatch.", -1);
goto finish;
}
if(p->dimensions != dimensions) {
sqlite3_result_error(context, "vec_avg(): vector dimensions do not match.", -1);
goto finish;
}
for(size_t i = 0; i < dimensions; i++) {
((f32*)p->vector)[i] += ((f32*)vector)[i];
}
}else {
size_t sz = vector_size(elementType, dimensions);
p->dimensions = dimensions;
p->elementType = elementType;
p->vector = sqlite3_malloc(sz);
if(!p->vector) {
sqlite3_result_error_nomem(context);
goto finish;
}
memset(p->vector, 0, sz);
memcpy(p->vector, vector, sz);
p->n = 1;
}
finish:
cleanup(vector);
}
static void vec_avgFinal(sqlite3_context *context){
vec_agg_context *p;
p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p));
if(!p) {
sqlite3_result_error_nomem(context);
return;
}
for(size_t i = 0; i < p->dimensions; i++) {
((f32*)p->vector)[i] = ((f32*)p->vector)[i] / ((float)p->n);
}
sqlite3_result_blob(context, p->vector, vector_size(p->elementType, p->dimensions), sqlite3_free);
sqlite3_result_subtype(context, p->elementType);
}
static void _static_text_func(sqlite3_context *context, int argc,
sqlite3_value **argv) {
UNUSED_PARAMETER(argc);
@ -5726,6 +5976,20 @@ __declspec(dllexport)
#endif
// clang-format on
};
static const struct {
char *zFName;
void (*xStep)(sqlite3_context *, int, sqlite3_value **);
void (*xFinal)(sqlite3_context *);
int nArg;
int flags;
void *p;
} aAggregateFunc[] = {
// clang-format off
{"vec_min", vec_minStep, vec_minFinal, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL },
{"vec_max", vec_maxStep, vec_maxFinal, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL },
{"vec_avg", vec_avgStep, vec_avgFinal, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL },
// clang-format on
};
#ifdef SQLITE_VEC_ENABLE_EXPERIMENTAL
vec_static_blob_data * static_blob_data;
@ -5758,6 +6022,17 @@ __declspec(dllexport)
return rc;
}
}
for (unsigned long i = 0;
i < sizeof(aAggregateFunc) / sizeof(aAggregateFunc[0]) && rc == SQLITE_OK; i++) {
rc = sqlite3_create_function_v2(db, aAggregateFunc[i].zFName, aAggregateFunc[i].nArg,
aAggregateFunc[i].flags, aAggregateFunc[i].p, NULL,
aAggregateFunc[i].xStep, aAggregateFunc[i].xFinal, NULL);
if (rc != SQLITE_OK) {
*pzErrMsg = sqlite3_mprintf("Error creating function %s: %s",
aAggregateFunc[i].zFName, sqlite3_errmsg(db));
return rc;
}
}
for (unsigned long i = 0; i < countof(aMod) && rc == SQLITE_OK; i++) {
rc = sqlite3_create_module_v2(db, aMod[i].name, aMod[i].module, NULL, NULL);