From 44dcb3b3914633b1308ffde7de140072b60797b9 Mon Sep 17 00:00:00 2001 From: Alex Garcia Date: Wed, 12 Jun 2024 00:10:00 -0700 Subject: [PATCH] init pass vec_min vec_max vec_avg --- sqlite-vec.c | 275 +++++++++++++++++++++++++++++++++++++++++ test.sql | 21 ++++ tests/test-loadable.py | 37 ++++++ 3 files changed, 333 insertions(+) create mode 100644 test.sql diff --git a/sqlite-vec.c b/sqlite-vec.c index e318907..eca44ef 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -1444,6 +1444,256 @@ static void vec_normalize(sqlite3_context *context, int argc, cleanup(vector); } + +typedef struct vec_agg_context vec_agg_context; +struct vec_agg_context { + enum VectorElementType elementType; + void * vector; + size_t dimensions; + int n; +}; + +size_t vector_size(enum VectorElementType elementType, size_t dimensions) { + switch (elementType) { + case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: + return dimensions * sizeof(f32); + case SQLITE_VEC_ELEMENT_TYPE_INT8: + return dimensions * sizeof(i8); + case SQLITE_VEC_ELEMENT_TYPE_BIT: + return dimensions / CHAR_BIT; + } +} + +static void vec_minStep(sqlite3_context *context, int argc, sqlite3_value **argv){ + todo_assert(argc==1); + void *vector; + size_t dimensions; + vector_cleanup cleanup; + char *err; + enum VectorElementType elementType; + + int rc = vector_from_value(argv[0], &vector, &dimensions, &elementType, + &cleanup, &err); + if (rc != SQLITE_OK) { + sqlite3_result_error(context, err, -1); + sqlite3_free(err); + return; + } + + if ((elementType != SQLITE_VEC_ELEMENT_TYPE_FLOAT32) && (elementType != SQLITE_VEC_ELEMENT_TYPE_INT8)) { + sqlite3_result_error( + context, "only float32 or int8 vectors are supported in vec_min", -1); + goto finish; + } + vec_agg_context *p; + p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p)); + if(!p) { + sqlite3_result_error_nomem(context); + goto finish; + } + if(p->n) { + p->n++; + if(p->elementType != elementType) { + sqlite3_result_error(context, "vec_min(): vector type mismatch.", -1); + goto finish; + } + if(p->dimensions != dimensions) { + sqlite3_result_error(context, "vec_min(): vector dimensions do not match.", -1); + goto finish; + } + for(size_t i = 0; i < dimensions; i++) { + if(p->elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) { + if( ((f32*)vector)[i] < ((f32*)p->vector)[i]) { + ((f32*)p->vector)[i] = ((f32*)vector)[i]; + } + } + else if(p->elementType == SQLITE_VEC_ELEMENT_TYPE_INT8) { + if( ((i8*)vector)[i] < ((i8*)p->vector)[i]) { + ((i8*)p->vector)[i] = ((i8*)vector)[i]; + } + } + } + + }else { + size_t sz = vector_size(elementType, dimensions); + p->dimensions = dimensions; + p->elementType = elementType; + p->vector = sqlite3_malloc(sz); + if(!p->vector) { + sqlite3_result_error_nomem(context); + goto finish; + } + memset(p->vector, 0, sz); + memcpy(p->vector, vector, sz); + p->n = 1; + } + + finish: + cleanup(vector); +} + +static void vec_minFinal(sqlite3_context *context){ + vec_agg_context *p; + p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p)); + if(!p) { + sqlite3_result_error_nomem(context); + return; + } + sqlite3_result_blob(context, p->vector, vector_size(p->elementType, p->dimensions), sqlite3_free); + sqlite3_result_subtype(context, p->elementType); +} + +static void vec_maxStep(sqlite3_context *context, int argc, sqlite3_value **argv){ + todo_assert(argc==1); + void *vector; + size_t dimensions; + vector_cleanup cleanup; + char *err; + enum VectorElementType elementType; + + int rc = vector_from_value(argv[0], &vector, &dimensions, &elementType, + &cleanup, &err); + if (rc != SQLITE_OK) { + sqlite3_result_error(context, err, -1); + sqlite3_free(err); + return; + } + + if ((elementType != SQLITE_VEC_ELEMENT_TYPE_FLOAT32) && (elementType != SQLITE_VEC_ELEMENT_TYPE_INT8)) { + sqlite3_result_error( + context, "only float32 or int8 vectors are supported in vec_max", -1); + goto finish; + } + vec_agg_context *p; + p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p)); + if(!p) { + sqlite3_result_error_nomem(context); + goto finish; + } + if(p->n) { + p->n++; + if(p->elementType != elementType) { + sqlite3_result_error(context, "vec_max(): vector type mismatch.", -1); + goto finish; + } + if(p->dimensions != dimensions) { + sqlite3_result_error(context, "vec_max(): vector dimensions do not match.", -1); + goto finish; + } + for(size_t i = 0; i < dimensions; i++) { + if(p->elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) { + if( ((f32*)vector)[i] > ((f32*)p->vector)[i]) { + ((f32*)p->vector)[i] = ((f32*)vector)[i]; + } + } + else if(p->elementType == SQLITE_VEC_ELEMENT_TYPE_INT8) { + if( ((i8*)vector)[i] > ((i8*)p->vector)[i]) { + ((i8*)p->vector)[i] = ((i8*)vector)[i]; + } + } + } + + }else { + size_t sz = vector_size(elementType, dimensions); + p->dimensions = dimensions; + p->elementType = elementType; + p->vector = sqlite3_malloc(sz); + if(!p->vector) { + sqlite3_result_error_nomem(context); + goto finish; + } + memset(p->vector, 0, sz); + memcpy(p->vector, vector, sz); + p->n = 1; + } + + finish: + cleanup(vector); +} + +static void vec_maxFinal(sqlite3_context *context){ + vec_agg_context *p; + p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p)); + if(!p) { + sqlite3_result_error_nomem(context); + return; + } + sqlite3_result_blob(context, p->vector, vector_size(p->elementType, p->dimensions), sqlite3_free); + sqlite3_result_subtype(context, p->elementType); +} +static void vec_avgStep(sqlite3_context *context, int argc, sqlite3_value **argv){ + todo_assert(argc==1); + void *vector; + size_t dimensions; + vector_cleanup cleanup; + char *err; + enum VectorElementType elementType; + + int rc = vector_from_value(argv[0], &vector, &dimensions, &elementType, + &cleanup, &err); + if (rc != SQLITE_OK) { + sqlite3_result_error(context, err, -1); + sqlite3_free(err); + return; + } + + if ((elementType != SQLITE_VEC_ELEMENT_TYPE_FLOAT32)) { + sqlite3_result_error( + context, "only float32 vectors are supported in vec_avg", -1); + goto finish; + } + vec_agg_context *p; + p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p)); + if(!p) { + sqlite3_result_error_nomem(context); + goto finish; + } + if(p->n) { + p->n++; + if(p->elementType != elementType) { + sqlite3_result_error(context, "vec_avg(): vector type mismatch.", -1); + goto finish; + } + if(p->dimensions != dimensions) { + sqlite3_result_error(context, "vec_avg(): vector dimensions do not match.", -1); + goto finish; + } + for(size_t i = 0; i < dimensions; i++) { + ((f32*)p->vector)[i] += ((f32*)vector)[i]; + } + + }else { + size_t sz = vector_size(elementType, dimensions); + p->dimensions = dimensions; + p->elementType = elementType; + p->vector = sqlite3_malloc(sz); + if(!p->vector) { + sqlite3_result_error_nomem(context); + goto finish; + } + memset(p->vector, 0, sz); + memcpy(p->vector, vector, sz); + p->n = 1; + } + + finish: + cleanup(vector); +} + +static void vec_avgFinal(sqlite3_context *context){ + vec_agg_context *p; + p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p)); + if(!p) { + sqlite3_result_error_nomem(context); + return; + } + for(size_t i = 0; i < p->dimensions; i++) { + ((f32*)p->vector)[i] = ((f32*)p->vector)[i] / ((float)p->n); + } + sqlite3_result_blob(context, p->vector, vector_size(p->elementType, p->dimensions), sqlite3_free); + sqlite3_result_subtype(context, p->elementType); +} + static void _static_text_func(sqlite3_context *context, int argc, sqlite3_value **argv) { UNUSED_PARAMETER(argc); @@ -5726,6 +5976,20 @@ __declspec(dllexport) #endif // clang-format on }; + static const struct { + char *zFName; + void (*xStep)(sqlite3_context *, int, sqlite3_value **); + void (*xFinal)(sqlite3_context *); + int nArg; + int flags; + void *p; + } aAggregateFunc[] = { + // clang-format off + {"vec_min", vec_minStep, vec_minFinal, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL }, + {"vec_max", vec_maxStep, vec_maxFinal, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL }, + {"vec_avg", vec_avgStep, vec_avgFinal, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL }, + // clang-format on + }; #ifdef SQLITE_VEC_ENABLE_EXPERIMENTAL vec_static_blob_data * static_blob_data; @@ -5758,6 +6022,17 @@ __declspec(dllexport) return rc; } } + for (unsigned long i = 0; + i < sizeof(aAggregateFunc) / sizeof(aAggregateFunc[0]) && rc == SQLITE_OK; i++) { + rc = sqlite3_create_function_v2(db, aAggregateFunc[i].zFName, aAggregateFunc[i].nArg, + aAggregateFunc[i].flags, aAggregateFunc[i].p, NULL, + aAggregateFunc[i].xStep, aAggregateFunc[i].xFinal, NULL); + if (rc != SQLITE_OK) { + *pzErrMsg = sqlite3_mprintf("Error creating function %s: %s", + aAggregateFunc[i].zFName, sqlite3_errmsg(db)); + return rc; + } + } for (unsigned long i = 0; i < countof(aMod) && rc == SQLITE_OK; i++) { rc = sqlite3_create_module_v2(db, aMod[i].name, aMod[i].module, NULL, NULL); diff --git a/test.sql b/test.sql new file mode 100644 index 0000000..f384853 --- /dev/null +++ b/test.sql @@ -0,0 +1,21 @@ +.load dist/vec0 +.mode box +.header on + +create table test as + select value + from json_each('[ + [1.0, 2.0, -3.0], + [-1.0, 2.0, 3.0], + [1.0, 2.0, 3.0], + [1.0, 2.0, 3.0], + [1.0, -2.0, 3.0] + ]'); + +select + vec_to_json(vec_min(value)), + vec_to_json(vec_max(value)), + vec_to_json(vec_avg(value)) +from test; + + diff --git a/tests/test-loadable.py b/tests/test-loadable.py index 52f4c3d..1a67aee 100644 --- a/tests/test-loadable.py +++ b/tests/test-loadable.py @@ -86,6 +86,7 @@ def spread_args(args): FUNCTIONS = [ "vec_add", + "vec_avg", "vec_bit", "vec_debug", "vec_distance_cosine", @@ -94,6 +95,8 @@ FUNCTIONS = [ "vec_f32", "vec_int8", "vec_length", + "vec_max", + "vec_min", "vec_normalize", "vec_quantize_binary", "vec_quantize_i8", @@ -459,6 +462,40 @@ def test_vec_sub(): ): vec_sub(_int8([2]), _f32([1]), a="vec_int8(?)") +def test_vec_min(): + def vec_min(values, wrap="(vec_f32(?))"): + return db.execute( + "select vec_min(column1) from (values {})".format(", ".join([wrap] * len(values))), values + ).fetchone()[0] + + assert vec_min(["[1]", "[2]"]) == _f32([1]) + assert vec_min(["[1,2,3,4]", "[-5,-6,-7,-8]"]) == _f32([-5,-6,-7,-8]) + + # TODO: int8 tests, block binary vectors, overflowing + + with pytest.raises( + sqlite3.OperationalError, + match=re.escape("vec_min(): vector dimensions do not match."), + ): + vec_min(["[1]", "[2,3]"]) + +def test_vec_max(): + def vec_max(values, wrap="(vec_f32(?))"): + return db.execute( + "select vec_max(column1) from (values {})".format(", ".join([wrap] * len(values))), values + ).fetchone()[0] + + assert vec_max(["[1]", "[2]"]) == _f32([1]) + assert vec_max(["[1,2,3,4]", "[-5,-6,-7,-8]"]) == _f32([-5,-6,-7,-8]) + + # TODO: int8 tests, block binary vectors, overflowing + + with pytest.raises( + sqlite3.OperationalError, + match=re.escape("vec_min(): vector dimensions do not match."), + ): + vec_min(["[1]", "[2,3]"]) + def test_vec_to_json(): vec_to_json = lambda *args, input="?": db.execute(