fix msvc build (#14)

This commit is contained in:
k.h.lai 2024-06-09 14:53:12 +08:00 committed by GitHub
parent 0c75fd292f
commit 80531b53e0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 23 additions and 15 deletions

View file

@ -9,6 +9,7 @@
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <float.h>
#include "sqlite3ext.h"
SQLITE_EXTENSION_INIT1
@ -341,6 +342,12 @@ static f32 distance_hamming_u8(u8 *a, u8 *b, size_t n) {
}
return (f32)same;
}
#ifdef _MSC_VER
# include <intrin.h>
# define __builtin_popcountl __popcnt64
#endif
static f32 distance_hamming_u64(u64 *a, u64 *b, size_t n) {
int same = 0;
for (unsigned long i = 0; i < n; i++) {
@ -429,7 +436,7 @@ int array_append(struct Array *array, const void *element) {
return SQLITE_NOMEM;
}
}
memcpy(&array->z[array->length * array->element_size], element,
memcpy((char *)(&array->z)[array->length * array->element_size], element,
array->element_size);
array->length++;
return SQLITE_OK;
@ -2510,7 +2517,7 @@ static int vec_npy_eachColumn(sqlite3_vtab_cursor *cur,
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: {
sqlite3_result_blob(
context,
&pCur->vector[pCur->iRowid * pCur->nDimensions * sizeof(f32)],
(f32 *)(&pCur->vector)[pCur->iRowid * pCur->nDimensions * sizeof(f32)],
pCur->nDimensions * sizeof(f32), SQLITE_STATIC);
break;
}
@ -2524,7 +2531,7 @@ static int vec_npy_eachColumn(sqlite3_vtab_cursor *cur,
switch (pCur->elementType) {
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: {
sqlite3_result_blob(context,
&pCur->fileBuffer[pCur->bufferIndex *
(f32 *)(&pCur->fileBuffer)[pCur->bufferIndex *
pCur->nDimensions * sizeof(f32)],
pCur->nDimensions * sizeof(f32), SQLITE_TRANSIENT);
break;
@ -3678,7 +3685,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
f32 *topk_distances = sqlite3_malloc(k * sizeof(f32));
todo_assert(topk_distances);
for (int i = 0; i < k; i++) {
topk_distances[i] = __FLT_MAX__;
topk_distances[i] = FLT_MAX;
}
// for each chunk, get top min(k, chunk_size) rowid + distances to query vec.
@ -3746,7 +3753,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
// Ensure the current vector is "valid" in the validity bitmap.
// If not, skip and continue on
if (!(((chunkValidity[i / CHAR_BIT]) >> (i % CHAR_BIT)) & 1)) {
chunk_distances[i] = __FLT_MAX__;
chunk_distances[i] = FLT_MAX;
continue;
};
// If pre-filtering, make sure the rowid appears in the `rowid in (...)`
@ -3756,7 +3763,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
void *in = bsearch(&rowid, arrayRowidsIn->z, arrayRowidsIn->length,
sizeof(i64), _cmp);
if (!in) {
chunk_distances[i] = __FLT_MAX__;
chunk_distances[i] = FLT_MAX;
continue;
}
}
@ -3988,7 +3995,7 @@ static int vec0Eof(sqlite3_vtab_cursor *cur) {
todo_assert(pCur->knn_data);
return (pCur->knn_data->current_idx >= pCur->knn_data->k) ||
(pCur->knn_data->distances[pCur->knn_data->current_idx] ==
__FLT_MAX__);
FLT_MAX);
}
case SQLITE_VEC0_QUERYPLAN_POINT: {
todo_assert(pCur->point_data);
@ -5392,7 +5399,7 @@ static int vec_expoFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum,
f32 *chunk_distances = sqlite3_malloc(chunk_size * sizeof(f32));
todo_assert(chunk_distances);
for (int i = 0; i < k; i++) {
topk_distances[i] = __FLT_MAX__;
topk_distances[i] = FLT_MAX;
}
i64 *chunk_rowids = sqlite3_malloc(chunk_size * sizeof(i64));
todo_assert(chunk_rowids);
@ -5674,9 +5681,6 @@ __declspec(dllexport)
#define SQLITE_RESULT_SUBTYPE 0x001000000
#endif
#ifdef _WIN32
__declspec(dllexport)
#endif
int sqlite3_vec_init(sqlite3 *db, char **pzErrMsg,
const sqlite3_api_routines *pApi) {
SQLITE_EXTENSION_INIT2(pApi);
@ -5684,7 +5688,7 @@ __declspec(dllexport)
const int DEFAULT_FLAGS =
SQLITE_UTF8 | SQLITE_INNOCUOUS | SQLITE_DETERMINISTIC;
static const struct {
const struct {
char *zFName;
void (*xFunc)(sqlite3_context *, int, sqlite3_value **);
int nArg;
@ -5767,9 +5771,6 @@ __declspec(dllexport)
return SQLITE_OK;
}
#ifdef _WIN32
__declspec(dllexport)
#endif
int sqlite3_vec_fs_read_init(sqlite3 *db, char **pzErrMsg,
const sqlite3_api_routines *pApi) {
UNUSED_PARAMETER(pzErrMsg);