mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 00:36:56 +02:00
init pass vec_min vec_max vec_avg
This commit is contained in:
parent
6875f7649c
commit
44dcb3b391
3 changed files with 333 additions and 0 deletions
275
sqlite-vec.c
275
sqlite-vec.c
|
|
@ -1444,6 +1444,256 @@ static void vec_normalize(sqlite3_context *context, int argc,
|
|||
cleanup(vector);
|
||||
}
|
||||
|
||||
|
||||
typedef struct vec_agg_context vec_agg_context;
|
||||
struct vec_agg_context {
|
||||
enum VectorElementType elementType;
|
||||
void * vector;
|
||||
size_t dimensions;
|
||||
int n;
|
||||
};
|
||||
|
||||
size_t vector_size(enum VectorElementType elementType, size_t dimensions) {
|
||||
switch (elementType) {
|
||||
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32:
|
||||
return dimensions * sizeof(f32);
|
||||
case SQLITE_VEC_ELEMENT_TYPE_INT8:
|
||||
return dimensions * sizeof(i8);
|
||||
case SQLITE_VEC_ELEMENT_TYPE_BIT:
|
||||
return dimensions / CHAR_BIT;
|
||||
}
|
||||
}
|
||||
|
||||
static void vec_minStep(sqlite3_context *context, int argc, sqlite3_value **argv){
|
||||
todo_assert(argc==1);
|
||||
void *vector;
|
||||
size_t dimensions;
|
||||
vector_cleanup cleanup;
|
||||
char *err;
|
||||
enum VectorElementType elementType;
|
||||
|
||||
int rc = vector_from_value(argv[0], &vector, &dimensions, &elementType,
|
||||
&cleanup, &err);
|
||||
if (rc != SQLITE_OK) {
|
||||
sqlite3_result_error(context, err, -1);
|
||||
sqlite3_free(err);
|
||||
return;
|
||||
}
|
||||
|
||||
if ((elementType != SQLITE_VEC_ELEMENT_TYPE_FLOAT32) && (elementType != SQLITE_VEC_ELEMENT_TYPE_INT8)) {
|
||||
sqlite3_result_error(
|
||||
context, "only float32 or int8 vectors are supported in vec_min", -1);
|
||||
goto finish;
|
||||
}
|
||||
vec_agg_context *p;
|
||||
p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p));
|
||||
if(!p) {
|
||||
sqlite3_result_error_nomem(context);
|
||||
goto finish;
|
||||
}
|
||||
if(p->n) {
|
||||
p->n++;
|
||||
if(p->elementType != elementType) {
|
||||
sqlite3_result_error(context, "vec_min(): vector type mismatch.", -1);
|
||||
goto finish;
|
||||
}
|
||||
if(p->dimensions != dimensions) {
|
||||
sqlite3_result_error(context, "vec_min(): vector dimensions do not match.", -1);
|
||||
goto finish;
|
||||
}
|
||||
for(size_t i = 0; i < dimensions; i++) {
|
||||
if(p->elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) {
|
||||
if( ((f32*)vector)[i] < ((f32*)p->vector)[i]) {
|
||||
((f32*)p->vector)[i] = ((f32*)vector)[i];
|
||||
}
|
||||
}
|
||||
else if(p->elementType == SQLITE_VEC_ELEMENT_TYPE_INT8) {
|
||||
if( ((i8*)vector)[i] < ((i8*)p->vector)[i]) {
|
||||
((i8*)p->vector)[i] = ((i8*)vector)[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}else {
|
||||
size_t sz = vector_size(elementType, dimensions);
|
||||
p->dimensions = dimensions;
|
||||
p->elementType = elementType;
|
||||
p->vector = sqlite3_malloc(sz);
|
||||
if(!p->vector) {
|
||||
sqlite3_result_error_nomem(context);
|
||||
goto finish;
|
||||
}
|
||||
memset(p->vector, 0, sz);
|
||||
memcpy(p->vector, vector, sz);
|
||||
p->n = 1;
|
||||
}
|
||||
|
||||
finish:
|
||||
cleanup(vector);
|
||||
}
|
||||
|
||||
static void vec_minFinal(sqlite3_context *context){
|
||||
vec_agg_context *p;
|
||||
p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p));
|
||||
if(!p) {
|
||||
sqlite3_result_error_nomem(context);
|
||||
return;
|
||||
}
|
||||
sqlite3_result_blob(context, p->vector, vector_size(p->elementType, p->dimensions), sqlite3_free);
|
||||
sqlite3_result_subtype(context, p->elementType);
|
||||
}
|
||||
|
||||
static void vec_maxStep(sqlite3_context *context, int argc, sqlite3_value **argv){
|
||||
todo_assert(argc==1);
|
||||
void *vector;
|
||||
size_t dimensions;
|
||||
vector_cleanup cleanup;
|
||||
char *err;
|
||||
enum VectorElementType elementType;
|
||||
|
||||
int rc = vector_from_value(argv[0], &vector, &dimensions, &elementType,
|
||||
&cleanup, &err);
|
||||
if (rc != SQLITE_OK) {
|
||||
sqlite3_result_error(context, err, -1);
|
||||
sqlite3_free(err);
|
||||
return;
|
||||
}
|
||||
|
||||
if ((elementType != SQLITE_VEC_ELEMENT_TYPE_FLOAT32) && (elementType != SQLITE_VEC_ELEMENT_TYPE_INT8)) {
|
||||
sqlite3_result_error(
|
||||
context, "only float32 or int8 vectors are supported in vec_max", -1);
|
||||
goto finish;
|
||||
}
|
||||
vec_agg_context *p;
|
||||
p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p));
|
||||
if(!p) {
|
||||
sqlite3_result_error_nomem(context);
|
||||
goto finish;
|
||||
}
|
||||
if(p->n) {
|
||||
p->n++;
|
||||
if(p->elementType != elementType) {
|
||||
sqlite3_result_error(context, "vec_max(): vector type mismatch.", -1);
|
||||
goto finish;
|
||||
}
|
||||
if(p->dimensions != dimensions) {
|
||||
sqlite3_result_error(context, "vec_max(): vector dimensions do not match.", -1);
|
||||
goto finish;
|
||||
}
|
||||
for(size_t i = 0; i < dimensions; i++) {
|
||||
if(p->elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) {
|
||||
if( ((f32*)vector)[i] > ((f32*)p->vector)[i]) {
|
||||
((f32*)p->vector)[i] = ((f32*)vector)[i];
|
||||
}
|
||||
}
|
||||
else if(p->elementType == SQLITE_VEC_ELEMENT_TYPE_INT8) {
|
||||
if( ((i8*)vector)[i] > ((i8*)p->vector)[i]) {
|
||||
((i8*)p->vector)[i] = ((i8*)vector)[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}else {
|
||||
size_t sz = vector_size(elementType, dimensions);
|
||||
p->dimensions = dimensions;
|
||||
p->elementType = elementType;
|
||||
p->vector = sqlite3_malloc(sz);
|
||||
if(!p->vector) {
|
||||
sqlite3_result_error_nomem(context);
|
||||
goto finish;
|
||||
}
|
||||
memset(p->vector, 0, sz);
|
||||
memcpy(p->vector, vector, sz);
|
||||
p->n = 1;
|
||||
}
|
||||
|
||||
finish:
|
||||
cleanup(vector);
|
||||
}
|
||||
|
||||
static void vec_maxFinal(sqlite3_context *context){
|
||||
vec_agg_context *p;
|
||||
p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p));
|
||||
if(!p) {
|
||||
sqlite3_result_error_nomem(context);
|
||||
return;
|
||||
}
|
||||
sqlite3_result_blob(context, p->vector, vector_size(p->elementType, p->dimensions), sqlite3_free);
|
||||
sqlite3_result_subtype(context, p->elementType);
|
||||
}
|
||||
static void vec_avgStep(sqlite3_context *context, int argc, sqlite3_value **argv){
|
||||
todo_assert(argc==1);
|
||||
void *vector;
|
||||
size_t dimensions;
|
||||
vector_cleanup cleanup;
|
||||
char *err;
|
||||
enum VectorElementType elementType;
|
||||
|
||||
int rc = vector_from_value(argv[0], &vector, &dimensions, &elementType,
|
||||
&cleanup, &err);
|
||||
if (rc != SQLITE_OK) {
|
||||
sqlite3_result_error(context, err, -1);
|
||||
sqlite3_free(err);
|
||||
return;
|
||||
}
|
||||
|
||||
if ((elementType != SQLITE_VEC_ELEMENT_TYPE_FLOAT32)) {
|
||||
sqlite3_result_error(
|
||||
context, "only float32 vectors are supported in vec_avg", -1);
|
||||
goto finish;
|
||||
}
|
||||
vec_agg_context *p;
|
||||
p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p));
|
||||
if(!p) {
|
||||
sqlite3_result_error_nomem(context);
|
||||
goto finish;
|
||||
}
|
||||
if(p->n) {
|
||||
p->n++;
|
||||
if(p->elementType != elementType) {
|
||||
sqlite3_result_error(context, "vec_avg(): vector type mismatch.", -1);
|
||||
goto finish;
|
||||
}
|
||||
if(p->dimensions != dimensions) {
|
||||
sqlite3_result_error(context, "vec_avg(): vector dimensions do not match.", -1);
|
||||
goto finish;
|
||||
}
|
||||
for(size_t i = 0; i < dimensions; i++) {
|
||||
((f32*)p->vector)[i] += ((f32*)vector)[i];
|
||||
}
|
||||
|
||||
}else {
|
||||
size_t sz = vector_size(elementType, dimensions);
|
||||
p->dimensions = dimensions;
|
||||
p->elementType = elementType;
|
||||
p->vector = sqlite3_malloc(sz);
|
||||
if(!p->vector) {
|
||||
sqlite3_result_error_nomem(context);
|
||||
goto finish;
|
||||
}
|
||||
memset(p->vector, 0, sz);
|
||||
memcpy(p->vector, vector, sz);
|
||||
p->n = 1;
|
||||
}
|
||||
|
||||
finish:
|
||||
cleanup(vector);
|
||||
}
|
||||
|
||||
static void vec_avgFinal(sqlite3_context *context){
|
||||
vec_agg_context *p;
|
||||
p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p));
|
||||
if(!p) {
|
||||
sqlite3_result_error_nomem(context);
|
||||
return;
|
||||
}
|
||||
for(size_t i = 0; i < p->dimensions; i++) {
|
||||
((f32*)p->vector)[i] = ((f32*)p->vector)[i] / ((float)p->n);
|
||||
}
|
||||
sqlite3_result_blob(context, p->vector, vector_size(p->elementType, p->dimensions), sqlite3_free);
|
||||
sqlite3_result_subtype(context, p->elementType);
|
||||
}
|
||||
|
||||
static void _static_text_func(sqlite3_context *context, int argc,
|
||||
sqlite3_value **argv) {
|
||||
UNUSED_PARAMETER(argc);
|
||||
|
|
@ -5726,6 +5976,20 @@ __declspec(dllexport)
|
|||
#endif
|
||||
// clang-format on
|
||||
};
|
||||
static const struct {
|
||||
char *zFName;
|
||||
void (*xStep)(sqlite3_context *, int, sqlite3_value **);
|
||||
void (*xFinal)(sqlite3_context *);
|
||||
int nArg;
|
||||
int flags;
|
||||
void *p;
|
||||
} aAggregateFunc[] = {
|
||||
// clang-format off
|
||||
{"vec_min", vec_minStep, vec_minFinal, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL },
|
||||
{"vec_max", vec_maxStep, vec_maxFinal, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL },
|
||||
{"vec_avg", vec_avgStep, vec_avgFinal, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL },
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
#ifdef SQLITE_VEC_ENABLE_EXPERIMENTAL
|
||||
vec_static_blob_data * static_blob_data;
|
||||
|
|
@ -5758,6 +6022,17 @@ __declspec(dllexport)
|
|||
return rc;
|
||||
}
|
||||
}
|
||||
for (unsigned long i = 0;
|
||||
i < sizeof(aAggregateFunc) / sizeof(aAggregateFunc[0]) && rc == SQLITE_OK; i++) {
|
||||
rc = sqlite3_create_function_v2(db, aAggregateFunc[i].zFName, aAggregateFunc[i].nArg,
|
||||
aAggregateFunc[i].flags, aAggregateFunc[i].p, NULL,
|
||||
aAggregateFunc[i].xStep, aAggregateFunc[i].xFinal, NULL);
|
||||
if (rc != SQLITE_OK) {
|
||||
*pzErrMsg = sqlite3_mprintf("Error creating function %s: %s",
|
||||
aAggregateFunc[i].zFName, sqlite3_errmsg(db));
|
||||
return rc;
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned long i = 0; i < countof(aMod) && rc == SQLITE_OK; i++) {
|
||||
rc = sqlite3_create_module_v2(db, aMod[i].name, aMod[i].module, NULL, NULL);
|
||||
|
|
|
|||
21
test.sql
Normal file
21
test.sql
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
.load dist/vec0
|
||||
.mode box
|
||||
.header on
|
||||
|
||||
create table test as
|
||||
select value
|
||||
from json_each('[
|
||||
[1.0, 2.0, -3.0],
|
||||
[-1.0, 2.0, 3.0],
|
||||
[1.0, 2.0, 3.0],
|
||||
[1.0, 2.0, 3.0],
|
||||
[1.0, -2.0, 3.0]
|
||||
]');
|
||||
|
||||
select
|
||||
vec_to_json(vec_min(value)),
|
||||
vec_to_json(vec_max(value)),
|
||||
vec_to_json(vec_avg(value))
|
||||
from test;
|
||||
|
||||
|
||||
|
|
@ -86,6 +86,7 @@ def spread_args(args):
|
|||
|
||||
FUNCTIONS = [
|
||||
"vec_add",
|
||||
"vec_avg",
|
||||
"vec_bit",
|
||||
"vec_debug",
|
||||
"vec_distance_cosine",
|
||||
|
|
@ -94,6 +95,8 @@ FUNCTIONS = [
|
|||
"vec_f32",
|
||||
"vec_int8",
|
||||
"vec_length",
|
||||
"vec_max",
|
||||
"vec_min",
|
||||
"vec_normalize",
|
||||
"vec_quantize_binary",
|
||||
"vec_quantize_i8",
|
||||
|
|
@ -459,6 +462,40 @@ def test_vec_sub():
|
|||
):
|
||||
vec_sub(_int8([2]), _f32([1]), a="vec_int8(?)")
|
||||
|
||||
def test_vec_min():
|
||||
def vec_min(values, wrap="(vec_f32(?))"):
|
||||
return db.execute(
|
||||
"select vec_min(column1) from (values {})".format(", ".join([wrap] * len(values))), values
|
||||
).fetchone()[0]
|
||||
|
||||
assert vec_min(["[1]", "[2]"]) == _f32([1])
|
||||
assert vec_min(["[1,2,3,4]", "[-5,-6,-7,-8]"]) == _f32([-5,-6,-7,-8])
|
||||
|
||||
# TODO: int8 tests, block binary vectors, overflowing
|
||||
|
||||
with pytest.raises(
|
||||
sqlite3.OperationalError,
|
||||
match=re.escape("vec_min(): vector dimensions do not match."),
|
||||
):
|
||||
vec_min(["[1]", "[2,3]"])
|
||||
|
||||
def test_vec_max():
|
||||
def vec_max(values, wrap="(vec_f32(?))"):
|
||||
return db.execute(
|
||||
"select vec_max(column1) from (values {})".format(", ".join([wrap] * len(values))), values
|
||||
).fetchone()[0]
|
||||
|
||||
assert vec_max(["[1]", "[2]"]) == _f32([1])
|
||||
assert vec_max(["[1,2,3,4]", "[-5,-6,-7,-8]"]) == _f32([-5,-6,-7,-8])
|
||||
|
||||
# TODO: int8 tests, block binary vectors, overflowing
|
||||
|
||||
with pytest.raises(
|
||||
sqlite3.OperationalError,
|
||||
match=re.escape("vec_min(): vector dimensions do not match."),
|
||||
):
|
||||
vec_min(["[1]", "[2,3]"])
|
||||
|
||||
|
||||
def test_vec_to_json():
|
||||
vec_to_json = lambda *args, input="?": db.execute(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue