init pass vec_min vec_max vec_avg

This commit is contained in:
Alex Garcia 2024-06-12 00:10:00 -07:00
parent 6875f7649c
commit 44dcb3b391
3 changed files with 333 additions and 0 deletions

View file

@ -1444,6 +1444,256 @@ static void vec_normalize(sqlite3_context *context, int argc,
cleanup(vector); cleanup(vector);
} }
typedef struct vec_agg_context vec_agg_context;
struct vec_agg_context {
enum VectorElementType elementType;
void * vector;
size_t dimensions;
int n;
};
size_t vector_size(enum VectorElementType elementType, size_t dimensions) {
switch (elementType) {
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32:
return dimensions * sizeof(f32);
case SQLITE_VEC_ELEMENT_TYPE_INT8:
return dimensions * sizeof(i8);
case SQLITE_VEC_ELEMENT_TYPE_BIT:
return dimensions / CHAR_BIT;
}
}
static void vec_minStep(sqlite3_context *context, int argc, sqlite3_value **argv){
todo_assert(argc==1);
void *vector;
size_t dimensions;
vector_cleanup cleanup;
char *err;
enum VectorElementType elementType;
int rc = vector_from_value(argv[0], &vector, &dimensions, &elementType,
&cleanup, &err);
if (rc != SQLITE_OK) {
sqlite3_result_error(context, err, -1);
sqlite3_free(err);
return;
}
if ((elementType != SQLITE_VEC_ELEMENT_TYPE_FLOAT32) && (elementType != SQLITE_VEC_ELEMENT_TYPE_INT8)) {
sqlite3_result_error(
context, "only float32 or int8 vectors are supported in vec_min", -1);
goto finish;
}
vec_agg_context *p;
p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p));
if(!p) {
sqlite3_result_error_nomem(context);
goto finish;
}
if(p->n) {
p->n++;
if(p->elementType != elementType) {
sqlite3_result_error(context, "vec_min(): vector type mismatch.", -1);
goto finish;
}
if(p->dimensions != dimensions) {
sqlite3_result_error(context, "vec_min(): vector dimensions do not match.", -1);
goto finish;
}
for(size_t i = 0; i < dimensions; i++) {
if(p->elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) {
if( ((f32*)vector)[i] < ((f32*)p->vector)[i]) {
((f32*)p->vector)[i] = ((f32*)vector)[i];
}
}
else if(p->elementType == SQLITE_VEC_ELEMENT_TYPE_INT8) {
if( ((i8*)vector)[i] < ((i8*)p->vector)[i]) {
((i8*)p->vector)[i] = ((i8*)vector)[i];
}
}
}
}else {
size_t sz = vector_size(elementType, dimensions);
p->dimensions = dimensions;
p->elementType = elementType;
p->vector = sqlite3_malloc(sz);
if(!p->vector) {
sqlite3_result_error_nomem(context);
goto finish;
}
memset(p->vector, 0, sz);
memcpy(p->vector, vector, sz);
p->n = 1;
}
finish:
cleanup(vector);
}
static void vec_minFinal(sqlite3_context *context){
vec_agg_context *p;
p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p));
if(!p) {
sqlite3_result_error_nomem(context);
return;
}
sqlite3_result_blob(context, p->vector, vector_size(p->elementType, p->dimensions), sqlite3_free);
sqlite3_result_subtype(context, p->elementType);
}
static void vec_maxStep(sqlite3_context *context, int argc, sqlite3_value **argv){
todo_assert(argc==1);
void *vector;
size_t dimensions;
vector_cleanup cleanup;
char *err;
enum VectorElementType elementType;
int rc = vector_from_value(argv[0], &vector, &dimensions, &elementType,
&cleanup, &err);
if (rc != SQLITE_OK) {
sqlite3_result_error(context, err, -1);
sqlite3_free(err);
return;
}
if ((elementType != SQLITE_VEC_ELEMENT_TYPE_FLOAT32) && (elementType != SQLITE_VEC_ELEMENT_TYPE_INT8)) {
sqlite3_result_error(
context, "only float32 or int8 vectors are supported in vec_max", -1);
goto finish;
}
vec_agg_context *p;
p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p));
if(!p) {
sqlite3_result_error_nomem(context);
goto finish;
}
if(p->n) {
p->n++;
if(p->elementType != elementType) {
sqlite3_result_error(context, "vec_max(): vector type mismatch.", -1);
goto finish;
}
if(p->dimensions != dimensions) {
sqlite3_result_error(context, "vec_max(): vector dimensions do not match.", -1);
goto finish;
}
for(size_t i = 0; i < dimensions; i++) {
if(p->elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) {
if( ((f32*)vector)[i] > ((f32*)p->vector)[i]) {
((f32*)p->vector)[i] = ((f32*)vector)[i];
}
}
else if(p->elementType == SQLITE_VEC_ELEMENT_TYPE_INT8) {
if( ((i8*)vector)[i] > ((i8*)p->vector)[i]) {
((i8*)p->vector)[i] = ((i8*)vector)[i];
}
}
}
}else {
size_t sz = vector_size(elementType, dimensions);
p->dimensions = dimensions;
p->elementType = elementType;
p->vector = sqlite3_malloc(sz);
if(!p->vector) {
sqlite3_result_error_nomem(context);
goto finish;
}
memset(p->vector, 0, sz);
memcpy(p->vector, vector, sz);
p->n = 1;
}
finish:
cleanup(vector);
}
static void vec_maxFinal(sqlite3_context *context){
vec_agg_context *p;
p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p));
if(!p) {
sqlite3_result_error_nomem(context);
return;
}
sqlite3_result_blob(context, p->vector, vector_size(p->elementType, p->dimensions), sqlite3_free);
sqlite3_result_subtype(context, p->elementType);
}
static void vec_avgStep(sqlite3_context *context, int argc, sqlite3_value **argv){
todo_assert(argc==1);
void *vector;
size_t dimensions;
vector_cleanup cleanup;
char *err;
enum VectorElementType elementType;
int rc = vector_from_value(argv[0], &vector, &dimensions, &elementType,
&cleanup, &err);
if (rc != SQLITE_OK) {
sqlite3_result_error(context, err, -1);
sqlite3_free(err);
return;
}
if ((elementType != SQLITE_VEC_ELEMENT_TYPE_FLOAT32)) {
sqlite3_result_error(
context, "only float32 vectors are supported in vec_avg", -1);
goto finish;
}
vec_agg_context *p;
p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p));
if(!p) {
sqlite3_result_error_nomem(context);
goto finish;
}
if(p->n) {
p->n++;
if(p->elementType != elementType) {
sqlite3_result_error(context, "vec_avg(): vector type mismatch.", -1);
goto finish;
}
if(p->dimensions != dimensions) {
sqlite3_result_error(context, "vec_avg(): vector dimensions do not match.", -1);
goto finish;
}
for(size_t i = 0; i < dimensions; i++) {
((f32*)p->vector)[i] += ((f32*)vector)[i];
}
}else {
size_t sz = vector_size(elementType, dimensions);
p->dimensions = dimensions;
p->elementType = elementType;
p->vector = sqlite3_malloc(sz);
if(!p->vector) {
sqlite3_result_error_nomem(context);
goto finish;
}
memset(p->vector, 0, sz);
memcpy(p->vector, vector, sz);
p->n = 1;
}
finish:
cleanup(vector);
}
static void vec_avgFinal(sqlite3_context *context){
vec_agg_context *p;
p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p));
if(!p) {
sqlite3_result_error_nomem(context);
return;
}
for(size_t i = 0; i < p->dimensions; i++) {
((f32*)p->vector)[i] = ((f32*)p->vector)[i] / ((float)p->n);
}
sqlite3_result_blob(context, p->vector, vector_size(p->elementType, p->dimensions), sqlite3_free);
sqlite3_result_subtype(context, p->elementType);
}
static void _static_text_func(sqlite3_context *context, int argc, static void _static_text_func(sqlite3_context *context, int argc,
sqlite3_value **argv) { sqlite3_value **argv) {
UNUSED_PARAMETER(argc); UNUSED_PARAMETER(argc);
@ -5726,6 +5976,20 @@ __declspec(dllexport)
#endif #endif
// clang-format on // clang-format on
}; };
static const struct {
char *zFName;
void (*xStep)(sqlite3_context *, int, sqlite3_value **);
void (*xFinal)(sqlite3_context *);
int nArg;
int flags;
void *p;
} aAggregateFunc[] = {
// clang-format off
{"vec_min", vec_minStep, vec_minFinal, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL },
{"vec_max", vec_maxStep, vec_maxFinal, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL },
{"vec_avg", vec_avgStep, vec_avgFinal, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL },
// clang-format on
};
#ifdef SQLITE_VEC_ENABLE_EXPERIMENTAL #ifdef SQLITE_VEC_ENABLE_EXPERIMENTAL
vec_static_blob_data * static_blob_data; vec_static_blob_data * static_blob_data;
@ -5758,6 +6022,17 @@ __declspec(dllexport)
return rc; return rc;
} }
} }
for (unsigned long i = 0;
i < sizeof(aAggregateFunc) / sizeof(aAggregateFunc[0]) && rc == SQLITE_OK; i++) {
rc = sqlite3_create_function_v2(db, aAggregateFunc[i].zFName, aAggregateFunc[i].nArg,
aAggregateFunc[i].flags, aAggregateFunc[i].p, NULL,
aAggregateFunc[i].xStep, aAggregateFunc[i].xFinal, NULL);
if (rc != SQLITE_OK) {
*pzErrMsg = sqlite3_mprintf("Error creating function %s: %s",
aAggregateFunc[i].zFName, sqlite3_errmsg(db));
return rc;
}
}
for (unsigned long i = 0; i < countof(aMod) && rc == SQLITE_OK; i++) { for (unsigned long i = 0; i < countof(aMod) && rc == SQLITE_OK; i++) {
rc = sqlite3_create_module_v2(db, aMod[i].name, aMod[i].module, NULL, NULL); rc = sqlite3_create_module_v2(db, aMod[i].name, aMod[i].module, NULL, NULL);

21
test.sql Normal file
View file

@ -0,0 +1,21 @@
.load dist/vec0
.mode box
.header on
create table test as
select value
from json_each('[
[1.0, 2.0, -3.0],
[-1.0, 2.0, 3.0],
[1.0, 2.0, 3.0],
[1.0, 2.0, 3.0],
[1.0, -2.0, 3.0]
]');
select
vec_to_json(vec_min(value)),
vec_to_json(vec_max(value)),
vec_to_json(vec_avg(value))
from test;

View file

@ -86,6 +86,7 @@ def spread_args(args):
FUNCTIONS = [ FUNCTIONS = [
"vec_add", "vec_add",
"vec_avg",
"vec_bit", "vec_bit",
"vec_debug", "vec_debug",
"vec_distance_cosine", "vec_distance_cosine",
@ -94,6 +95,8 @@ FUNCTIONS = [
"vec_f32", "vec_f32",
"vec_int8", "vec_int8",
"vec_length", "vec_length",
"vec_max",
"vec_min",
"vec_normalize", "vec_normalize",
"vec_quantize_binary", "vec_quantize_binary",
"vec_quantize_i8", "vec_quantize_i8",
@ -459,6 +462,40 @@ def test_vec_sub():
): ):
vec_sub(_int8([2]), _f32([1]), a="vec_int8(?)") vec_sub(_int8([2]), _f32([1]), a="vec_int8(?)")
def test_vec_min():
def vec_min(values, wrap="(vec_f32(?))"):
return db.execute(
"select vec_min(column1) from (values {})".format(", ".join([wrap] * len(values))), values
).fetchone()[0]
assert vec_min(["[1]", "[2]"]) == _f32([1])
assert vec_min(["[1,2,3,4]", "[-5,-6,-7,-8]"]) == _f32([-5,-6,-7,-8])
# TODO: int8 tests, block binary vectors, overflowing
with pytest.raises(
sqlite3.OperationalError,
match=re.escape("vec_min(): vector dimensions do not match."),
):
vec_min(["[1]", "[2,3]"])
def test_vec_max():
def vec_max(values, wrap="(vec_f32(?))"):
return db.execute(
"select vec_max(column1) from (values {})".format(", ".join([wrap] * len(values))), values
).fetchone()[0]
assert vec_max(["[1]", "[2]"]) == _f32([1])
assert vec_max(["[1,2,3,4]", "[-5,-6,-7,-8]"]) == _f32([-5,-6,-7,-8])
# TODO: int8 tests, block binary vectors, overflowing
with pytest.raises(
sqlite3.OperationalError,
match=re.escape("vec_min(): vector dimensions do not match."),
):
vec_min(["[1]", "[2,3]"])
def test_vec_to_json(): def test_vec_to_json():
vec_to_json = lambda *args, input="?": db.execute( vec_to_json = lambda *args, input="?": db.execute(