vec_type(), API references

This commit is contained in:
Alex Garcia 2024-07-22 21:24:44 -07:00
parent cfd8e9a46b
commit ff6cf96e2a
6 changed files with 677 additions and 240 deletions

View file

@ -1082,8 +1082,105 @@ finish:
return;
}
static void vec_quantize_i8(sqlite3_context *context, int argc,
char * vec_type_name(enum VectorElementType elementType) {
switch(elementType) {
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32:
return "float32";
case SQLITE_VEC_ELEMENT_TYPE_INT8:
return "int8";
case SQLITE_VEC_ELEMENT_TYPE_BIT:
return "bit";
}
}
static void vec_type(sqlite3_context *context, int argc,
sqlite3_value **argv) {
assert(argc == 1);
void *vector;
size_t dimensions;
vector_cleanup cleanup;
char *pzError;
enum VectorElementType elementType;
int rc = vector_from_value(argv[0], &vector, &dimensions, &elementType,
&cleanup, &pzError);
if (rc != SQLITE_OK) {
sqlite3_result_error(context, pzError, -1);
sqlite3_free(pzError);
return;
}
sqlite3_result_text(context, vec_type_name(elementType), -1, SQLITE_STATIC);
cleanup(vector);
}
static void vec_quantize_binary(sqlite3_context *context, int argc,
sqlite3_value **argv) {
assert(argc == 1);
void *vector;
size_t dimensions;
vector_cleanup vectorCleanup;
char *pzError;
enum VectorElementType elementType;
int rc = vector_from_value(argv[0], &vector, &dimensions, &elementType,
&vectorCleanup, &pzError);
if (rc != SQLITE_OK) {
sqlite3_result_error(context, pzError, -1);
sqlite3_free(pzError);
return;
}
if(dimensions <= 0) {
sqlite3_result_error(context, "Zero length vectors are not supported.", -1);
goto cleanup;
return;
}
if((dimensions % CHAR_BIT) != 0) {
sqlite3_result_error(context, "Binary quantization requires vectors with a length divisible by 8", -1);
goto cleanup;
return;
}
int sz = dimensions / CHAR_BIT;
u8 *out = sqlite3_malloc(sz);
if (!out) {
sqlite3_result_error_code(context, SQLITE_NOMEM);
goto cleanup;
return;
}
memset(out, 0, sz);
switch(elementType) {
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: {
for (size_t i = 0; i < dimensions; i++) {
int res = ((f32 *)vector)[i] > 0.0;
out[i / 8] |= (res << (i % 8));
}
break;
}
case SQLITE_VEC_ELEMENT_TYPE_INT8: {
for (size_t i = 0; i < dimensions; i++) {
int res = ((i8 *)vector)[i] > 0;
out[i / 8] |= (res << (i % 8));
}
break;
}
case SQLITE_VEC_ELEMENT_TYPE_BIT: {
sqlite3_result_error(context, "Can only binary quantize float or int8 vectors", -1);
sqlite3_free(out);
return;
}
}
sqlite3_result_blob(context, out, sz, sqlite3_free);
sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_BIT);
cleanup:
vectorCleanup(vector);
}
static void vec_quantize_int8(sqlite3_context *context, int argc,
sqlite3_value **argv) {
assert(argc == 2);
f32 *srcVector;
size_t dimensions;
fvec_cleanup srcCleanup;
@ -1099,39 +1196,23 @@ static void vec_quantize_i8(sqlite3_context *context, int argc,
int sz = dimensions * sizeof(i8);
out = sqlite3_malloc(sz);
if (!out) {
rc = SQLITE_NOMEM;
sqlite3_result_error_nomem(context);
goto cleanup;
}
memset(out, 0, sz);
if (argc == 2) {
if ((sqlite3_value_type(argv[1]) != SQLITE_TEXT) ||
(sqlite3_value_bytes(argv[1]) != strlen("unit")) ||
(sqlite3_stricmp((const char *)sqlite3_value_text(argv[1]), "unit") !=
0)) {
sqlite3_result_error(context,
"2nd argument to vec_quantize_i8() must be 'unit', "
"or ranges must be provided.",
-1);
sqlite3_free(out);
goto cleanup;
}
f32 step = (1.0 - (-1.0)) / 255;
for (size_t i = 0; i < dimensions; i++) {
out[i] = ((srcVector[i] - (-1.0)) / step) - 128;
}
} else if (argc == 3) {
// f32 * minVector, maxVector;
// size_t d;
// fvec_cleanup minCleanup, maxCleanup;
// int rc = fvec_from_value(argv[1], )
if ((sqlite3_value_type(argv[1]) != SQLITE_TEXT) ||
(sqlite3_value_bytes(argv[1]) != strlen("unit")) ||
(sqlite3_stricmp((const char *)sqlite3_value_text(argv[1]), "unit") !=
0)) {
sqlite3_result_error(context, "2nd argument to vec_quantize_i8() must be 'unit'.", -1);
sqlite3_free(out);
// TODO
sqlite3_result_error(
context, "ranges parameter not supported in vec_quantize_i8 yet.", -1);
goto cleanup;
}
f32 step = (1.0 - (-1.0)) / 255;
for (size_t i = 0; i < dimensions; i++) {
out[i] = ((srcVector[i] - (-1.0)) / step) - 128;
}
sqlite3_result_blob(context, out, dimensions * sizeof(i8), sqlite3_free);
sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_INT8);
@ -1140,58 +1221,6 @@ cleanup:
srcCleanup(srcVector);
}
static void vec_quantize_binary(sqlite3_context *context, int argc,
sqlite3_value **argv) {
assert(argc == 1);
void *vector;
size_t dimensions;
vector_cleanup cleanup;
char *pzError;
enum VectorElementType elementType;
int rc = vector_from_value(argv[0], &vector, &dimensions, &elementType,
&cleanup, &pzError);
if (rc != SQLITE_OK) {
sqlite3_result_error(context, pzError, -1);
sqlite3_free(pzError);
return;
}
if (elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) {
int sz = dimensions / CHAR_BIT;
u8 *out = sqlite3_malloc(sz);
if (!out) {
cleanup(vector);
sqlite3_result_error_code(context, SQLITE_NOMEM);
return;
}
memset(out, 0, sz);
for (size_t i = 0; i < dimensions; i++) {
int res = ((f32 *)vector)[i] > 0.0;
out[i / 8] |= (res << (i % 8));
}
sqlite3_result_blob(context, out, dimensions / CHAR_BIT, sqlite3_free);
sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_BIT);
} else if (elementType == SQLITE_VEC_ELEMENT_TYPE_INT8) {
int sz = dimensions / CHAR_BIT;
u8 *out = sqlite3_malloc(sz);
if (!out) {
cleanup(vector);
sqlite3_result_error_code(context, SQLITE_NOMEM);
return;
}
memset(out, 0, sz);
for (size_t i = 0; i < dimensions; i++) {
int res = ((i8 *)vector)[i] > 0;
out[i / 8] |= (res << (i % 8));
}
sqlite3_result_blob(context, out, dimensions / CHAR_BIT, sqlite3_free);
sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_BIT);
} else {
sqlite3_result_error(context,
"Can only binary quantize float or int8 vectors", -1);
return;
}
}
static void vec_add(sqlite3_context *context, int argc, sqlite3_value **argv) {
assert(argc == 2);
@ -2778,7 +2807,7 @@ static int vec_npy_eachColumnBuffer(vec_npy_each_cursor *pCur,
}
case SQLITE_VEC_ELEMENT_TYPE_INT8:
case SQLITE_VEC_ELEMENT_TYPE_BIT: {
// TODO
// https://github.com/asg017/sqlite-vec/issues/42
sqlite3_result_error(context,
"vec_npy_each only supports float32 vectors", -1);
break;
@ -2806,7 +2835,7 @@ static int vec_npy_eachColumnFile(vec_npy_each_cursor *pCur,
}
case SQLITE_VEC_ELEMENT_TYPE_INT8:
case SQLITE_VEC_ELEMENT_TYPE_BIT: {
// TODO
// https://github.com/asg017/sqlite-vec/issues/42
sqlite3_result_error(context,
"vec_npy_each only supports float32 vectors", -1);
break;
@ -5902,13 +5931,13 @@ static sqlite3_module vec0Module = {
/* xCommit */ 0,
/* xRollback */ 0,
/* xFindFunction */ 0,
/* xRename */ 0, // TODO
/* xRename */ 0, // https://github.com/asg017/sqlite-vec/issues/43
/* xSavepoint */ 0,
/* xRelease */ 0,
/* xRollbackTo */ 0,
/* xShadowName */ vec0ShadowName,
#if SQLITE_VERSION_NUMBER >= 3044000
/* xIntegrity */ 0, // TODO
/* xIntegrity */ 0, // https://github.com/asg017/sqlite-vec/issues/44
#endif
};
#pragma endregion
@ -6661,6 +6690,7 @@ __declspec(dllexport)
{"vec_distance_hamming",vec_distance_hamming, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE, },
{"vec_distance_cosine", vec_distance_cosine, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE, },
{"vec_length", vec_length, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE, },
{"vec_type", vec_type, 1, DEFAULT_FLAGS, },
{"vec_to_json", vec_to_json, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, },
{"vec_add", vec_add, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, },
{"vec_sub", vec_sub, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, },
@ -6669,8 +6699,7 @@ __declspec(dllexport)
{"vec_f32", vec_f32, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, },
{"vec_bit", vec_bit, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, },
{"vec_int8", vec_int8, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, },
{"vec_quantize_i8", vec_quantize_i8, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, },
{"vec_quantize_i8", vec_quantize_i8, 3, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, },
{"vec_quantize_int8", vec_quantize_int8, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, },
{"vec_quantize_binary", vec_quantize_binary, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, },
{"vec_static_blob_from_raw", vec_static_blob_from_raw, 4, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE },
// clang-format on