typedef float f32

This commit is contained in:
Alex Garcia 2024-04-20 17:05:37 -07:00
parent 29507aa45d
commit cc46a6f2f0

View file

@ -94,15 +94,15 @@ enum VectorElementType {
#define PORTABLE_ALIGN32 __attribute__((aligned(32))) #define PORTABLE_ALIGN32 __attribute__((aligned(32)))
#define PORTABLE_ALIGN64 __attribute__((aligned(64))) #define PORTABLE_ALIGN64 __attribute__((aligned(64)))
static float l2_sqr_float_avx(const void *pVect1v, const void *pVect2v, static f32 l2_sqr_float_avx(const void *pVect1v, const void *pVect2v,
const void *qty_ptr) { const void *qty_ptr) {
float *pVect1 = (float *)pVect1v; f32 *pVect1 = (f32 *)pVect1v;
float *pVect2 = (float *)pVect2v; f32 *pVect2 = (f32 *)pVect2v;
size_t qty = *((size_t *)qty_ptr); size_t qty = *((size_t *)qty_ptr);
float PORTABLE_ALIGN32 TmpRes[8]; f32 PORTABLE_ALIGN32 TmpRes[8];
size_t qty16 = qty >> 4; size_t qty16 = qty >> 4;
const float *pEnd1 = pVect1 + (qty16 << 4); const f32 *pEnd1 = pVect1 + (qty16 << 4);
__m256 diff, v1, v2; __m256 diff, v1, v2;
__m256 sum = _mm256_set1_ps(0); __m256 sum = _mm256_set1_ps(0);
@ -135,14 +135,14 @@ static float l2_sqr_float_avx(const void *pVect1v, const void *pVect2v,
#define PORTABLE_ALIGN32 __attribute__((aligned(32))) #define PORTABLE_ALIGN32 __attribute__((aligned(32)))
// thx https://github.com/nmslib/hnswlib/pull/299/files // thx https://github.com/nmslib/hnswlib/pull/299/files
static float l2_sqr_float_neon(const void *pVect1v, const void *pVect2v, static f32 l2_sqr_float_neon(const void *pVect1v, const void *pVect2v,
const void *qty_ptr) { const void *qty_ptr) {
float *pVect1 = (float *)pVect1v; f32 *pVect1 = (f32 *)pVect1v;
float *pVect2 = (float *)pVect2v; f32 *pVect2 = (f32 *)pVect2v;
size_t qty = *((size_t *)qty_ptr); size_t qty = *((size_t *)qty_ptr);
size_t qty16 = qty >> 4; size_t qty16 = qty >> 4;
const float *pEnd1 = pVect1 + (qty16 << 4); const f32 *pEnd1 = pVect1 + (qty16 << 4);
float32x4_t diff, v1, v2; float32x4_t diff, v1, v2;
float32x4_t sum0 = vdupq_n_f32(0); float32x4_t sum0 = vdupq_n_f32(0);
@ -185,15 +185,15 @@ static float l2_sqr_float_neon(const void *pVect1v, const void *pVect2v,
} }
#endif #endif
static float l2_sqr_float(const void *pVect1v, const void *pVect2v, static f32 l2_sqr_float(const void *pVect1v, const void *pVect2v,
const void *qty_ptr) { const void *qty_ptr) {
float *pVect1 = (float *)pVect1v; f32 *pVect1 = (f32 *)pVect1v;
float *pVect2 = (float *)pVect2v; f32 *pVect2 = (f32 *)pVect2v;
size_t qty = *((size_t *)qty_ptr); size_t qty = *((size_t *)qty_ptr);
float res = 0; f32 res = 0;
for (size_t i = 0; i < qty; i++) { for (size_t i = 0; i < qty; i++) {
float t = *pVect1 - *pVect2; f32 t = *pVect1 - *pVect2;
pVect1++; pVect1++;
pVect2++; pVect2++;
res += t * t; res += t * t;
@ -201,14 +201,14 @@ static float l2_sqr_float(const void *pVect1v, const void *pVect2v,
return sqrt(res); return sqrt(res);
} }
static float l2_sqr_int8(const void *pA, const void *pB, const void *pD) { static f32 l2_sqr_int8(const void *pA, const void *pB, const void *pD) {
i8 *a = (i8 *)pA; i8 *a = (i8 *)pA;
i8 *b = (i8 *)pB; i8 *b = (i8 *)pB;
size_t d = *((size_t *)pD); size_t d = *((size_t *)pD);
float res = 0; f32 res = 0;
for (size_t i = 0; i < d; i++) { for (size_t i = 0; i < d; i++) {
float t = *a - *b; f32 t = *a - *b;
a++; a++;
b++; b++;
res += t * t; res += t * t;
@ -216,8 +216,7 @@ static float l2_sqr_int8(const void *pA, const void *pB, const void *pD) {
return sqrt(res); return sqrt(res);
} }
static float distance_l2_sqr_float(const void *a, const void *b, static f32 distance_l2_sqr_float(const void *a, const void *b, const void *d) {
const void *d) {
#ifdef SQLITE_VEC_ENABLE_NEON #ifdef SQLITE_VEC_ENABLE_NEON
if (((*(const size_t *)d) % 16 == 0)) { if (((*(const size_t *)d) % 16 == 0)) {
return l2_sqr_float_neon(a, b, d); return l2_sqr_float_neon(a, b, d);
@ -231,19 +230,19 @@ static float distance_l2_sqr_float(const void *a, const void *b,
return l2_sqr_float(a, b, d); return l2_sqr_float(a, b, d);
} }
static float distance_l2_sqr_int8(const void *a, const void *b, const void *d) { static f32 distance_l2_sqr_int8(const void *a, const void *b, const void *d) {
return l2_sqr_int8(a, b, d); return l2_sqr_int8(a, b, d);
} }
static float distance_cosine_float(const void *pVect1v, const void *pVect2v, static f32 distance_cosine_float(const void *pVect1v, const void *pVect2v,
const void *qty_ptr) { const void *qty_ptr) {
float *pVect1 = (float *)pVect1v; f32 *pVect1 = (f32 *)pVect1v;
float *pVect2 = (float *)pVect2v; f32 *pVect2 = (f32 *)pVect2v;
size_t qty = *((size_t *)qty_ptr); size_t qty = *((size_t *)qty_ptr);
float dot = 0; f32 dot = 0;
float aMag = 0; f32 aMag = 0;
float bMag = 0; f32 bMag = 0;
for (size_t i = 0; i < qty; i++) { for (size_t i = 0; i < qty; i++) {
dot += *pVect1 * *pVect2; dot += *pVect1 * *pVect2;
aMag += *pVect1 * *pVect1; aMag += *pVect1 * *pVect1;
@ -253,15 +252,15 @@ static float distance_cosine_float(const void *pVect1v, const void *pVect2v,
} }
return 1 - (dot / (sqrt(aMag) * sqrt(bMag))); return 1 - (dot / (sqrt(aMag) * sqrt(bMag)));
} }
static float distance_cosine_int8(const void *pA, const void *pB, static f32 distance_cosine_int8(const void *pA, const void *pB,
const void *pD) { const void *pD) {
i8 *a = (i8 *)pA; i8 *a = (i8 *)pA;
i8 *b = (i8 *)pB; i8 *b = (i8 *)pB;
size_t d = *((size_t *)pD); size_t d = *((size_t *)pD);
float dot = 0; f32 dot = 0;
float aMag = 0; f32 aMag = 0;
float bMag = 0; f32 bMag = 0;
for (size_t i = 0; i < d; i++) { for (size_t i = 0; i < d; i++) {
dot += *a * *b; dot += *a * *b;
aMag += *a * *a; aMag += *a * *a;
@ -286,22 +285,22 @@ static u8 hamdist_table[256] = {
4, 5, 5, 6, 5, 6, 6, 7, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 4, 5, 5, 6, 5, 6, 6, 7, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8}; 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8};
static float distance_hamming_u8(u8 *a, u8 *b, size_t n) { static f32 distance_hamming_u8(u8 *a, u8 *b, size_t n) {
int same = 0; int same = 0;
for (unsigned long i = 0; i < n; i++) { for (unsigned long i = 0; i < n; i++) {
same += hamdist_table[a[i] ^ b[i]]; same += hamdist_table[a[i] ^ b[i]];
} }
return (float)same; return (f32)same;
} }
static float distance_hamming_u64(u64 *a, u64 *b, size_t n) { static f32 distance_hamming_u64(u64 *a, u64 *b, size_t n) {
int same = 0; int same = 0;
for (unsigned long i = 0; i < n; i++) { for (unsigned long i = 0; i < n; i++) {
same += __builtin_popcountl(a[i] ^ b[i]); same += __builtin_popcountl(a[i] ^ b[i]);
} }
return (float)same; return (f32)same;
} }
static float distance_hamming(const void *a, const void *b, const void *d) { static f32 distance_hamming(const void *a, const void *b, const void *d) {
size_t dimensions = *((size_t *)d); size_t dimensions = *((size_t *)d);
todo_assert((dimensions % CHAR_BIT) == 0); todo_assert((dimensions % CHAR_BIT) == 0);
@ -380,11 +379,11 @@ void array_cleanup(struct Array *array) {
array->z = NULL; array->z = NULL;
} }
typedef void (*fvec_cleanup)(float *vector); typedef void (*fvec_cleanup)(f32 *vector);
void fvec_cleanup_noop(float *_) { UNUSED_PARAMETER(_); } void fvec_cleanup_noop(f32 *_) { UNUSED_PARAMETER(_); }
static int fvec_from_value(sqlite3_value *value, float **vector, static int fvec_from_value(sqlite3_value *value, f32 **vector,
size_t *dimensions, fvec_cleanup *cleanup, size_t *dimensions, fvec_cleanup *cleanup,
char **pzErr) { char **pzErr) {
int value_type = sqlite3_value_type(value); int value_type = sqlite3_value_type(value);
@ -395,14 +394,14 @@ static int fvec_from_value(sqlite3_value *value, float **vector,
*pzErr = sqlite3_mprintf("zero-length vectors are not supported."); *pzErr = sqlite3_mprintf("zero-length vectors are not supported.");
return SQLITE_ERROR; return SQLITE_ERROR;
} }
if ((bytes % sizeof(float)) != 0) { if ((bytes % sizeof(f32)) != 0) {
*pzErr = sqlite3_mprintf("invalid float32 vector BLOB length. Must be " *pzErr = sqlite3_mprintf("invalid float32 vector BLOB length. Must be "
"divisible by %d, found %d", "divisible by %d, found %d",
sizeof(float), bytes); sizeof(f32), bytes);
return SQLITE_ERROR; return SQLITE_ERROR;
} }
*vector = (float *)blob; *vector = (f32 *)blob;
*dimensions = bytes / sizeof(float); *dimensions = bytes / sizeof(f32);
*cleanup = fvec_cleanup_noop; *cleanup = fvec_cleanup_noop;
return SQLITE_OK; return SQLITE_OK;
} }
@ -413,7 +412,7 @@ static int fvec_from_value(sqlite3_value *value, float **vector,
int i = 0; int i = 0;
struct Array x; struct Array x;
int rc = array_init(&x, sizeof(float), ceil(source_len / 2.0)); int rc = array_init(&x, sizeof(f32), ceil(source_len / 2.0));
todo_assert(rc == SQLITE_OK); todo_assert(rc == SQLITE_OK);
// advance leading whitespace to first '[' // advance leading whitespace to first '['
@ -463,7 +462,7 @@ static int fvec_from_value(sqlite3_value *value, float **vector,
goto done; goto done;
} }
float res = (float)result; f32 res = (f32)result;
array_append(&x, (const void *)&res); array_append(&x, (const void *)&res);
offset += (endptr - ptr); offset += (endptr - ptr);
@ -485,7 +484,7 @@ static int fvec_from_value(sqlite3_value *value, float **vector,
done: done:
if (x.length > 0) { if (x.length > 0) {
*vector = (float *)x.z; *vector = (f32 *)x.z;
*dimensions = x.length; *dimensions = x.length;
*cleanup = (fvec_cleanup)sqlite3_free; *cleanup = (fvec_cleanup)sqlite3_free;
return SQLITE_OK; return SQLITE_OK;
@ -558,7 +557,7 @@ int vector_from_value(sqlite3_value *value, void **vector, size_t *dimensions,
int subtype = sqlite3_value_subtype(value); int subtype = sqlite3_value_subtype(value);
if (!subtype || (subtype == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) || if (!subtype || (subtype == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) ||
(subtype == JSON_SUBTYPE)) { (subtype == JSON_SUBTYPE)) {
int rc = fvec_from_value(value, (float **)vector, dimensions, int rc = fvec_from_value(value, (f32 **)vector, dimensions,
(fvec_cleanup *)cleanup, pzErrorMessage); (fvec_cleanup *)cleanup, pzErrorMessage);
if (rc == SQLITE_OK) { if (rc == SQLITE_OK) {
*element_type = SQLITE_VEC_ELEMENT_TYPE_FLOAT32; *element_type = SQLITE_VEC_ELEMENT_TYPE_FLOAT32;
@ -669,7 +668,7 @@ static void vec_npy_file(sqlite3_context *context, int argc,
static void vec_f32(sqlite3_context *context, int argc, sqlite3_value **argv) { static void vec_f32(sqlite3_context *context, int argc, sqlite3_value **argv) {
todo_assert(argc == 1); todo_assert(argc == 1);
int rc; int rc;
float *vector; f32 *vector;
size_t dimensions; size_t dimensions;
fvec_cleanup cleanup; fvec_cleanup cleanup;
char *errmsg; char *errmsg;
@ -679,7 +678,7 @@ static void vec_f32(sqlite3_context *context, int argc, sqlite3_value **argv) {
sqlite3_free(errmsg); sqlite3_free(errmsg);
return; return;
} }
sqlite3_result_blob(context, vector, dimensions * sizeof(float), sqlite3_result_blob(context, vector, dimensions * sizeof(f32),
SQLITE_TRANSIENT); SQLITE_TRANSIENT);
sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_FLOAT32); sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_FLOAT32);
cleanup(vector); cleanup(vector);
@ -764,12 +763,12 @@ static void vec_distance_cosine(sqlite3_context *context, int argc,
goto finish; goto finish;
} }
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: {
float result = distance_cosine_float(a, b, &dimensions); f32 result = distance_cosine_float(a, b, &dimensions);
sqlite3_result_double(context, result); sqlite3_result_double(context, result);
goto finish; goto finish;
} }
case SQLITE_VEC_ELEMENT_TYPE_INT8: { case SQLITE_VEC_ELEMENT_TYPE_INT8: {
float result = distance_cosine_int8(a, b, &dimensions); f32 result = distance_cosine_int8(a, b, &dimensions);
sqlite3_result_double(context, result); sqlite3_result_double(context, result);
goto finish; goto finish;
} }
@ -805,12 +804,12 @@ static void vec_distance_l2(sqlite3_context *context, int argc,
goto finish; goto finish;
} }
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: {
float result = distance_l2_sqr_float(a, b, &dimensions); f32 result = distance_l2_sqr_float(a, b, &dimensions);
sqlite3_result_double(context, result); sqlite3_result_double(context, result);
goto finish; goto finish;
} }
case SQLITE_VEC_ELEMENT_TYPE_INT8: { case SQLITE_VEC_ELEMENT_TYPE_INT8: {
float result = distance_l2_sqr_int8(a, b, &dimensions); f32 result = distance_l2_sqr_int8(a, b, &dimensions);
sqlite3_result_double(context, result); sqlite3_result_double(context, result);
goto finish; goto finish;
} }
@ -865,7 +864,7 @@ finish:
static void vec_quantize_i8(sqlite3_context *context, int argc, static void vec_quantize_i8(sqlite3_context *context, int argc,
sqlite3_value **argv) { sqlite3_value **argv) {
float *srcVector; f32 *srcVector;
size_t dimensions; size_t dimensions;
fvec_cleanup cleanup; fvec_cleanup cleanup;
char *err; char *err;
@ -887,12 +886,12 @@ static void vec_quantize_i8(sqlite3_context *context, int argc,
sqlite3_free(out); sqlite3_free(out);
return; return;
} }
float step = (1.0 - (-1.0)) / 255; f32 step = (1.0 - (-1.0)) / 255;
for (size_t i = 0; i < dimensions; i++) { for (size_t i = 0; i < dimensions; i++) {
out[i] = ((srcVector[i] - (-1.0)) / step) - 128; out[i] = ((srcVector[i] - (-1.0)) / step) - 128;
} }
} else if (argc == 3) { } else if (argc == 3) {
// float * minVector, maxVector; // f32 * minVector, maxVector;
// size_t d; // size_t d;
// fvec_cleanup minCleanup, maxCleanup; // fvec_cleanup minCleanup, maxCleanup;
// int rc = fvec_from_value(argv[1], ) // int rc = fvec_from_value(argv[1], )
@ -925,7 +924,7 @@ static void vec_quantize_binary(sqlite3_context *context, int argc,
u8 *out = sqlite3_malloc(dimensions / CHAR_BIT); u8 *out = sqlite3_malloc(dimensions / CHAR_BIT);
todo_assert(out); todo_assert(out);
for (size_t i = 0; i < dimensions; i++) { for (size_t i = 0; i < dimensions; i++) {
int res = ((float *)vector)[i] > 0.0; int res = ((f32 *)vector)[i] > 0.0;
out[i / 8] |= (res << (i % 8)); out[i / 8] |= (res << (i % 8));
} }
sqlite3_result_blob(context, out, dimensions / CHAR_BIT, sqlite3_free); sqlite3_result_blob(context, out, dimensions / CHAR_BIT, sqlite3_free);
@ -966,14 +965,14 @@ static void vec_add(sqlite3_context *context, int argc, sqlite3_value **argv) {
goto finish; goto finish;
} }
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: {
size_t outSize = dimensions * sizeof(float); size_t outSize = dimensions * sizeof(f32);
float *out = sqlite3_malloc(outSize); f32 *out = sqlite3_malloc(outSize);
if (!out) { if (!out) {
sqlite3_result_error_nomem(context); sqlite3_result_error_nomem(context);
goto finish; goto finish;
} }
for (size_t i = 0; i < dimensions; i++) { for (size_t i = 0; i < dimensions; i++) {
out[i] = ((float *)a)[i] + ((float *)b)[i]; out[i] = ((f32 *)a)[i] + ((f32 *)b)[i];
} }
sqlite3_result_blob(context, out, outSize, sqlite3_free); sqlite3_result_blob(context, out, outSize, sqlite3_free);
sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_FLOAT32); sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_FLOAT32);
@ -1022,14 +1021,14 @@ static void vec_sub(sqlite3_context *context, int argc, sqlite3_value **argv) {
goto finish; goto finish;
} }
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: {
size_t outSize = dimensions * sizeof(float); size_t outSize = dimensions * sizeof(f32);
float *out = sqlite3_malloc(outSize); f32 *out = sqlite3_malloc(outSize);
if (!out) { if (!out) {
sqlite3_result_error_nomem(context); sqlite3_result_error_nomem(context);
goto finish; goto finish;
} }
for (size_t i = 0; i < dimensions; i++) { for (size_t i = 0; i < dimensions; i++) {
out[i] = ((float *)a)[i] - ((float *)b)[i]; out[i] = ((f32 *)a)[i] - ((f32 *)b)[i];
} }
sqlite3_result_blob(context, out, outSize, sqlite3_free); sqlite3_result_blob(context, out, outSize, sqlite3_free);
sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_FLOAT32); sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_FLOAT32);
@ -1107,15 +1106,15 @@ static void vec_slice(sqlite3_context *context, int argc,
switch (elementType) { switch (elementType) {
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: {
float *out = sqlite3_malloc(n * sizeof(float)); f32 *out = sqlite3_malloc(n * sizeof(f32));
if (!out) { if (!out) {
sqlite3_result_error_nomem(context); sqlite3_result_error_nomem(context);
return; return;
} }
for (size_t i = 0; i < n; i++) { for (size_t i = 0; i < n; i++) {
out[i] = ((float *)vector)[start + i]; out[i] = ((f32 *)vector)[start + i];
} }
sqlite3_result_blob(context, out, n * sizeof(float), sqlite3_free); sqlite3_result_blob(context, out, n * sizeof(f32), sqlite3_free);
sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_FLOAT32); sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_FLOAT32);
goto done; goto done;
} }
@ -1183,7 +1182,7 @@ static void vec_to_json(sqlite3_context *context, int argc,
sqlite3_str_appendall(str, ","); sqlite3_str_appendall(str, ",");
} }
if (elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) { if (elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) {
sqlite3_str_appendf(str, "%f", ((float *)vector)[i]); sqlite3_str_appendf(str, "%f", ((f32 *)vector)[i]);
} else if (elementType == SQLITE_VEC_ELEMENT_TYPE_INT8) { } else if (elementType == SQLITE_VEC_ELEMENT_TYPE_INT8) {
sqlite3_str_appendf(str, "%d", ((i8 *)vector)[i]); sqlite3_str_appendf(str, "%d", ((i8 *)vector)[i]);
} else if (elementType == SQLITE_VEC_ELEMENT_TYPE_BIT) { } else if (elementType == SQLITE_VEC_ELEMENT_TYPE_BIT) {
@ -1226,11 +1225,11 @@ static void vec_normalize(sqlite3_context *context, int argc,
return; return;
} }
float *out = sqlite3_malloc(dimensions * sizeof(float)); f32 *out = sqlite3_malloc(dimensions * sizeof(f32));
todo_assert(out); todo_assert(out);
float *v = (float *)vector; f32 *v = (f32 *)vector;
float norm = 0; f32 norm = 0;
for (size_t i = 0; i < dimensions; i++) { for (size_t i = 0; i < dimensions; i++) {
norm += v[i] * v[i]; norm += v[i] * v[i];
} }
@ -1239,7 +1238,7 @@ static void vec_normalize(sqlite3_context *context, int argc,
out[i] = v[i] / norm; out[i] = v[i] / norm;
} }
sqlite3_result_blob(context, out, dimensions * sizeof(float), sqlite3_free); sqlite3_result_blob(context, out, dimensions * sizeof(f32), sqlite3_free);
sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_FLOAT32); sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_FLOAT32);
} }
@ -1482,7 +1481,7 @@ struct VectorColumnDefinition {
size_t vector_column_byte_size(struct VectorColumnDefinition column) { size_t vector_column_byte_size(struct VectorColumnDefinition column) {
switch (column.element_type) { switch (column.element_type) {
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: case SQLITE_VEC_ELEMENT_TYPE_FLOAT32:
return column.dimensions * sizeof(float); return column.dimensions * sizeof(f32);
case SQLITE_VEC_ELEMENT_TYPE_INT8: case SQLITE_VEC_ELEMENT_TYPE_INT8:
return column.dimensions * sizeof(i8); return column.dimensions * sizeof(i8);
case SQLITE_VEC_ELEMENT_TYPE_BIT: case SQLITE_VEC_ELEMENT_TYPE_BIT:
@ -1729,7 +1728,7 @@ static int vec_eachColumn(sqlite3_vtab_cursor *cur, sqlite3_context *context,
case VEC_EACH_COLUMN_VALUE: case VEC_EACH_COLUMN_VALUE:
switch (pCur->vector_type) { switch (pCur->vector_type) {
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: {
sqlite3_result_double(context, ((float *)pCur->vector)[pCur->iRowid]); sqlite3_result_double(context, ((f32 *)pCur->vector)[pCur->iRowid]);
break; break;
} }
case SQLITE_VEC_ELEMENT_TYPE_BIT: { case SQLITE_VEC_ELEMENT_TYPE_BIT: {
@ -2014,7 +2013,7 @@ int parse_npy(const unsigned char *buffer, size_t bufferLength, void **data,
int element_size = 0; int element_size = 0;
// TODO bit // TODO bit
if (*element_type == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) { if (*element_type == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) {
element_size = sizeof(float); element_size = sizeof(f32);
} }
todo_assert((*numElements * *numDimensions * element_size) == dataSize); todo_assert((*numElements * *numDimensions * element_size) == dataSize);
@ -2232,7 +2231,7 @@ static int vec_npy_eachFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum,
int element_size = 0; int element_size = 0;
if (element_type == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) { if (element_type == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) {
element_size = sizeof(float); element_size = sizeof(f32);
} else { } else {
todo("non-f32 numpy array"); todo("non-f32 numpy array");
} }
@ -2321,8 +2320,8 @@ static int vec_npy_eachColumn(sqlite3_vtab_cursor *cur,
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: {
sqlite3_result_blob( sqlite3_result_blob(
context, context,
&pCur->vector[pCur->iRowid * pCur->nDimensions * sizeof(float)], &pCur->vector[pCur->iRowid * pCur->nDimensions * sizeof(f32)],
pCur->nDimensions * sizeof(float), SQLITE_STATIC); pCur->nDimensions * sizeof(f32), SQLITE_STATIC);
break; break;
} }
case SQLITE_VEC_ELEMENT_TYPE_INT8: case SQLITE_VEC_ELEMENT_TYPE_INT8:
@ -2334,11 +2333,10 @@ static int vec_npy_eachColumn(sqlite3_vtab_cursor *cur,
} else if (pCur->input_type == VEC_NPY_EACH_INPUT_FILE) { } else if (pCur->input_type == VEC_NPY_EACH_INPUT_FILE) {
switch (pCur->elementType) { switch (pCur->elementType) {
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: {
sqlite3_result_blob( sqlite3_result_blob(context,
context, &pCur->fileBuffer[pCur->bufferIndex *
&pCur->fileBuffer[pCur->bufferIndex * pCur->nDimensions * pCur->nDimensions * sizeof(f32)],
sizeof(float)], pCur->nDimensions * sizeof(f32), SQLITE_TRANSIENT);
pCur->nDimensions * sizeof(float), SQLITE_TRANSIENT);
break; break;
} }
case SQLITE_VEC_ELEMENT_TYPE_INT8: case SQLITE_VEC_ELEMENT_TYPE_INT8:
@ -2744,7 +2742,7 @@ int vec0_new_chunk(vec0_vtab *p, i64 *chunk_rowid) {
switch (p->vector_columns[i].element_type) { switch (p->vector_columns[i].element_type) {
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: case SQLITE_VEC_ELEMENT_TYPE_FLOAT32:
vectorsSize = vectorsSize =
p->chunk_size * p->vector_columns[i].dimensions * sizeof(float); p->chunk_size * p->vector_columns[i].dimensions * sizeof(f32);
break; break;
case SQLITE_VEC_ELEMENT_TYPE_INT8: case SQLITE_VEC_ELEMENT_TYPE_INT8:
vectorsSize = vectorsSize =
@ -2815,7 +2813,7 @@ struct vec0_query_knn_data {
// Array of rowids of size k. Must be freed with sqlite3_freee(). // Array of rowids of size k. Must be freed with sqlite3_freee().
i64 *rowids; i64 *rowids;
// Array of distances of size k. Must be freed with sqlite3_freee(). // Array of distances of size k. Must be freed with sqlite3_freee().
float *distances; f32 *distances;
i64 current_idx; i64 current_idx;
}; };
int vec0_query_knn_data_clear(struct vec0_query_knn_data *knn_data) { int vec0_query_knn_data_clear(struct vec0_query_knn_data *knn_data) {
@ -3341,13 +3339,13 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
// forward delcaration bc vec0Filter uses it // forward delcaration bc vec0Filter uses it
static int vec0Next(sqlite3_vtab_cursor *cur); static int vec0Next(sqlite3_vtab_cursor *cur);
void dethrone(int k, float *base_distances, i64 *base_rowids, size_t chunk_size, void dethrone(int k, f32 *base_distances, i64 *base_rowids, size_t chunk_size,
i32 *chunk_top_idx, float *chunk_distances, i64 *chunk_rowids, i32 *chunk_top_idx, f32 *chunk_distances, i64 *chunk_rowids,
i64 **out_rowids, float **out_distances) { i64 **out_rowids, f32 **out_distances) {
*out_rowids = sqlite3_malloc(k * sizeof(i64)); *out_rowids = sqlite3_malloc(k * sizeof(i64));
todo_assert(out_rowids); todo_assert(out_rowids);
*out_distances = sqlite3_malloc(k * sizeof(float)); *out_distances = sqlite3_malloc(k * sizeof(f32));
todo_assert(out_distances); todo_assert(out_distances);
size_t ptrA = 0; size_t ptrA = 0;
@ -3374,14 +3372,14 @@ void dethrone(int k, float *base_distances, i64 *base_rowids, size_t chunk_size,
* @brief Finds the minimum k items in distances, and writes the indicies to * @brief Finds the minimum k items in distances, and writes the indicies to
* out. * out.
* *
* @param distances input float array of size n, the items to consider. * @param distances input f32 array of size n, the items to consider.
* @param n: size of distances array. * @param n: size of distances array.
* @param out: Output array of size k, will contain the minumum k element * @param out: Output array of size k, will contain the minumum k element
* indicies * indicies
* @param k: Size of output array * @param k: Size of output array
* @return int * @return int
*/ */
int min_idx(const float *distances, i32 n, i32 *out, i32 k) { int min_idx(const f32 *distances, i32 n, i32 *out, i32 k) {
todo_assert(k > 0); todo_assert(k > 0);
todo_assert(k <= n); todo_assert(k <= n);
@ -3473,7 +3471,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
// TODO do we need to ensure that rowid is never -1? // TODO do we need to ensure that rowid is never -1?
topk_rowids[i] = -1; topk_rowids[i] = -1;
} }
float *topk_distances = sqlite3_malloc(k * sizeof(float)); f32 *topk_distances = sqlite3_malloc(k * sizeof(f32));
todo_assert(topk_distances); todo_assert(topk_distances);
for (int i = 0; i < k; i++) { for (int i = 0; i < k; i++) {
topk_distances[i] = __FLT_MAX__; topk_distances[i] = __FLT_MAX__;
@ -3536,7 +3534,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
todo_assert(rc == SQLITE_OK); todo_assert(rc == SQLITE_OK);
// TODO realloc here, like baseVectors // TODO realloc here, like baseVectors
float *chunk_distances = sqlite3_malloc(p->chunk_size * sizeof(float)); f32 *chunk_distances = sqlite3_malloc(p->chunk_size * sizeof(f32));
todo_assert(chunk_distances); todo_assert(chunk_distances);
for (int i = 0; i < p->chunk_size; i++) { for (int i = 0; i < p->chunk_size; i++) {
@ -3559,25 +3557,25 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
} }
} }
float result; f32 result;
switch (vector_column->element_type) { switch (vector_column->element_type) {
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: {
const float *base_i = const f32 *base_i =
((float *)baseVectors) + (i * vector_column->dimensions); ((f32 *)baseVectors) + (i * vector_column->dimensions);
switch (vector_column->distance_metric) { switch (vector_column->distance_metric) {
case VEC0_DISTANCE_METRIC_L2: { case VEC0_DISTANCE_METRIC_L2: {
result = distance_l2_sqr_float(base_i, (float *)queryVector, result = distance_l2_sqr_float(base_i, (f32 *)queryVector,
&vector_column->dimensions); &vector_column->dimensions);
break; break;
} }
case VEC0_DISTANCE_METRIC_COSINE: { case VEC0_DISTANCE_METRIC_COSINE: {
result = distance_cosine_float(base_i, (float *)queryVector, result = distance_cosine_float(base_i, (f32 *)queryVector,
&vector_column->dimensions); &vector_column->dimensions);
break; break;
} }
} }
// result = distance_cosine(base_i, (float *) queryVector, & // result = distance_cosine(base_i, (f32 *) queryVector, &
// vector_column->dimensions); // vector_column->dimensions);
break; break;
} }
@ -3619,7 +3617,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
k <= p->chunk_size ? k : p->chunk_size); k <= p->chunk_size ? k : p->chunk_size);
i64 *out_rowids; i64 *out_rowids;
float *out_distances; f32 *out_distances;
dethrone(k, topk_distances, topk_rowids, p->chunk_size, chunk_topk_idxs, dethrone(k, topk_distances, topk_rowids, p->chunk_size, chunk_topk_idxs,
chunk_distances, chunkRowids, chunk_distances, chunkRowids,
@ -4058,8 +4056,8 @@ static int vec0Update_InsertWriteFinalStepVectors(
switch (element_type) { switch (element_type) {
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: case SQLITE_VEC_ELEMENT_TYPE_FLOAT32:
n = dimensions * sizeof(float); n = dimensions * sizeof(f32);
offset = chunk_offset * dimensions * sizeof(float); offset = chunk_offset * dimensions * sizeof(f32);
break; break;
case SQLITE_VEC_ELEMENT_TYPE_INT8: case SQLITE_VEC_ELEMENT_TYPE_INT8:
n = dimensions * sizeof(i8); n = dimensions * sizeof(i8);
@ -4117,7 +4115,7 @@ int vec0Update_InsertWriteFinalStep(vec0_vtab *p, i64 chunk_rowid,
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: case SQLITE_VEC_ELEMENT_TYPE_FLOAT32:
todo_assert((unsigned long)sqlite3_blob_bytes(blobVectors) == todo_assert((unsigned long)sqlite3_blob_bytes(blobVectors) ==
p->chunk_size * p->vector_columns[i].dimensions * p->chunk_size * p->vector_columns[i].dimensions *
sizeof(float)); sizeof(f32));
break; break;
case SQLITE_VEC_ELEMENT_TYPE_INT8: case SQLITE_VEC_ELEMENT_TYPE_INT8:
todo_assert((unsigned long)sqlite3_blob_bytes(blobVectors) == todo_assert((unsigned long)sqlite3_blob_bytes(blobVectors) ==
@ -4325,7 +4323,7 @@ int vec0Update_UpdateOnRowid(sqlite3_vtab *pVTab, int argc,
void *vector = (void *)sqlite3_value_blob(valueVector); void *vector = (void *)sqlite3_value_blob(valueVector);
switch (p->vector_columns[i].element_type) { switch (p->vector_columns[i].element_type) {
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: case SQLITE_VEC_ELEMENT_TYPE_FLOAT32:
dimensions = sqlite3_value_bytes(valueVector) / sizeof(float); dimensions = sqlite3_value_bytes(valueVector) / sizeof(f32);
break; break;
case SQLITE_VEC_ELEMENT_TYPE_INT8: case SQLITE_VEC_ELEMENT_TYPE_INT8:
dimensions = sqlite3_value_bytes(valueVector) * sizeof(i8); dimensions = sqlite3_value_bytes(valueVector) * sizeof(i8);