mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 00:36:56 +02:00
init pass vec_min vec_max vec_avg
This commit is contained in:
parent
6875f7649c
commit
44dcb3b391
3 changed files with 333 additions and 0 deletions
275
sqlite-vec.c
275
sqlite-vec.c
|
|
@ -1444,6 +1444,256 @@ static void vec_normalize(sqlite3_context *context, int argc,
|
||||||
cleanup(vector);
|
cleanup(vector);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
typedef struct vec_agg_context vec_agg_context;
|
||||||
|
struct vec_agg_context {
|
||||||
|
enum VectorElementType elementType;
|
||||||
|
void * vector;
|
||||||
|
size_t dimensions;
|
||||||
|
int n;
|
||||||
|
};
|
||||||
|
|
||||||
|
size_t vector_size(enum VectorElementType elementType, size_t dimensions) {
|
||||||
|
switch (elementType) {
|
||||||
|
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32:
|
||||||
|
return dimensions * sizeof(f32);
|
||||||
|
case SQLITE_VEC_ELEMENT_TYPE_INT8:
|
||||||
|
return dimensions * sizeof(i8);
|
||||||
|
case SQLITE_VEC_ELEMENT_TYPE_BIT:
|
||||||
|
return dimensions / CHAR_BIT;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void vec_minStep(sqlite3_context *context, int argc, sqlite3_value **argv){
|
||||||
|
todo_assert(argc==1);
|
||||||
|
void *vector;
|
||||||
|
size_t dimensions;
|
||||||
|
vector_cleanup cleanup;
|
||||||
|
char *err;
|
||||||
|
enum VectorElementType elementType;
|
||||||
|
|
||||||
|
int rc = vector_from_value(argv[0], &vector, &dimensions, &elementType,
|
||||||
|
&cleanup, &err);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
sqlite3_result_error(context, err, -1);
|
||||||
|
sqlite3_free(err);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((elementType != SQLITE_VEC_ELEMENT_TYPE_FLOAT32) && (elementType != SQLITE_VEC_ELEMENT_TYPE_INT8)) {
|
||||||
|
sqlite3_result_error(
|
||||||
|
context, "only float32 or int8 vectors are supported in vec_min", -1);
|
||||||
|
goto finish;
|
||||||
|
}
|
||||||
|
vec_agg_context *p;
|
||||||
|
p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p));
|
||||||
|
if(!p) {
|
||||||
|
sqlite3_result_error_nomem(context);
|
||||||
|
goto finish;
|
||||||
|
}
|
||||||
|
if(p->n) {
|
||||||
|
p->n++;
|
||||||
|
if(p->elementType != elementType) {
|
||||||
|
sqlite3_result_error(context, "vec_min(): vector type mismatch.", -1);
|
||||||
|
goto finish;
|
||||||
|
}
|
||||||
|
if(p->dimensions != dimensions) {
|
||||||
|
sqlite3_result_error(context, "vec_min(): vector dimensions do not match.", -1);
|
||||||
|
goto finish;
|
||||||
|
}
|
||||||
|
for(size_t i = 0; i < dimensions; i++) {
|
||||||
|
if(p->elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) {
|
||||||
|
if( ((f32*)vector)[i] < ((f32*)p->vector)[i]) {
|
||||||
|
((f32*)p->vector)[i] = ((f32*)vector)[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if(p->elementType == SQLITE_VEC_ELEMENT_TYPE_INT8) {
|
||||||
|
if( ((i8*)vector)[i] < ((i8*)p->vector)[i]) {
|
||||||
|
((i8*)p->vector)[i] = ((i8*)vector)[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}else {
|
||||||
|
size_t sz = vector_size(elementType, dimensions);
|
||||||
|
p->dimensions = dimensions;
|
||||||
|
p->elementType = elementType;
|
||||||
|
p->vector = sqlite3_malloc(sz);
|
||||||
|
if(!p->vector) {
|
||||||
|
sqlite3_result_error_nomem(context);
|
||||||
|
goto finish;
|
||||||
|
}
|
||||||
|
memset(p->vector, 0, sz);
|
||||||
|
memcpy(p->vector, vector, sz);
|
||||||
|
p->n = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
finish:
|
||||||
|
cleanup(vector);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void vec_minFinal(sqlite3_context *context){
|
||||||
|
vec_agg_context *p;
|
||||||
|
p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p));
|
||||||
|
if(!p) {
|
||||||
|
sqlite3_result_error_nomem(context);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
sqlite3_result_blob(context, p->vector, vector_size(p->elementType, p->dimensions), sqlite3_free);
|
||||||
|
sqlite3_result_subtype(context, p->elementType);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void vec_maxStep(sqlite3_context *context, int argc, sqlite3_value **argv){
|
||||||
|
todo_assert(argc==1);
|
||||||
|
void *vector;
|
||||||
|
size_t dimensions;
|
||||||
|
vector_cleanup cleanup;
|
||||||
|
char *err;
|
||||||
|
enum VectorElementType elementType;
|
||||||
|
|
||||||
|
int rc = vector_from_value(argv[0], &vector, &dimensions, &elementType,
|
||||||
|
&cleanup, &err);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
sqlite3_result_error(context, err, -1);
|
||||||
|
sqlite3_free(err);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((elementType != SQLITE_VEC_ELEMENT_TYPE_FLOAT32) && (elementType != SQLITE_VEC_ELEMENT_TYPE_INT8)) {
|
||||||
|
sqlite3_result_error(
|
||||||
|
context, "only float32 or int8 vectors are supported in vec_max", -1);
|
||||||
|
goto finish;
|
||||||
|
}
|
||||||
|
vec_agg_context *p;
|
||||||
|
p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p));
|
||||||
|
if(!p) {
|
||||||
|
sqlite3_result_error_nomem(context);
|
||||||
|
goto finish;
|
||||||
|
}
|
||||||
|
if(p->n) {
|
||||||
|
p->n++;
|
||||||
|
if(p->elementType != elementType) {
|
||||||
|
sqlite3_result_error(context, "vec_max(): vector type mismatch.", -1);
|
||||||
|
goto finish;
|
||||||
|
}
|
||||||
|
if(p->dimensions != dimensions) {
|
||||||
|
sqlite3_result_error(context, "vec_max(): vector dimensions do not match.", -1);
|
||||||
|
goto finish;
|
||||||
|
}
|
||||||
|
for(size_t i = 0; i < dimensions; i++) {
|
||||||
|
if(p->elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) {
|
||||||
|
if( ((f32*)vector)[i] > ((f32*)p->vector)[i]) {
|
||||||
|
((f32*)p->vector)[i] = ((f32*)vector)[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if(p->elementType == SQLITE_VEC_ELEMENT_TYPE_INT8) {
|
||||||
|
if( ((i8*)vector)[i] > ((i8*)p->vector)[i]) {
|
||||||
|
((i8*)p->vector)[i] = ((i8*)vector)[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}else {
|
||||||
|
size_t sz = vector_size(elementType, dimensions);
|
||||||
|
p->dimensions = dimensions;
|
||||||
|
p->elementType = elementType;
|
||||||
|
p->vector = sqlite3_malloc(sz);
|
||||||
|
if(!p->vector) {
|
||||||
|
sqlite3_result_error_nomem(context);
|
||||||
|
goto finish;
|
||||||
|
}
|
||||||
|
memset(p->vector, 0, sz);
|
||||||
|
memcpy(p->vector, vector, sz);
|
||||||
|
p->n = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
finish:
|
||||||
|
cleanup(vector);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void vec_maxFinal(sqlite3_context *context){
|
||||||
|
vec_agg_context *p;
|
||||||
|
p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p));
|
||||||
|
if(!p) {
|
||||||
|
sqlite3_result_error_nomem(context);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
sqlite3_result_blob(context, p->vector, vector_size(p->elementType, p->dimensions), sqlite3_free);
|
||||||
|
sqlite3_result_subtype(context, p->elementType);
|
||||||
|
}
|
||||||
|
static void vec_avgStep(sqlite3_context *context, int argc, sqlite3_value **argv){
|
||||||
|
todo_assert(argc==1);
|
||||||
|
void *vector;
|
||||||
|
size_t dimensions;
|
||||||
|
vector_cleanup cleanup;
|
||||||
|
char *err;
|
||||||
|
enum VectorElementType elementType;
|
||||||
|
|
||||||
|
int rc = vector_from_value(argv[0], &vector, &dimensions, &elementType,
|
||||||
|
&cleanup, &err);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
sqlite3_result_error(context, err, -1);
|
||||||
|
sqlite3_free(err);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((elementType != SQLITE_VEC_ELEMENT_TYPE_FLOAT32)) {
|
||||||
|
sqlite3_result_error(
|
||||||
|
context, "only float32 vectors are supported in vec_avg", -1);
|
||||||
|
goto finish;
|
||||||
|
}
|
||||||
|
vec_agg_context *p;
|
||||||
|
p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p));
|
||||||
|
if(!p) {
|
||||||
|
sqlite3_result_error_nomem(context);
|
||||||
|
goto finish;
|
||||||
|
}
|
||||||
|
if(p->n) {
|
||||||
|
p->n++;
|
||||||
|
if(p->elementType != elementType) {
|
||||||
|
sqlite3_result_error(context, "vec_avg(): vector type mismatch.", -1);
|
||||||
|
goto finish;
|
||||||
|
}
|
||||||
|
if(p->dimensions != dimensions) {
|
||||||
|
sqlite3_result_error(context, "vec_avg(): vector dimensions do not match.", -1);
|
||||||
|
goto finish;
|
||||||
|
}
|
||||||
|
for(size_t i = 0; i < dimensions; i++) {
|
||||||
|
((f32*)p->vector)[i] += ((f32*)vector)[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
}else {
|
||||||
|
size_t sz = vector_size(elementType, dimensions);
|
||||||
|
p->dimensions = dimensions;
|
||||||
|
p->elementType = elementType;
|
||||||
|
p->vector = sqlite3_malloc(sz);
|
||||||
|
if(!p->vector) {
|
||||||
|
sqlite3_result_error_nomem(context);
|
||||||
|
goto finish;
|
||||||
|
}
|
||||||
|
memset(p->vector, 0, sz);
|
||||||
|
memcpy(p->vector, vector, sz);
|
||||||
|
p->n = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
finish:
|
||||||
|
cleanup(vector);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void vec_avgFinal(sqlite3_context *context){
|
||||||
|
vec_agg_context *p;
|
||||||
|
p = (vec_agg_context *) sqlite3_aggregate_context(context, sizeof(*p));
|
||||||
|
if(!p) {
|
||||||
|
sqlite3_result_error_nomem(context);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for(size_t i = 0; i < p->dimensions; i++) {
|
||||||
|
((f32*)p->vector)[i] = ((f32*)p->vector)[i] / ((float)p->n);
|
||||||
|
}
|
||||||
|
sqlite3_result_blob(context, p->vector, vector_size(p->elementType, p->dimensions), sqlite3_free);
|
||||||
|
sqlite3_result_subtype(context, p->elementType);
|
||||||
|
}
|
||||||
|
|
||||||
static void _static_text_func(sqlite3_context *context, int argc,
|
static void _static_text_func(sqlite3_context *context, int argc,
|
||||||
sqlite3_value **argv) {
|
sqlite3_value **argv) {
|
||||||
UNUSED_PARAMETER(argc);
|
UNUSED_PARAMETER(argc);
|
||||||
|
|
@ -5726,6 +5976,20 @@ __declspec(dllexport)
|
||||||
#endif
|
#endif
|
||||||
// clang-format on
|
// clang-format on
|
||||||
};
|
};
|
||||||
|
static const struct {
|
||||||
|
char *zFName;
|
||||||
|
void (*xStep)(sqlite3_context *, int, sqlite3_value **);
|
||||||
|
void (*xFinal)(sqlite3_context *);
|
||||||
|
int nArg;
|
||||||
|
int flags;
|
||||||
|
void *p;
|
||||||
|
} aAggregateFunc[] = {
|
||||||
|
// clang-format off
|
||||||
|
{"vec_min", vec_minStep, vec_minFinal, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL },
|
||||||
|
{"vec_max", vec_maxStep, vec_maxFinal, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL },
|
||||||
|
{"vec_avg", vec_avgStep, vec_avgFinal, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL },
|
||||||
|
// clang-format on
|
||||||
|
};
|
||||||
|
|
||||||
#ifdef SQLITE_VEC_ENABLE_EXPERIMENTAL
|
#ifdef SQLITE_VEC_ENABLE_EXPERIMENTAL
|
||||||
vec_static_blob_data * static_blob_data;
|
vec_static_blob_data * static_blob_data;
|
||||||
|
|
@ -5758,6 +6022,17 @@ __declspec(dllexport)
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for (unsigned long i = 0;
|
||||||
|
i < sizeof(aAggregateFunc) / sizeof(aAggregateFunc[0]) && rc == SQLITE_OK; i++) {
|
||||||
|
rc = sqlite3_create_function_v2(db, aAggregateFunc[i].zFName, aAggregateFunc[i].nArg,
|
||||||
|
aAggregateFunc[i].flags, aAggregateFunc[i].p, NULL,
|
||||||
|
aAggregateFunc[i].xStep, aAggregateFunc[i].xFinal, NULL);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
*pzErrMsg = sqlite3_mprintf("Error creating function %s: %s",
|
||||||
|
aAggregateFunc[i].zFName, sqlite3_errmsg(db));
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (unsigned long i = 0; i < countof(aMod) && rc == SQLITE_OK; i++) {
|
for (unsigned long i = 0; i < countof(aMod) && rc == SQLITE_OK; i++) {
|
||||||
rc = sqlite3_create_module_v2(db, aMod[i].name, aMod[i].module, NULL, NULL);
|
rc = sqlite3_create_module_v2(db, aMod[i].name, aMod[i].module, NULL, NULL);
|
||||||
|
|
|
||||||
21
test.sql
Normal file
21
test.sql
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
.load dist/vec0
|
||||||
|
.mode box
|
||||||
|
.header on
|
||||||
|
|
||||||
|
create table test as
|
||||||
|
select value
|
||||||
|
from json_each('[
|
||||||
|
[1.0, 2.0, -3.0],
|
||||||
|
[-1.0, 2.0, 3.0],
|
||||||
|
[1.0, 2.0, 3.0],
|
||||||
|
[1.0, 2.0, 3.0],
|
||||||
|
[1.0, -2.0, 3.0]
|
||||||
|
]');
|
||||||
|
|
||||||
|
select
|
||||||
|
vec_to_json(vec_min(value)),
|
||||||
|
vec_to_json(vec_max(value)),
|
||||||
|
vec_to_json(vec_avg(value))
|
||||||
|
from test;
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -86,6 +86,7 @@ def spread_args(args):
|
||||||
|
|
||||||
FUNCTIONS = [
|
FUNCTIONS = [
|
||||||
"vec_add",
|
"vec_add",
|
||||||
|
"vec_avg",
|
||||||
"vec_bit",
|
"vec_bit",
|
||||||
"vec_debug",
|
"vec_debug",
|
||||||
"vec_distance_cosine",
|
"vec_distance_cosine",
|
||||||
|
|
@ -94,6 +95,8 @@ FUNCTIONS = [
|
||||||
"vec_f32",
|
"vec_f32",
|
||||||
"vec_int8",
|
"vec_int8",
|
||||||
"vec_length",
|
"vec_length",
|
||||||
|
"vec_max",
|
||||||
|
"vec_min",
|
||||||
"vec_normalize",
|
"vec_normalize",
|
||||||
"vec_quantize_binary",
|
"vec_quantize_binary",
|
||||||
"vec_quantize_i8",
|
"vec_quantize_i8",
|
||||||
|
|
@ -459,6 +462,40 @@ def test_vec_sub():
|
||||||
):
|
):
|
||||||
vec_sub(_int8([2]), _f32([1]), a="vec_int8(?)")
|
vec_sub(_int8([2]), _f32([1]), a="vec_int8(?)")
|
||||||
|
|
||||||
|
def test_vec_min():
|
||||||
|
def vec_min(values, wrap="(vec_f32(?))"):
|
||||||
|
return db.execute(
|
||||||
|
"select vec_min(column1) from (values {})".format(", ".join([wrap] * len(values))), values
|
||||||
|
).fetchone()[0]
|
||||||
|
|
||||||
|
assert vec_min(["[1]", "[2]"]) == _f32([1])
|
||||||
|
assert vec_min(["[1,2,3,4]", "[-5,-6,-7,-8]"]) == _f32([-5,-6,-7,-8])
|
||||||
|
|
||||||
|
# TODO: int8 tests, block binary vectors, overflowing
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
sqlite3.OperationalError,
|
||||||
|
match=re.escape("vec_min(): vector dimensions do not match."),
|
||||||
|
):
|
||||||
|
vec_min(["[1]", "[2,3]"])
|
||||||
|
|
||||||
|
def test_vec_max():
|
||||||
|
def vec_max(values, wrap="(vec_f32(?))"):
|
||||||
|
return db.execute(
|
||||||
|
"select vec_max(column1) from (values {})".format(", ".join([wrap] * len(values))), values
|
||||||
|
).fetchone()[0]
|
||||||
|
|
||||||
|
assert vec_max(["[1]", "[2]"]) == _f32([1])
|
||||||
|
assert vec_max(["[1,2,3,4]", "[-5,-6,-7,-8]"]) == _f32([-5,-6,-7,-8])
|
||||||
|
|
||||||
|
# TODO: int8 tests, block binary vectors, overflowing
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
sqlite3.OperationalError,
|
||||||
|
match=re.escape("vec_min(): vector dimensions do not match."),
|
||||||
|
):
|
||||||
|
vec_min(["[1]", "[2,3]"])
|
||||||
|
|
||||||
|
|
||||||
def test_vec_to_json():
|
def test_vec_to_json():
|
||||||
vec_to_json = lambda *args, input="?": db.execute(
|
vec_to_json = lambda *args, input="?": db.execute(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue