fmt + better typedef common types

This commit is contained in:
Alex Garcia 2024-04-20 17:02:19 -07:00
parent e1737b1043
commit 29507aa45d
2 changed files with 169 additions and 173 deletions

View file

@ -1,19 +1,18 @@
#include "sqlite-vec.h"
#include <assert.h> #include <assert.h>
#include <errno.h> #include <errno.h>
#include <inttypes.h>
#include <limits.h> #include <limits.h>
#include <math.h> #include <math.h>
#include <stdbool.h> #include <stdbool.h>
#include <stdint.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <stdint.h>
#include <inttypes.h>
#include "sqlite-vec.h"
#include "sqlite3ext.h" #include "sqlite3ext.h"
SQLITE_EXTENSION_INIT1 SQLITE_EXTENSION_INIT1
#ifndef UINT32_TYPE #ifndef UINT32_TYPE
#ifdef HAVE_UINT32_T #ifdef HAVE_UINT32_T
#define UINT32_TYPE uint32_t #define UINT32_TYPE uint32_t
@ -59,6 +58,15 @@ typedef u_int16_t uint16_t;
typedef u_int64_t uint64_t; typedef u_int64_t uint64_t;
#endif #endif
typedef int8_t i8;
typedef uint8_t u8;
typedef int32_t i32;
typedef sqlite3_int64 i64;
typedef uint32_t u32;
typedef uint64_t u64;
typedef float f32;
typedef size_t usize;
#ifndef UNUSED_PARAMETER #ifndef UNUSED_PARAMETER
#define UNUSED_PARAMETER(X) (void)(X) #define UNUSED_PARAMETER(X) (void)(X)
#endif #endif
@ -194,8 +202,8 @@ static float l2_sqr_float(const void *pVect1v, const void *pVect2v,
} }
static float l2_sqr_int8(const void *pA, const void *pB, const void *pD) { static float l2_sqr_int8(const void *pA, const void *pB, const void *pD) {
int8_t *a = (int8_t *)pA; i8 *a = (i8 *)pA;
int8_t *b = (int8_t *)pB; i8 *b = (i8 *)pB;
size_t d = *((size_t *)pD); size_t d = *((size_t *)pD);
float res = 0; float res = 0;
@ -247,8 +255,8 @@ static float distance_cosine_float(const void *pVect1v, const void *pVect2v,
} }
static float distance_cosine_int8(const void *pA, const void *pB, static float distance_cosine_int8(const void *pA, const void *pB,
const void *pD) { const void *pD) {
int8_t *a = (int8_t *)pA; i8 *a = (i8 *)pA;
int8_t *b = (int8_t *)pB; i8 *b = (i8 *)pB;
size_t d = *((size_t *)pD); size_t d = *((size_t *)pD);
float dot = 0; float dot = 0;
@ -265,7 +273,7 @@ static float distance_cosine_int8(const void *pA, const void *pB,
} }
// https://github.com/facebookresearch/faiss/blob/77e2e79cd0a680adc343b9840dd865da724c579e/faiss/utils/hamming_distance/common.h#L34 // https://github.com/facebookresearch/faiss/blob/77e2e79cd0a680adc343b9840dd865da724c579e/faiss/utils/hamming_distance/common.h#L34
static uint8_t hamdist_table[256] = { static u8 hamdist_table[256] = {
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 2, 3, 2, 3, 3, 4,
2, 3, 3, 4, 3, 4, 4, 5, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 1, 2, 2, 3, 2, 3, 3, 4,
@ -278,14 +286,14 @@ static uint8_t hamdist_table[256] = {
4, 5, 5, 6, 5, 6, 6, 7, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 4, 5, 5, 6, 5, 6, 6, 7, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8}; 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8};
static float distance_hamming_u8(uint8_t *a, uint8_t *b, size_t n) { static float distance_hamming_u8(u8 *a, u8 *b, size_t n) {
int same = 0; int same = 0;
for (unsigned long i = 0; i < n; i++) { for (unsigned long i = 0; i < n; i++) {
same += hamdist_table[a[i] ^ b[i]]; same += hamdist_table[a[i] ^ b[i]];
} }
return (float)same; return (float)same;
} }
static float distance_hamming_u64(uint64_t *a, uint64_t *b, size_t n) { static float distance_hamming_u64(u64 *a, u64 *b, size_t n) {
int same = 0; int same = 0;
for (unsigned long i = 0; i < n; i++) { for (unsigned long i = 0; i < n; i++) {
same += __builtin_popcountl(a[i] ^ b[i]); same += __builtin_popcountl(a[i] ^ b[i]);
@ -298,10 +306,9 @@ static float distance_hamming(const void *a, const void *b, const void *d) {
todo_assert((dimensions % CHAR_BIT) == 0); todo_assert((dimensions % CHAR_BIT) == 0);
if ((dimensions % 64) == 0) { if ((dimensions % 64) == 0) {
return distance_hamming_u64((uint64_t *)a, (uint64_t *)b, return distance_hamming_u64((u64 *)a, (u64 *)b, dimensions / 8 / CHAR_BIT);
dimensions / 8 / CHAR_BIT);
} }
return distance_hamming_u8((uint8_t *)a, (uint8_t *)b, dimensions / CHAR_BIT); return distance_hamming_u8((u8 *)a, (u8 *)b, dimensions / CHAR_BIT);
} }
// from SQLite source: // from SQLite source:
@ -493,7 +500,7 @@ static int fvec_from_value(sqlite3_value *value, float **vector,
return SQLITE_ERROR; return SQLITE_ERROR;
} }
static int bitvec_from_value(sqlite3_value *value, uint8_t **vector, static int bitvec_from_value(sqlite3_value *value, u8 **vector,
size_t *dimensions, vector_cleanup *cleanup, size_t *dimensions, vector_cleanup *cleanup,
char **pzErr) { char **pzErr) {
int value_type = sqlite3_value_type(value); int value_type = sqlite3_value_type(value);
@ -504,7 +511,7 @@ static int bitvec_from_value(sqlite3_value *value, uint8_t **vector,
*pzErr = sqlite3_mprintf("zero-length vectors are not supported."); *pzErr = sqlite3_mprintf("zero-length vectors are not supported.");
return SQLITE_ERROR; return SQLITE_ERROR;
} }
*vector = (uint8_t *)blob; *vector = (u8 *)blob;
*dimensions = bytes * CHAR_BIT; *dimensions = bytes * CHAR_BIT;
*cleanup = vector_cleanup_noop; *cleanup = vector_cleanup_noop;
return SQLITE_OK; return SQLITE_OK;
@ -513,7 +520,7 @@ static int bitvec_from_value(sqlite3_value *value, uint8_t **vector,
return SQLITE_ERROR; return SQLITE_ERROR;
} }
static int int8_vec_from_value(sqlite3_value *value, int8_t **vector, static int int8_vec_from_value(sqlite3_value *value, i8 **vector,
size_t *dimensions, vector_cleanup *cleanup, size_t *dimensions, vector_cleanup *cleanup,
char **pzErr) { char **pzErr) {
int value_type = sqlite3_value_type(value); int value_type = sqlite3_value_type(value);
@ -524,7 +531,7 @@ static int int8_vec_from_value(sqlite3_value *value, int8_t **vector,
*pzErr = sqlite3_mprintf("zero-length vectors are not supported."); *pzErr = sqlite3_mprintf("zero-length vectors are not supported.");
return SQLITE_ERROR; return SQLITE_ERROR;
} }
*vector = (int8_t *)blob; *vector = (i8 *)blob;
*dimensions = bytes; *dimensions = bytes;
*cleanup = vector_cleanup_noop; *cleanup = vector_cleanup_noop;
return SQLITE_OK; return SQLITE_OK;
@ -560,7 +567,7 @@ int vector_from_value(sqlite3_value *value, void **vector, size_t *dimensions,
} }
if (subtype == SQLITE_VEC_ELEMENT_TYPE_BIT) { if (subtype == SQLITE_VEC_ELEMENT_TYPE_BIT) {
int rc = bitvec_from_value(value, (uint8_t **)vector, dimensions, cleanup, int rc = bitvec_from_value(value, (u8 **)vector, dimensions, cleanup,
pzErrorMessage); pzErrorMessage);
if (rc == SQLITE_OK) { if (rc == SQLITE_OK) {
*element_type = SQLITE_VEC_ELEMENT_TYPE_BIT; *element_type = SQLITE_VEC_ELEMENT_TYPE_BIT;
@ -568,7 +575,7 @@ int vector_from_value(sqlite3_value *value, void **vector, size_t *dimensions,
return rc; return rc;
} }
if (subtype == SQLITE_VEC_ELEMENT_TYPE_INT8) { if (subtype == SQLITE_VEC_ELEMENT_TYPE_INT8) {
int rc = int8_vec_from_value(value, (int8_t **)vector, dimensions, cleanup, int rc = int8_vec_from_value(value, (i8 **)vector, dimensions, cleanup,
pzErrorMessage); pzErrorMessage);
if (rc == SQLITE_OK) { if (rc == SQLITE_OK) {
*element_type = SQLITE_VEC_ELEMENT_TYPE_INT8; *element_type = SQLITE_VEC_ELEMENT_TYPE_INT8;
@ -640,9 +647,7 @@ int ensure_vector_match(sqlite3_value *aValue, sqlite3_value *bValue, void **a,
return SQLITE_OK; return SQLITE_OK;
} }
int _cmp(const void *a, const void *b) { int _cmp(const void *a, const void *b) { return (*(i64 *)a - *(i64 *)b); }
return (*(sqlite3_int64 *)a - *(sqlite3_int64 *)b);
}
struct VecNpyFile { struct VecNpyFile {
char *path; char *path;
@ -682,7 +687,7 @@ static void vec_f32(sqlite3_context *context, int argc, sqlite3_value **argv) {
static void vec_bit(sqlite3_context *context, int argc, sqlite3_value **argv) { static void vec_bit(sqlite3_context *context, int argc, sqlite3_value **argv) {
todo_assert(argc == 1); todo_assert(argc == 1);
int rc; int rc;
uint8_t *vector; u8 *vector;
size_t dimensions; size_t dimensions;
vector_cleanup cleanup; vector_cleanup cleanup;
char *errmsg; char *errmsg;
@ -699,7 +704,7 @@ static void vec_bit(sqlite3_context *context, int argc, sqlite3_value **argv) {
static void vec_int8(sqlite3_context *context, int argc, sqlite3_value **argv) { static void vec_int8(sqlite3_context *context, int argc, sqlite3_value **argv) {
todo_assert(argc == 1); todo_assert(argc == 1);
int rc; int rc;
int8_t *vector; i8 *vector;
size_t dimensions; size_t dimensions;
vector_cleanup cleanup; vector_cleanup cleanup;
char *errmsg; char *errmsg;
@ -866,7 +871,7 @@ static void vec_quantize_i8(sqlite3_context *context, int argc,
char *err; char *err;
int rc = fvec_from_value(argv[0], &srcVector, &dimensions, &cleanup, &err); int rc = fvec_from_value(argv[0], &srcVector, &dimensions, &cleanup, &err);
assert(rc == SQLITE_OK); assert(rc == SQLITE_OK);
int8_t *out = sqlite3_malloc(dimensions * sizeof(int8_t)); i8 *out = sqlite3_malloc(dimensions * sizeof(i8));
assert(out); assert(out);
if (argc == 2) { if (argc == 2) {
@ -895,7 +900,7 @@ static void vec_quantize_i8(sqlite3_context *context, int argc,
} }
cleanup(srcVector); cleanup(srcVector);
sqlite3_result_blob(context, out, dimensions * sizeof(int8_t), sqlite3_free); sqlite3_result_blob(context, out, dimensions * sizeof(i8), sqlite3_free);
sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_INT8); sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_INT8);
return; return;
} }
@ -917,7 +922,7 @@ static void vec_quantize_binary(sqlite3_context *context, int argc,
} }
if (elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) { if (elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) {
uint8_t *out = sqlite3_malloc(dimensions / CHAR_BIT); u8 *out = sqlite3_malloc(dimensions / CHAR_BIT);
todo_assert(out); todo_assert(out);
for (size_t i = 0; i < dimensions; i++) { for (size_t i = 0; i < dimensions; i++) {
int res = ((float *)vector)[i] > 0.0; int res = ((float *)vector)[i] > 0.0;
@ -926,10 +931,10 @@ static void vec_quantize_binary(sqlite3_context *context, int argc,
sqlite3_result_blob(context, out, dimensions / CHAR_BIT, sqlite3_free); sqlite3_result_blob(context, out, dimensions / CHAR_BIT, sqlite3_free);
sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_BIT); sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_BIT);
} else if (elementType == SQLITE_VEC_ELEMENT_TYPE_INT8) { } else if (elementType == SQLITE_VEC_ELEMENT_TYPE_INT8) {
uint8_t *out = sqlite3_malloc(dimensions / CHAR_BIT); u8 *out = sqlite3_malloc(dimensions / CHAR_BIT);
todo_assert(out); todo_assert(out);
for (size_t i = 0; i < dimensions; i++) { for (size_t i = 0; i < dimensions; i++) {
int res = ((int8_t *)vector)[i] > 0; int res = ((i8 *)vector)[i] > 0;
out[i / 8] |= (res << (i % 8)); out[i / 8] |= (res << (i % 8));
} }
sqlite3_result_blob(context, out, dimensions / CHAR_BIT, sqlite3_free); sqlite3_result_blob(context, out, dimensions / CHAR_BIT, sqlite3_free);
@ -975,14 +980,14 @@ static void vec_add(sqlite3_context *context, int argc, sqlite3_value **argv) {
goto finish; goto finish;
} }
case SQLITE_VEC_ELEMENT_TYPE_INT8: { case SQLITE_VEC_ELEMENT_TYPE_INT8: {
size_t outSize = dimensions * sizeof(int8_t); size_t outSize = dimensions * sizeof(i8);
int8_t *out = sqlite3_malloc(outSize); i8 *out = sqlite3_malloc(outSize);
if (!out) { if (!out) {
sqlite3_result_error_nomem(context); sqlite3_result_error_nomem(context);
goto finish; goto finish;
} }
for (size_t i = 0; i < dimensions; i++) { for (size_t i = 0; i < dimensions; i++) {
out[i] = ((int8_t *)a)[i] + ((int8_t *)b)[i]; out[i] = ((i8 *)a)[i] + ((i8 *)b)[i];
} }
sqlite3_result_blob(context, out, outSize, sqlite3_free); sqlite3_result_blob(context, out, outSize, sqlite3_free);
sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_INT8); sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_INT8);
@ -1031,14 +1036,14 @@ static void vec_sub(sqlite3_context *context, int argc, sqlite3_value **argv) {
goto finish; goto finish;
} }
case SQLITE_VEC_ELEMENT_TYPE_INT8: { case SQLITE_VEC_ELEMENT_TYPE_INT8: {
size_t outSize = dimensions * sizeof(int8_t); size_t outSize = dimensions * sizeof(i8);
int8_t *out = sqlite3_malloc(outSize); i8 *out = sqlite3_malloc(outSize);
if (!out) { if (!out) {
sqlite3_result_error_nomem(context); sqlite3_result_error_nomem(context);
goto finish; goto finish;
} }
for (size_t i = 0; i < dimensions; i++) { for (size_t i = 0; i < dimensions; i++) {
out[i] = ((int8_t *)a)[i] - ((int8_t *)b)[i]; out[i] = ((i8 *)a)[i] - ((i8 *)b)[i];
} }
sqlite3_result_blob(context, out, outSize, sqlite3_free); sqlite3_result_blob(context, out, outSize, sqlite3_free);
sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_INT8); sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_INT8);
@ -1115,15 +1120,15 @@ static void vec_slice(sqlite3_context *context, int argc,
goto done; goto done;
} }
case SQLITE_VEC_ELEMENT_TYPE_INT8: { case SQLITE_VEC_ELEMENT_TYPE_INT8: {
int8_t *out = sqlite3_malloc(n * sizeof(int8_t)); i8 *out = sqlite3_malloc(n * sizeof(i8));
if (!out) { if (!out) {
sqlite3_result_error_nomem(context); sqlite3_result_error_nomem(context);
return; return;
} }
for (size_t i = 0; i < n; i++) { for (size_t i = 0; i < n; i++) {
out[i] = ((int8_t *)vector)[start + i]; out[i] = ((i8 *)vector)[start + i];
} }
sqlite3_result_blob(context, out, n * sizeof(int8_t), sqlite3_free); sqlite3_result_blob(context, out, n * sizeof(i8), sqlite3_free);
sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_INT8); sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_INT8);
goto done; goto done;
} }
@ -1137,13 +1142,13 @@ static void vec_slice(sqlite3_context *context, int argc,
goto done; goto done;
} }
uint8_t *out = sqlite3_malloc(n / CHAR_BIT); u8 *out = sqlite3_malloc(n / CHAR_BIT);
if (!out) { if (!out) {
sqlite3_result_error_nomem(context); sqlite3_result_error_nomem(context);
return; return;
} }
for (size_t i = 0; i < n / CHAR_BIT; i++) { for (size_t i = 0; i < n / CHAR_BIT; i++) {
out[i] = ((uint8_t *)vector)[(start / CHAR_BIT) + i]; out[i] = ((u8 *)vector)[(start / CHAR_BIT) + i];
} }
sqlite3_result_blob(context, out, n / CHAR_BIT, sqlite3_free); sqlite3_result_blob(context, out, n / CHAR_BIT, sqlite3_free);
sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_BIT); sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_BIT);
@ -1180,9 +1185,9 @@ static void vec_to_json(sqlite3_context *context, int argc,
if (elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) { if (elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) {
sqlite3_str_appendf(str, "%f", ((float *)vector)[i]); sqlite3_str_appendf(str, "%f", ((float *)vector)[i]);
} else if (elementType == SQLITE_VEC_ELEMENT_TYPE_INT8) { } else if (elementType == SQLITE_VEC_ELEMENT_TYPE_INT8) {
sqlite3_str_appendf(str, "%d", ((int8_t *)vector)[i]); sqlite3_str_appendf(str, "%d", ((i8 *)vector)[i]);
} else if (elementType == SQLITE_VEC_ELEMENT_TYPE_BIT) { } else if (elementType == SQLITE_VEC_ELEMENT_TYPE_BIT) {
uint8_t b = (((uint8_t *)vector)[i / 8] >> (i % CHAR_BIT)) & 1; u8 b = (((u8 *)vector)[i / 8] >> (i % CHAR_BIT)) & 1;
sqlite3_str_appendf(str, "%d", b); sqlite3_str_appendf(str, "%d", b);
} }
} }
@ -1479,7 +1484,7 @@ size_t vector_column_byte_size(struct VectorColumnDefinition column) {
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: case SQLITE_VEC_ELEMENT_TYPE_FLOAT32:
return column.dimensions * sizeof(float); return column.dimensions * sizeof(float);
case SQLITE_VEC_ELEMENT_TYPE_INT8: case SQLITE_VEC_ELEMENT_TYPE_INT8:
return column.dimensions * sizeof(int8_t); return column.dimensions * sizeof(i8);
case SQLITE_VEC_ELEMENT_TYPE_BIT: case SQLITE_VEC_ELEMENT_TYPE_BIT:
return column.dimensions / CHAR_BIT; return column.dimensions / CHAR_BIT;
} }
@ -1601,7 +1606,7 @@ struct vec_each_vtab {
typedef struct vec_each_cursor vec_each_cursor; typedef struct vec_each_cursor vec_each_cursor;
struct vec_each_cursor { struct vec_each_cursor {
sqlite3_vtab_cursor base; sqlite3_vtab_cursor base;
sqlite3_int64 iRowid; i64 iRowid;
enum VectorElementType vector_type; enum VectorElementType vector_type;
void *vector; void *vector;
size_t dimensions; size_t dimensions;
@ -1708,7 +1713,7 @@ static int vec_eachRowid(sqlite3_vtab_cursor *cur, sqlite_int64 *pRowid) {
static int vec_eachEof(sqlite3_vtab_cursor *cur) { static int vec_eachEof(sqlite3_vtab_cursor *cur) {
vec_each_cursor *pCur = (vec_each_cursor *)cur; vec_each_cursor *pCur = (vec_each_cursor *)cur;
return pCur->iRowid >= (sqlite3_int64)pCur->dimensions; return pCur->iRowid >= (i64)pCur->dimensions;
} }
static int vec_eachNext(sqlite3_vtab_cursor *cur) { static int vec_eachNext(sqlite3_vtab_cursor *cur) {
@ -1728,13 +1733,13 @@ static int vec_eachColumn(sqlite3_vtab_cursor *cur, sqlite3_context *context,
break; break;
} }
case SQLITE_VEC_ELEMENT_TYPE_BIT: { case SQLITE_VEC_ELEMENT_TYPE_BIT: {
uint8_t x = ((uint8_t *)pCur->vector)[pCur->iRowid / CHAR_BIT]; u8 x = ((u8 *)pCur->vector)[pCur->iRowid / CHAR_BIT];
sqlite3_result_int(context, sqlite3_result_int(context,
(x & (0b10000000 >> ((pCur->iRowid % CHAR_BIT)))) > 0); (x & (0b10000000 >> ((pCur->iRowid % CHAR_BIT)))) > 0);
break; break;
} }
case SQLITE_VEC_ELEMENT_TYPE_INT8: { case SQLITE_VEC_ELEMENT_TYPE_INT8: {
sqlite3_result_int(context, ((int8_t *)pCur->vector)[pCur->iRowid]); sqlite3_result_int(context, ((i8 *)pCur->vector)[pCur->iRowid]);
break; break;
} }
} }
@ -1986,8 +1991,8 @@ int parse_npy(const unsigned char *buffer, size_t bufferLength, void **data,
for (size_t i = 0; i < sizeof(NPY_MAGIC); i++) { for (size_t i = 0; i < sizeof(NPY_MAGIC); i++) {
todo_assert(NPY_MAGIC[i] == buffer[i]); todo_assert(NPY_MAGIC[i] == buffer[i]);
} }
uint8_t major = buffer[6]; u8 major = buffer[6];
uint8_t minor = buffer[7]; u8 minor = buffer[7];
uint16_t headerLength = 0; uint16_t headerLength = 0;
memcpy(&headerLength, &buffer[8], sizeof(uint16_t)); memcpy(&headerLength, &buffer[8], sizeof(uint16_t));
@ -2030,7 +2035,7 @@ typedef enum {
typedef struct vec_npy_each_cursor vec_npy_each_cursor; typedef struct vec_npy_each_cursor vec_npy_each_cursor;
struct vec_npy_each_cursor { struct vec_npy_each_cursor {
sqlite3_vtab_cursor base; sqlite3_vtab_cursor base;
sqlite3_int64 iRowid; i64 iRowid;
// sqlite-vec compatible type of vector // sqlite-vec compatible type of vector
enum VectorElementType elementType; enum VectorElementType elementType;
// number of vectors in the npy array // number of vectors in the npy array
@ -2199,8 +2204,8 @@ static int vec_npy_eachFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum,
for (size_t i = 0; i < countof(NPY_MAGIC); i++) { for (size_t i = 0; i < countof(NPY_MAGIC); i++) {
todo_assert(NPY_MAGIC[i] == header[i]); todo_assert(NPY_MAGIC[i] == header[i]);
} }
uint8_t major = header[6]; u8 major = header[6];
uint8_t minor = header[7]; u8 minor = header[7];
uint16_t headerLength = 0; uint16_t headerLength = 0;
memcpy(&headerLength, &header[8], sizeof(uint16_t)); memcpy(&headerLength, &header[8], sizeof(uint16_t));
@ -2516,8 +2521,8 @@ struct vec0_vtab {
* Parameters: * Parameters:
* 1: rowid of the row/vector to lookup * 1: rowid of the row/vector to lookup
* Result columns: * Result columns:
* 0: chunk_id (sqlite3_int64) * 0: chunk_id (i64)
* 1: chunk_offset (sqlite3_int64) * 1: chunk_offset (i64)
* SQL: "SELECT chunk_id, chunk_offset FROM _rowids WHERE rowid = ?"" * SQL: "SELECT chunk_id, chunk_offset FROM _rowids WHERE rowid = ?""
* *
* Must be cleaned up with sqlite3_finalize(). * Must be cleaned up with sqlite3_finalize().
@ -2579,7 +2584,7 @@ int vec0_column_idx_to_vector_idx(vec0_vtab *pVtab, int column_idx) {
* Must be cleaned up with sqlite3_value_free(). * Must be cleaned up with sqlite3_value_free().
* @returns SQLITE_OK on success, error code on failure * @returns SQLITE_OK on success, error code on failure
*/ */
int vec0_get_id_value_from_rowid(vec0_vtab *pVtab, sqlite3_int64 rowid, int vec0_get_id_value_from_rowid(vec0_vtab *pVtab, i64 rowid,
sqlite3_value **out) { sqlite3_value **out) {
// TODO different stmt than stmtRowidsGetChunkPosition? // TODO different stmt than stmtRowidsGetChunkPosition?
// TODO return rc instead // TODO return rc instead
@ -2597,8 +2602,7 @@ int vec0_get_id_value_from_rowid(vec0_vtab *pVtab, sqlite3_int64 rowid,
} }
// TODO make sure callees use the return value of this function // TODO make sure callees use the return value of this function
int vec0_result_id(vec0_vtab *p, sqlite3_context *context, int vec0_result_id(vec0_vtab *p, sqlite3_context *context, i64 rowid) {
sqlite3_int64 rowid) {
if (!p->pkIsText) { if (!p->pkIsText) {
sqlite3_result_int64(context, rowid); sqlite3_result_int64(context, rowid);
return SQLITE_OK; return SQLITE_OK;
@ -2629,9 +2633,8 @@ int vec0_result_id(vec0_vtab *p, sqlite3_context *context,
* will be stored. * will be stored.
* @return int SQLITE_OK on success. * @return int SQLITE_OK on success.
*/ */
int vec0_get_vector_data(vec0_vtab *pVtab, sqlite3_int64 rowid, int vec0_get_vector_data(vec0_vtab *pVtab, i64 rowid, int vector_column_idx,
int vector_column_idx, void **outVector, void **outVector, int *outVectorSize) {
int *outVectorSize) {
todo_assert((vector_column_idx >= 0) && todo_assert((vector_column_idx >= 0) &&
(vector_column_idx < pVtab->numVectorColumns)); (vector_column_idx < pVtab->numVectorColumns));
@ -2640,10 +2643,8 @@ int vec0_get_vector_data(vec0_vtab *pVtab, sqlite3_int64 rowid,
sqlite3_bind_int64(pVtab->stmtRowidsGetChunkPosition, 1, rowid); sqlite3_bind_int64(pVtab->stmtRowidsGetChunkPosition, 1, rowid);
int rc = sqlite3_step(pVtab->stmtRowidsGetChunkPosition); int rc = sqlite3_step(pVtab->stmtRowidsGetChunkPosition);
todo_assert(rc == SQLITE_ROW); todo_assert(rc == SQLITE_ROW);
sqlite3_int64 chunk_id = i64 chunk_id = sqlite3_column_int64(pVtab->stmtRowidsGetChunkPosition, 1);
sqlite3_column_int64(pVtab->stmtRowidsGetChunkPosition, 1); i64 chunk_offset = sqlite3_column_int64(pVtab->stmtRowidsGetChunkPosition, 2);
sqlite3_int64 chunk_offset =
sqlite3_column_int64(pVtab->stmtRowidsGetChunkPosition, 2);
rc = sqlite3_blob_reopen(pVtab->vectorBlobs[vector_column_idx], chunk_id); rc = sqlite3_blob_reopen(pVtab->vectorBlobs[vector_column_idx], chunk_id);
todo_assert(rc == SQLITE_OK); todo_assert(rc == SQLITE_OK);
@ -2673,9 +2674,8 @@ int vec0_get_vector_data(vec0_vtab *pVtab, sqlite3_int64 rowid,
* @param chunk_offset: Output chunk_offset of the row * @param chunk_offset: Output chunk_offset of the row
* @return int: SQLITE_OK on success, error code on failure * @return int: SQLITE_OK on success, error code on failure
*/ */
int vec0_get_chunk_position(vec0_vtab *p, sqlite3_int64 rowid, int vec0_get_chunk_position(vec0_vtab *p, i64 rowid, i64 *chunk_id,
sqlite3_int64 *chunk_id, i64 *chunk_offset) {
sqlite3_int64 *chunk_offset) {
int rc; int rc;
sqlite3_reset(p->stmtRowidsGetChunkPosition); sqlite3_reset(p->stmtRowidsGetChunkPosition);
sqlite3_clear_bindings(p->stmtRowidsGetChunkPosition); sqlite3_clear_bindings(p->stmtRowidsGetChunkPosition);
@ -2701,11 +2701,11 @@ int vec0_get_chunk_position(vec0_vtab *p, sqlite3_int64 rowid,
* new chunk rowid. * new chunk rowid.
* @return int SQLITE_OK on success, error code otherwise. * @return int SQLITE_OK on success, error code otherwise.
*/ */
int vec0_new_chunk(vec0_vtab *p, sqlite3_int64 *chunk_rowid) { int vec0_new_chunk(vec0_vtab *p, i64 *chunk_rowid) {
int rc; int rc;
char *zSql; char *zSql;
sqlite3_stmt *stmt; sqlite3_stmt *stmt;
sqlite3_int64 rowid; i64 rowid;
// Step 1: Insert a new row in _chunks, capture that new rowid // Step 1: Insert a new row in _chunks, capture that new rowid
zSql = sqlite3_mprintf("INSERT INTO " VEC0_SHADOW_CHUNKS_NAME zSql = sqlite3_mprintf("INSERT INTO " VEC0_SHADOW_CHUNKS_NAME
@ -2725,7 +2725,7 @@ int vec0_new_chunk(vec0_vtab *p, sqlite3_int64 *chunk_rowid) {
p->chunk_size / CHAR_BIT); // validity bitmap p->chunk_size / CHAR_BIT); // validity bitmap
todo_assert(rc == SQLITE_OK); todo_assert(rc == SQLITE_OK);
rc = sqlite3_bind_zeroblob(stmt, 3, rc = sqlite3_bind_zeroblob(stmt, 3,
p->chunk_size * sizeof(sqlite3_int64)); // rowids p->chunk_size * sizeof(i64)); // rowids
todo_assert(rc == SQLITE_OK); todo_assert(rc == SQLITE_OK);
rc = sqlite3_step(stmt); rc = sqlite3_step(stmt);
todo_assert(rc == SQLITE_DONE); todo_assert(rc == SQLITE_DONE);
@ -2740,7 +2740,7 @@ int vec0_new_chunk(vec0_vtab *p, sqlite3_int64 *chunk_rowid) {
for (int i = 0; i < p->numVectorColumns; i++) { for (int i = 0; i < p->numVectorColumns; i++) {
sqlite3_int64 vectorsSize = 0; i64 vectorsSize = 0;
switch (p->vector_columns[i].element_type) { switch (p->vector_columns[i].element_type) {
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: case SQLITE_VEC_ELEMENT_TYPE_FLOAT32:
vectorsSize = vectorsSize =
@ -2748,7 +2748,7 @@ int vec0_new_chunk(vec0_vtab *p, sqlite3_int64 *chunk_rowid) {
break; break;
case SQLITE_VEC_ELEMENT_TYPE_INT8: case SQLITE_VEC_ELEMENT_TYPE_INT8:
vectorsSize = vectorsSize =
p->chunk_size * p->vector_columns[i].dimensions * sizeof(int8_t); p->chunk_size * p->vector_columns[i].dimensions * sizeof(i8);
break; break;
case SQLITE_VEC_ELEMENT_TYPE_BIT: case SQLITE_VEC_ELEMENT_TYPE_BIT:
vectorsSize = vectorsSize =
@ -2797,7 +2797,7 @@ typedef enum {
struct vec0_query_fullscan_data { struct vec0_query_fullscan_data {
sqlite3_stmt *rowids_stmt; sqlite3_stmt *rowids_stmt;
int8_t done; i8 done;
}; };
int vec0_query_fullscan_data_clear( int vec0_query_fullscan_data_clear(
struct vec0_query_fullscan_data *fullscan_data) { struct vec0_query_fullscan_data *fullscan_data) {
@ -2811,12 +2811,12 @@ int vec0_query_fullscan_data_clear(
} }
struct vec0_query_knn_data { struct vec0_query_knn_data {
sqlite3_int64 k; i64 k;
// Array of rowids of size k. Must be freed with sqlite3_freee(). // Array of rowids of size k. Must be freed with sqlite3_freee().
sqlite3_int64 *rowids; i64 *rowids;
// Array of distances of size k. Must be freed with sqlite3_freee(). // Array of distances of size k. Must be freed with sqlite3_freee().
float *distances; float *distances;
sqlite3_int64 current_idx; i64 current_idx;
}; };
int vec0_query_knn_data_clear(struct vec0_query_knn_data *knn_data) { int vec0_query_knn_data_clear(struct vec0_query_knn_data *knn_data) {
if (knn_data->rowids) { if (knn_data->rowids) {
@ -2831,7 +2831,7 @@ int vec0_query_knn_data_clear(struct vec0_query_knn_data *knn_data) {
} }
struct vec0_query_point_data { struct vec0_query_point_data {
sqlite3_int64 rowid; i64 rowid;
void *vectors[VEC0_MAX_VECTOR_COLUMNS]; void *vectors[VEC0_MAX_VECTOR_COLUMNS];
int done; int done;
}; };
@ -3219,7 +3219,7 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
#endif #endif
for (int i = 0; i < pIdxInfo->nConstraint; i++) { for (int i = 0; i < pIdxInfo->nConstraint; i++) {
uint8_t vtabIn = 0; u8 vtabIn = 0;
// sqlite3_vtab_in() was added in SQLite version 3.38 (2022-02-22) // sqlite3_vtab_in() was added in SQLite version 3.38 (2022-02-22)
// ref: https://www.sqlite.org/changes.html#version_3_38_0 // ref: https://www.sqlite.org/changes.html#version_3_38_0
if (sqlite3_libversion_number() >= 3038000) { if (sqlite3_libversion_number() >= 3038000) {
@ -3341,12 +3341,11 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
// forward delcaration bc vec0Filter uses it // forward delcaration bc vec0Filter uses it
static int vec0Next(sqlite3_vtab_cursor *cur); static int vec0Next(sqlite3_vtab_cursor *cur);
void dethrone(int k, float *base_distances, sqlite3_int64 *base_rowids, void dethrone(int k, float *base_distances, i64 *base_rowids, size_t chunk_size,
size_t chunk_size, int32_t *chunk_top_idx, float *chunk_distances, i32 *chunk_top_idx, float *chunk_distances, i64 *chunk_rowids,
sqlite3_int64 *chunk_rowids,
sqlite3_int64 **out_rowids, float **out_distances) { i64 **out_rowids, float **out_distances) {
*out_rowids = sqlite3_malloc(k * sizeof(sqlite3_int64)); *out_rowids = sqlite3_malloc(k * sizeof(i64));
todo_assert(out_rowids); todo_assert(out_rowids);
*out_distances = sqlite3_malloc(k * sizeof(float)); *out_distances = sqlite3_malloc(k * sizeof(float));
todo_assert(out_distances); todo_assert(out_distances);
@ -3382,7 +3381,7 @@ void dethrone(int k, float *base_distances, sqlite3_int64 *base_rowids,
* @param k: Size of output array * @param k: Size of output array
* @return int * @return int
*/ */
int min_idx(const float *distances, int32_t n, int32_t *out, int32_t k) { int min_idx(const float *distances, i32 n, i32 *out, i32 k) {
todo_assert(k > 0); todo_assert(k > 0);
todo_assert(k <= n); todo_assert(k <= n);
@ -3433,12 +3432,12 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
enum VectorElementType elementType; enum VectorElementType elementType;
vector_cleanup cleanup; vector_cleanup cleanup;
char *err; char *err;
rc = vector_from_value(argv[0], &queryVector, &dimensions, &elementType, &cleanup, &err); rc = vector_from_value(argv[0], &queryVector, &dimensions, &elementType,
&cleanup, &err);
todo_assert(elementType == vector_column->element_type); todo_assert(elementType == vector_column->element_type);
todo_assert(dimensions == vector_column->dimensions); todo_assert(dimensions == vector_column->dimensions);
i64 k = sqlite3_value_int64(argv[1]);
sqlite3_int64 k = sqlite3_value_int64(argv[1]);
todo_assert(k >= 0); todo_assert(k >= 0);
if (k == 0) { if (k == 0) {
knn_data->k = 0; knn_data->k = 0;
@ -3455,11 +3454,11 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
int rc; int rc;
arrayRowidsIn = sqlite3_malloc(sizeof(struct Array)); arrayRowidsIn = sqlite3_malloc(sizeof(struct Array));
todo_assert(arrayRowidsIn); todo_assert(arrayRowidsIn);
rc = array_init(arrayRowidsIn, sizeof(sqlite3_int64), 32); rc = array_init(arrayRowidsIn, sizeof(i64), 32);
todo_assert(rc == SQLITE_OK); todo_assert(rc == SQLITE_OK);
for (rc = sqlite3_vtab_in_first(argv[2], &item); rc == SQLITE_OK && item; for (rc = sqlite3_vtab_in_first(argv[2], &item); rc == SQLITE_OK && item;
rc = sqlite3_vtab_in_next(argv[2], &item)) { rc = sqlite3_vtab_in_next(argv[2], &item)) {
sqlite3_int64 rowid = sqlite3_value_int64(item); i64 rowid = sqlite3_value_int64(item);
rc = array_append(arrayRowidsIn, &rowid); rc = array_append(arrayRowidsIn, &rowid);
todo_assert(rc == SQLITE_OK); todo_assert(rc == SQLITE_OK);
} }
@ -3468,7 +3467,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
_cmp); _cmp);
} }
sqlite3_int64 *topk_rowids = sqlite3_malloc(k * sizeof(sqlite3_int64)); i64 *topk_rowids = sqlite3_malloc(k * sizeof(i64));
todo_assert(topk_rowids); todo_assert(topk_rowids);
for (int i = 0; i < k; i++) { for (int i = 0; i < k; i++) {
// TODO do we need to ensure that rowid is never -1? // TODO do we need to ensure that rowid is never -1?
@ -3497,7 +3496,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
todo_assert(rc == SQLITE_OK); todo_assert(rc == SQLITE_OK);
void *baseVectors = NULL; void *baseVectors = NULL;
sqlite3_int64 baseVectorsSize = 0; i64 baseVectorsSize = 0;
while (true) { while (true) {
rc = sqlite3_step(stmtChunks); rc = sqlite3_step(stmtChunks);
@ -3506,22 +3505,21 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
if (rc != SQLITE_ROW) { if (rc != SQLITE_ROW) {
todo("chunks iter error"); todo("chunks iter error");
} }
sqlite3_int64 chunk_id = sqlite3_column_int64(stmtChunks, 0); i64 chunk_id = sqlite3_column_int64(stmtChunks, 0);
unsigned char *chunkValidity = unsigned char *chunkValidity =
(unsigned char *)sqlite3_column_blob(stmtChunks, 1); (unsigned char *)sqlite3_column_blob(stmtChunks, 1);
sqlite3_int64 validitySize = sqlite3_column_bytes(stmtChunks, 1); i64 validitySize = sqlite3_column_bytes(stmtChunks, 1);
todo_assert(validitySize == p->chunk_size / CHAR_BIT); todo_assert(validitySize == p->chunk_size / CHAR_BIT);
sqlite3_int64 *chunkRowids = i64 *chunkRowids = (i64 *)sqlite3_column_blob(stmtChunks, 2);
(sqlite3_int64 *)sqlite3_column_blob(stmtChunks, 2); i64 rowidsSize = sqlite3_column_bytes(stmtChunks, 2);
sqlite3_int64 rowidsSize = sqlite3_column_bytes(stmtChunks, 2); todo_assert(rowidsSize == p->chunk_size * sizeof(i64));
todo_assert(rowidsSize == p->chunk_size * sizeof(sqlite3_int64));
// open the vector chunk blob for the current chunk // open the vector chunk blob for the current chunk
rc = sqlite3_blob_open(p->db, p->schemaName, rc = sqlite3_blob_open(p->db, p->schemaName,
p->shadowVectorChunksNames[vectorColumnIdx], p->shadowVectorChunksNames[vectorColumnIdx],
"vectors", chunk_id, 0, &blobVectors); "vectors", chunk_id, 0, &blobVectors);
todo_assert(rc == SQLITE_OK); todo_assert(rc == SQLITE_OK);
sqlite3_int64 currentBaseVectorsSize = sqlite3_blob_bytes(blobVectors); i64 currentBaseVectorsSize = sqlite3_blob_bytes(blobVectors);
todo_assert((unsigned long)currentBaseVectorsSize == todo_assert((unsigned long)currentBaseVectorsSize ==
p->chunk_size * vector_column_byte_size(*vector_column)); p->chunk_size * vector_column_byte_size(*vector_column));
@ -3552,9 +3550,9 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
// If pre-filtering, make sure the rowid appears in the `rowid in (...)` // If pre-filtering, make sure the rowid appears in the `rowid in (...)`
// list. // list.
if (arrayRowidsIn) { if (arrayRowidsIn) {
sqlite3_int64 rowid = chunkRowids[i]; i64 rowid = chunkRowids[i];
void *in = bsearch(&rowid, arrayRowidsIn->z, arrayRowidsIn->length, void *in = bsearch(&rowid, arrayRowidsIn->z, arrayRowidsIn->length,
sizeof(sqlite3_int64), _cmp); sizeof(i64), _cmp);
if (!in) { if (!in) {
chunk_distances[i] = __FLT_MAX__; chunk_distances[i] = __FLT_MAX__;
continue; continue;
@ -3584,17 +3582,17 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
break; break;
} }
case SQLITE_VEC_ELEMENT_TYPE_INT8: { case SQLITE_VEC_ELEMENT_TYPE_INT8: {
const int8_t *base_i = const i8 *base_i =
((int8_t *)baseVectors) + (i * vector_column->dimensions); ((i8 *)baseVectors) + (i * vector_column->dimensions);
switch (vector_column->distance_metric) { switch (vector_column->distance_metric) {
case VEC0_DISTANCE_METRIC_L2: { case VEC0_DISTANCE_METRIC_L2: {
result = distance_l2_sqr_int8(base_i, (int8_t *)queryVector, result = distance_l2_sqr_int8(base_i, (i8 *)queryVector,
&vector_column->dimensions); &vector_column->dimensions);
break; break;
} }
case VEC0_DISTANCE_METRIC_COSINE: { case VEC0_DISTANCE_METRIC_COSINE: {
result = distance_cosine_int8(base_i, (int8_t *)queryVector, result = distance_cosine_int8(base_i, (i8 *)queryVector,
&vector_column->dimensions); &vector_column->dimensions);
break; break;
} }
@ -3603,9 +3601,9 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
break; break;
} }
case SQLITE_VEC_ELEMENT_TYPE_BIT: { case SQLITE_VEC_ELEMENT_TYPE_BIT: {
const uint8_t *base_i = ((uint8_t *)baseVectors) + const u8 *base_i = ((u8 *)baseVectors) +
(i * (vector_column->dimensions / CHAR_BIT)); (i * (vector_column->dimensions / CHAR_BIT));
result = distance_hamming(base_i, (uint8_t *)queryVector, result = distance_hamming(base_i, (u8 *)queryVector,
&vector_column->dimensions); &vector_column->dimensions);
break; break;
} }
@ -3615,12 +3613,12 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
} }
// now that we have the distances // now that we have the distances
int32_t *chunk_topk_idxs = sqlite3_malloc(k * sizeof(int32_t)); i32 *chunk_topk_idxs = sqlite3_malloc(k * sizeof(i32));
todo_assert(chunk_topk_idxs); todo_assert(chunk_topk_idxs);
min_idx(chunk_distances, p->chunk_size, chunk_topk_idxs, min_idx(chunk_distances, p->chunk_size, chunk_topk_idxs,
k <= p->chunk_size ? k : p->chunk_size); k <= p->chunk_size ? k : p->chunk_size);
sqlite3_int64 *out_rowids; i64 *out_rowids;
float *out_distances; float *out_distances;
dethrone(k, topk_distances, topk_rowids, p->chunk_size, chunk_topk_idxs, dethrone(k, topk_distances, topk_rowids, p->chunk_size, chunk_topk_idxs,
chunk_distances, chunkRowids, chunk_distances, chunkRowids,
@ -3699,7 +3697,7 @@ int vec0Filter_point(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
UNUSED_PARAMETER(idxStr); UNUSED_PARAMETER(idxStr);
int rc; int rc;
todo_assert(argc == 1); todo_assert(argc == 1);
sqlite3_int64 rowid = sqlite3_value_int64(argv[0]); i64 rowid = sqlite3_value_int64(argv[0]);
pCur->query_plan = SQLITE_VEC0_QUERYPLAN_POINT; pCur->query_plan = SQLITE_VEC0_QUERYPLAN_POINT;
struct vec0_query_point_data *point_data = struct vec0_query_point_data *point_data =
@ -3800,8 +3798,7 @@ static int vec0Eof(sqlite3_vtab_cursor *cur) {
static int vec0Column_fullscan(vec0_vtab *pVtab, vec0_cursor *pCur, static int vec0Column_fullscan(vec0_vtab *pVtab, vec0_cursor *pCur,
sqlite3_context *context, int i) { sqlite3_context *context, int i) {
todo_assert(pCur->fullscan_data); todo_assert(pCur->fullscan_data);
sqlite3_int64 rowid = i64 rowid = sqlite3_column_int64(pCur->fullscan_data->rowids_stmt, 0);
sqlite3_column_int64(pCur->fullscan_data->rowids_stmt, 0);
if (i == VEC0_COLUMN_ID) { if (i == VEC0_COLUMN_ID) {
vec0_result_id(pVtab, context, rowid); vec0_result_id(pVtab, context, rowid);
} else if (vec0_column_idx_is_vector(pVtab, i)) { } else if (vec0_column_idx_is_vector(pVtab, i)) {
@ -3853,7 +3850,7 @@ static int vec0Column_knn(vec0_vtab *pVtab, vec0_cursor *pCur,
sqlite3_context *context, int i) { sqlite3_context *context, int i) {
todo_assert(pCur->knn_data); todo_assert(pCur->knn_data);
if (i == VEC0_COLUMN_ID) { if (i == VEC0_COLUMN_ID) {
sqlite3_int64 rowid = pCur->knn_data->rowids[pCur->knn_data->current_idx]; i64 rowid = pCur->knn_data->rowids[pCur->knn_data->current_idx];
vec0_result_id(pVtab, context, rowid); vec0_result_id(pVtab, context, rowid);
return SQLITE_OK; return SQLITE_OK;
} }
@ -3902,16 +3899,16 @@ static int vec0Column(sqlite3_vtab_cursor *cur, sqlite3_context *context,
* *
* @param p: virtual table * @param p: virtual table
* @param idValue: Value containing the inserted rowid/id value. * @param idValue: Value containing the inserted rowid/id value.
* @param rowid: Output rowid, will point to the "real" sqlite3_int64 rowid * @param rowid: Output rowid, will point to the "real" i64 rowid
* value that was inserted * value that was inserted
* @return int SQLITE_OK on success, error code on failure * @return int SQLITE_OK on success, error code on failure
*/ */
int vec0Update_InsertRowidStep(vec0_vtab *p, sqlite3_value *idValue, int vec0Update_InsertRowidStep(vec0_vtab *p, sqlite3_value *idValue,
sqlite3_int64 *rowid) { i64 *rowid) {
/** /**
* An insert into a vec0 table can happen a few different ways: * An insert into a vec0 table can happen a few different ways:
* 1) With default INTEGER primary key: With a supplied sqlite3_int64 rowid * 1) With default INTEGER primary key: With a supplied i64 rowid
* 2) With default INTEGER primary key: WITHOUT a supplied rowid * 2) With default INTEGER primary key: WITHOUT a supplied rowid
* 3) With TEXT primary key: supplied text rowid * 3) With TEXT primary key: supplied text rowid
*/ */
@ -3937,9 +3934,9 @@ int vec0Update_InsertRowidStep(vec0_vtab *p, sqlite3_value *idValue,
#endif #endif
} }
// Option 1: User supplied a sqlite3_int64 rowid // Option 1: User supplied a i64 rowid
else if (sqlite3_value_type(idValue) == SQLITE_INTEGER) { else if (sqlite3_value_type(idValue) == SQLITE_INTEGER) {
sqlite3_int64 suppliedRowid = sqlite3_value_int64(idValue); i64 suppliedRowid = sqlite3_value_int64(idValue);
sqlite3_reset(p->stmtRowidsInsertRowid); sqlite3_reset(p->stmtRowidsInsertRowid);
sqlite3_clear_bindings(p->stmtRowidsInsertRowid); sqlite3_clear_bindings(p->stmtRowidsInsertRowid);
@ -3987,12 +3984,12 @@ int vec0Update_InsertRowidStep(vec0_vtab *p, sqlite3_value *idValue,
* @return int SQLITE_OK on success, error code on failure * @return int SQLITE_OK on success, error code on failure
*/ */
int vec0Update_InsertNextAvailableStep( int vec0Update_InsertNextAvailableStep(
vec0_vtab *p, sqlite3_int64 *chunk_rowid, sqlite3_int64 *chunk_offset, vec0_vtab *p, i64 *chunk_rowid, i64 *chunk_offset,
sqlite3_blob **blobChunksValidity, sqlite3_blob **blobChunksValidity,
const unsigned char **bufferChunksValidity) { const unsigned char **bufferChunksValidity) {
int rc; int rc;
sqlite3_int64 validitySize; i64 validitySize;
*chunk_offset = -1; *chunk_offset = -1;
sqlite3_reset(p->stmtLatestChunk); sqlite3_reset(p->stmtLatestChunk);
@ -4054,7 +4051,7 @@ done:
} }
static int vec0Update_InsertWriteFinalStepVectors( static int vec0Update_InsertWriteFinalStepVectors(
sqlite3_blob *blobVectors, const void *bVector, sqlite3_int64 chunk_offset, sqlite3_blob *blobVectors, const void *bVector, i64 chunk_offset,
size_t dimensions, enum VectorElementType element_type) { size_t dimensions, enum VectorElementType element_type) {
int n; int n;
int offset; int offset;
@ -4065,8 +4062,8 @@ static int vec0Update_InsertWriteFinalStepVectors(
offset = chunk_offset * dimensions * sizeof(float); offset = chunk_offset * dimensions * sizeof(float);
break; break;
case SQLITE_VEC_ELEMENT_TYPE_INT8: case SQLITE_VEC_ELEMENT_TYPE_INT8:
n = dimensions * sizeof(int8_t); n = dimensions * sizeof(i8);
offset = chunk_offset * dimensions * sizeof(int8_t); offset = chunk_offset * dimensions * sizeof(i8);
break; break;
case SQLITE_VEC_ELEMENT_TYPE_BIT: case SQLITE_VEC_ELEMENT_TYPE_BIT:
n = dimensions / CHAR_BIT; n = dimensions / CHAR_BIT;
@ -4092,9 +4089,9 @@ static int vec0Update_InsertWriteFinalStepVectors(
* assigned chunk. * assigned chunk.
* @return int SQLITE_OK on success, error code on failure * @return int SQLITE_OK on success, error code on failure
*/ */
int vec0Update_InsertWriteFinalStep(vec0_vtab *p, sqlite3_int64 chunk_rowid, int vec0Update_InsertWriteFinalStep(vec0_vtab *p, i64 chunk_rowid,
sqlite3_int64 chunk_offset, i64 chunk_offset, i64 rowid,
sqlite3_int64 rowid, void *vectorDatas[], void *vectorDatas[],
sqlite3_blob *blobChunksValidity, sqlite3_blob *blobChunksValidity,
const unsigned char *bufferChunksValidity) { const unsigned char *bufferChunksValidity) {
int rc; int rc;
@ -4124,8 +4121,7 @@ int vec0Update_InsertWriteFinalStep(vec0_vtab *p, sqlite3_int64 chunk_rowid,
break; break;
case SQLITE_VEC_ELEMENT_TYPE_INT8: case SQLITE_VEC_ELEMENT_TYPE_INT8:
todo_assert((unsigned long)sqlite3_blob_bytes(blobVectors) == todo_assert((unsigned long)sqlite3_blob_bytes(blobVectors) ==
p->chunk_size * p->vector_columns[i].dimensions * p->chunk_size * p->vector_columns[i].dimensions * sizeof(i8));
sizeof(int8_t));
break; break;
case SQLITE_VEC_ELEMENT_TYPE_BIT: case SQLITE_VEC_ELEMENT_TYPE_BIT:
todo_assert((unsigned long)sqlite3_blob_bytes(blobVectors) == todo_assert((unsigned long)sqlite3_blob_bytes(blobVectors) ==
@ -4145,9 +4141,9 @@ int vec0Update_InsertWriteFinalStep(vec0_vtab *p, sqlite3_int64 chunk_rowid,
chunk_rowid, 1, &blobChunksRowids); chunk_rowid, 1, &blobChunksRowids);
todo_assert(rc == SQLITE_OK); todo_assert(rc == SQLITE_OK);
todo_assert(sqlite3_blob_bytes(blobChunksRowids) == todo_assert(sqlite3_blob_bytes(blobChunksRowids) ==
p->chunk_size * sizeof(sqlite3_int64)); p->chunk_size * sizeof(i64));
rc = sqlite3_blob_write(blobChunksRowids, &rowid, sizeof(sqlite3_int64), rc = sqlite3_blob_write(blobChunksRowids, &rowid, sizeof(i64),
chunk_offset * sizeof(sqlite3_int64)); chunk_offset * sizeof(i64));
todo_assert(rc == SQLITE_OK); todo_assert(rc == SQLITE_OK);
sqlite3_blob_close(blobChunksRowids); sqlite3_blob_close(blobChunksRowids);
@ -4176,16 +4172,16 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
int rc; int rc;
// Rowid for the inserted row, deterimined by the inserted ID + _rowids shadow // Rowid for the inserted row, deterimined by the inserted ID + _rowids shadow
// table // table
sqlite3_int64 rowid; i64 rowid;
// Array to hold the vector data of the inserted row. Individual elements will // Array to hold the vector data of the inserted row. Individual elements will
// have a lifetime bound to the argv[..] values. // have a lifetime bound to the argv[..] values.
void *vectorDatas[VEC0_MAX_VECTOR_COLUMNS]; void *vectorDatas[VEC0_MAX_VECTOR_COLUMNS];
// Rowid of the chunk in the _chunks shadow table that the row will be a part // Rowid of the chunk in the _chunks shadow table that the row will be a part
// of. // of.
sqlite3_int64 chunk_rowid; i64 chunk_rowid;
// offset within the chunk where the rowid belongs // offset within the chunk where the rowid belongs
sqlite3_int64 chunk_offset; i64 chunk_offset;
// a write-able blob of the validity column for the given chunk. Used to mark // a write-able blob of the validity column for the given chunk. Used to mark
// validity bit // validity bit
@ -4262,8 +4258,8 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite_int64 rowid) { int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite_int64 rowid) {
vec0_vtab *p = (vec0_vtab *)pVTab; vec0_vtab *p = (vec0_vtab *)pVTab;
int rc; int rc;
sqlite3_int64 chunk_id; i64 chunk_id;
sqlite3_int64 chunk_offset; i64 chunk_offset;
sqlite3_blob *blobChunksValidity = NULL; sqlite3_blob *blobChunksValidity = NULL;
// 1. get chunk_id and chunk_offset from _rowids // 1. get chunk_id and chunk_offset from _rowids
@ -4312,9 +4308,9 @@ int vec0Update_UpdateOnRowid(sqlite3_vtab *pVTab, int argc,
UNUSED_PARAMETER(argc); UNUSED_PARAMETER(argc);
vec0_vtab *p = (vec0_vtab *)pVTab; vec0_vtab *p = (vec0_vtab *)pVTab;
int rc; int rc;
sqlite3_int64 chunk_id; i64 chunk_id;
sqlite3_int64 chunk_offset; i64 chunk_offset;
sqlite3_int64 rowid = sqlite3_value_int64(argv[0]); i64 rowid = sqlite3_value_int64(argv[0]);
// 1. get chunk_id and chunk_offset from _rowids // 1. get chunk_id and chunk_offset from _rowids
rc = vec0_get_chunk_position(p, rowid, &chunk_id, &chunk_offset); rc = vec0_get_chunk_position(p, rowid, &chunk_id, &chunk_offset);
@ -4332,7 +4328,7 @@ int vec0Update_UpdateOnRowid(sqlite3_vtab *pVTab, int argc,
dimensions = sqlite3_value_bytes(valueVector) / sizeof(float); dimensions = sqlite3_value_bytes(valueVector) / sizeof(float);
break; break;
case SQLITE_VEC_ELEMENT_TYPE_INT8: case SQLITE_VEC_ELEMENT_TYPE_INT8:
dimensions = sqlite3_value_bytes(valueVector) * sizeof(int8_t); dimensions = sqlite3_value_bytes(valueVector) * sizeof(i8);
break; break;
case SQLITE_VEC_ELEMENT_TYPE_BIT: case SQLITE_VEC_ELEMENT_TYPE_BIT:
dimensions = sqlite3_value_bytes(valueVector) * CHAR_BIT; dimensions = sqlite3_value_bytes(valueVector) * CHAR_BIT;
@ -4477,7 +4473,7 @@ int sqlite3_mmap_warm(sqlite3 *db, const char *zDb) {
sqlite3_file *pFd = 0; sqlite3_file *pFd = 0;
rc = sqlite3_file_control(db, zDb, SQLITE_FCNTL_FILE_POINTER, &pFd); rc = sqlite3_file_control(db, zDb, SQLITE_FCNTL_FILE_POINTER, &pFd);
if (rc == SQLITE_OK && pFd->pMethods && pFd->pMethods->iVersion >= 3) { if (rc == SQLITE_OK && pFd->pMethods && pFd->pMethods->iVersion >= 3) {
sqlite3_int64 iPg = 1; i64 iPg = 1;
sqlite3_io_methods const *p = pFd->pMethods; sqlite3_io_methods const *p = pFd->pMethods;
while (1) { while (1) {
unsigned char *pMap; unsigned char *pMap;

View file

@ -16,6 +16,7 @@ EXT_PATH = "./dist/vec0"
SUPPORTS_SUBTYPE = sqlite3.version_info[1] > 38 SUPPORTS_SUBTYPE = sqlite3.version_info[1] > 38
def bitmap_full(n: int) -> bytearray: def bitmap_full(n: int) -> bytearray:
assert (n % 8) == 0 assert (n % 8) == 0
return bytes([0xFF] * int(n / 8)) return bytes([0xFF] * int(n / 8))
@ -614,18 +615,17 @@ def test_smoke():
assert re.match( assert re.match(
"SCAN (TABLE )?vec_xyz VIRTUAL TABLE INDEX 0:knn:", "SCAN (TABLE )?vec_xyz VIRTUAL TABLE INDEX 0:knn:",
explain_query_plan( explain_query_plan(
"select * from vec_xyz where a match X'' and k = 10 order by distance" "select * from vec_xyz where a match X'' and k = 10 order by distance"
) ),
) )
assert re.match( assert re.match(
"SCAN (TABLE )?vec_xyz VIRTUAL TABLE INDEX 0:fullscan", "SCAN (TABLE )?vec_xyz VIRTUAL TABLE INDEX 0:fullscan",
explain_query_plan("select * from vec_xyz") explain_query_plan("select * from vec_xyz"),
) )
assert re.match( assert re.match(
"SCAN (TABLE )?vec_xyz VIRTUAL TABLE INDEX 3:point", "SCAN (TABLE )?vec_xyz VIRTUAL TABLE INDEX 3:point",
explain_query_plan("select * from vec_xyz where rowid = 4") explain_query_plan("select * from vec_xyz where rowid = 4"),
) )
db.execute("insert into vec_xyz(rowid, a) select 1, X'000000000000803f'") db.execute("insert into vec_xyz(rowid, a) select 1, X'000000000000803f'")