vec_expo experiment

This commit is contained in:
Alex Garcia 2024-05-26 22:59:29 -07:00
parent 8f3a8c2faf
commit 8c418f9349

View file

@ -5024,6 +5024,399 @@ static sqlite3_module vec_static_blob_entriesModule = {
/* xShadowName */ 0};
#pragma endregion
#pragma region vec_expo() table function
void dethrone2(int k, f32 *base_distances, i64 *base_rowids, size_t chunk_size,
i32 *chunk_top_idx, f32 *chunk_distances, i64 *chunk_rowids,
i64 **out_rowids, f32 **out_distances) {
*out_rowids = sqlite3_malloc(k * sizeof(i64));
todo_assert(*out_rowids);
*out_distances = sqlite3_malloc(k * sizeof(f32));
todo_assert(*out_distances);
size_t ptrA = 0;
size_t ptrB = 0;
for (int i = 0; i < k; i++) {
if (ptrA < chunk_size && (ptrB >= k || chunk_distances[chunk_top_idx[ptrA]] < base_distances[ptrB])) {
(*out_rowids)[i] = chunk_rowids[chunk_top_idx[ptrA]];
(*out_distances)[i] = chunk_distances[chunk_top_idx[ptrA]];
ptrA++;
} else if (ptrB < k) {
(*out_rowids)[i] = base_rowids[ptrB];
(*out_distances)[i] = base_distances[ptrB];
ptrB++;
}
}
}
typedef struct vec_expo_vtab vec_expo_vtab;
struct vec_expo_vtab {
sqlite3_vtab base;
sqlite3 * db;
char * table;
char * column;
};
typedef struct vec_expo_cursor vec_expo_cursor;
struct vec_expo_cursor {
sqlite3_vtab_cursor base;
sqlite3_int64 iRowid;
vec_sbe_query_plan query_plan;
struct vec0_query_knn_data * knn_data;
};
static int vec_expoConnect(sqlite3 *db, void *pAux, int argc,
const char *const *argv, sqlite3_vtab **ppVtab,
char **pzErr) {
vec_expo_vtab *pNew;
assert(argc==5);
#define VEC_EXPO_VECTOR 0
#define VEC_EXPO_DISTANCE 1
#define VEC_EXPO_K 2
int rc = sqlite3_declare_vtab(
db, "CREATE TABLE x(vector, distance hidden, k hidden)");
assert(rc == SQLITE_OK);
pNew = sqlite3_malloc(sizeof(*pNew));
if (pNew == 0)
return SQLITE_NOMEM;
memset(pNew, 0, sizeof(*pNew));
pNew->db = db;
pNew->table = sqlite3_mprintf("%s", argv[3]);
pNew->column = sqlite3_mprintf("%s", argv[4]);
*ppVtab = (sqlite3_vtab *)pNew;
return SQLITE_OK;
}
static int vec_expoCreate(sqlite3 *db, void *pAux, int argc,
const char *const *argv, sqlite3_vtab **ppVtab,
char **pzErr) {
return vec_expoConnect(db, pAux, argc, argv, ppVtab, pzErr);
}
static int vec_expoDisconnect(sqlite3_vtab *pVtab) {
vec_expo_vtab *p = (vec_expo_vtab *)pVtab;
sqlite3_free(p);
return SQLITE_OK;
}
static int vec_expoOpen(sqlite3_vtab *p, sqlite3_vtab_cursor **ppCursor) {
vec_expo_cursor *pCur;
pCur = sqlite3_malloc(sizeof(*pCur));
if (pCur == 0)
return SQLITE_NOMEM;
memset(pCur, 0, sizeof(*pCur));
*ppCursor = &pCur->base;
return SQLITE_OK;
}
static int vec_expoClose(sqlite3_vtab_cursor *cur) {
vec_expo_cursor *pCur = (vec_expo_cursor *)cur;
sqlite3_free(pCur);
return SQLITE_OK;
}
static int vec_expoBestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
vec_expo_vtab *p = (vec_expo_vtab *)pVTab;
int iMatchTerm = -1;
int iLimitTerm = -1;
int iRowidTerm = -1; // TODO point query
int iKTerm = -1;
for (int i = 0; i < pIdxInfo->nConstraint; i++) {
if (!pIdxInfo->aConstraint[i].usable)
continue;
int iColumn = pIdxInfo->aConstraint[i].iColumn;
int op = pIdxInfo->aConstraint[i].op;
if (op == SQLITE_INDEX_CONSTRAINT_MATCH && iColumn == VEC_EXPO_VECTOR) {
if (iMatchTerm > -1) {
// TODO only 1 match operator at a time
return SQLITE_ERROR;
}
iMatchTerm = i;
}
if (op == SQLITE_INDEX_CONSTRAINT_LIMIT) {
iLimitTerm = i;
}
if (op == SQLITE_INDEX_CONSTRAINT_EQ && iColumn == VEC_EXPO_K) {
iKTerm = i;
}
}
if(iMatchTerm >= 0) {
if (iLimitTerm < 0 && iKTerm < 0) {
// TODO: error, match on vector1 should require a limit for KNN
return SQLITE_ERROR;
}
if (iLimitTerm >= 0 && iKTerm >= 0) {
return SQLITE_ERROR; // limit or k, not both
}
if (pIdxInfo->nOrderBy < 1) {
SET_VTAB_ERROR("ORDER BY distance required");
return SQLITE_CONSTRAINT;
}
if (pIdxInfo->nOrderBy > 1) {
// TODO error, orderByConsumed is all or nothing, only 1 order by allowed
SET_VTAB_ERROR("more than 1 ORDER BY clause provided");
return SQLITE_CONSTRAINT;
}
if (pIdxInfo->aOrderBy[0].iColumn != VEC_EXPO_DISTANCE) {
SET_VTAB_ERROR("ORDER BY must be on the distance column");
return SQLITE_CONSTRAINT;
}
if (pIdxInfo->aOrderBy[0].desc) {
SET_VTAB_ERROR("Only ascending in ORDER BY distance clause is supported, "
"DESC is not supported yet.");
return SQLITE_CONSTRAINT;
}
pIdxInfo->idxNum = VEC_SBE__QUERYPLAN_KNN;
pIdxInfo->estimatedCost = (double)10; // TODO vtab_value(?) as hint?
pIdxInfo->estimatedRows = 10;// TODO vtab_value(?) as hint?
pIdxInfo->orderByConsumed = 1;
pIdxInfo->aConstraintUsage[iMatchTerm].argvIndex = 1;
pIdxInfo->aConstraintUsage[iMatchTerm].omit = 1;
if (iLimitTerm >= 0) {
pIdxInfo->aConstraintUsage[iLimitTerm].argvIndex = 2;
pIdxInfo->aConstraintUsage[iLimitTerm].omit = 1;
} else {
pIdxInfo->aConstraintUsage[iKTerm].argvIndex = 2;
pIdxInfo->aConstraintUsage[iKTerm].omit = 1;
}
}
else {
pIdxInfo->idxNum = VEC_SBE__QUERYPLAN_FULLSCAN;
pIdxInfo->estimatedCost = 10000.0;
pIdxInfo->estimatedRows = 10000;
}
return SQLITE_OK;
}
static int vec_expoFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum,
const char *idxStr, int argc, sqlite3_value **argv) {
vec_expo_cursor *pCur = (vec_expo_cursor *)pVtabCursor;
vec_expo_vtab *p = (vec_expo_vtab *)pCur->base.pVtab;
if(idxNum == VEC_SBE__QUERYPLAN_KNN) {
pCur->query_plan = VEC_SBE__QUERYPLAN_KNN;
struct vec0_query_knn_data *knn_data =
sqlite3_malloc(sizeof(struct vec0_query_knn_data));
if (!knn_data) {
return SQLITE_NOMEM;
}
memset(knn_data, 0, sizeof(struct vec0_query_knn_data));
void *queryVector;
size_t dimensions;
enum VectorElementType elementType;
vector_cleanup cleanup;
char *err;
int rc = vector_from_value(argv[0], &queryVector, &dimensions, &elementType,
&cleanup, &err);
todo_assert(elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32);
todo_assert(dimensions > 0);
i64 k = sqlite3_value_int64(argv[1]);
todo_assert(k >= 0);
if (k == 0) {
knn_data->k = 0;
pCur->knn_data = knn_data;
return SQLITE_OK;
}
i64 *topk_rowids = sqlite3_malloc(k * sizeof(i64));
todo_assert(topk_rowids);
f32 *topk_distances = sqlite3_malloc(k * sizeof(f32));
todo_assert(topk_distances);
sqlite3_stmt * stmtRowids;
char * zSql = sqlite3_mprintf("select rowid from \"%w\" ", p->table);
assert(zSql);
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmtRowids, NULL);
assert(rc == SQLITE_OK);
sqlite3_blob * baseVectorsBlob;
sqlite3_blob_open(p->db, "main", p->table, p->column, 1, 0, &baseVectorsBlob);
int chunk_size = 200;
float * chunk = sqlite3_malloc(dimensions * chunk_size * sizeof(float));
assert(chunk);
f32 *chunk_distances = sqlite3_malloc(chunk_size * sizeof(f32));
todo_assert(chunk_distances);
for (int i = 0; i < k; i++) {
topk_distances[i] = __FLT_MAX__;
}
i64 *chunk_rowids = sqlite3_malloc(chunk_size * sizeof(i64));
todo_assert(chunk_rowids);
while(true) {
int nused = 0;
for(int i = 0; i < chunk_size; i++) {
rc = sqlite3_step(stmtRowids);
if(rc == SQLITE_DONE) {
break;
}
assert(rc == SQLITE_ROW);
nused = i+1;
i64 rowid = sqlite3_column_int64(stmtRowids, 0);
chunk_rowids[i] = rowid;
rc = sqlite3_blob_reopen(baseVectorsBlob, rowid);
assert(rc == SQLITE_OK);
assert(sqlite3_blob_bytes(baseVectorsBlob) == dimensions * sizeof(float));
sqlite3_blob_read(baseVectorsBlob, &chunk[i * dimensions], dimensions * sizeof(float), 0);
}
for(int i = 0; i < nused; i++) {
const f32 *base_i = (chunk) + (i * dimensions);
chunk_distances[i] = distance_l2_sqr_float(base_i, (f32 *)queryVector, &dimensions);
}
i32 *chunk_top_idxs = sqlite3_malloc(nused * sizeof(i32));
todo_assert(chunk_top_idxs);
min_idx(chunk_distances, nused, chunk_top_idxs, nused);
i64 *out_rowids;
f32 *out_distances;
dethrone2(k, topk_distances, topk_rowids, /*chunk_size*/ nused, chunk_top_idxs,
chunk_distances, chunk_rowids,
&out_rowids, &out_distances);
for (int i = 0; i < k; i++) {
topk_rowids[i] = out_rowids[i];
topk_distances[i] = out_distances[i];
}
sqlite3_free(out_rowids);
sqlite3_free(out_distances);
sqlite3_free(chunk_top_idxs);
if(nused < chunk_size) break;
}
sqlite3_blob_close(baseVectorsBlob);
sqlite3_finalize(stmtRowids);
cleanup(queryVector);
knn_data->current_idx = 0;
knn_data->k = k;
knn_data->rowids = topk_rowids;
knn_data->distances = topk_distances;
pCur->knn_data = knn_data;
}
else {
pCur->query_plan = VEC_SBE__QUERYPLAN_FULLSCAN;
pCur->iRowid = 0;
}
return SQLITE_OK;
}
static int vec_expoRowid(sqlite3_vtab_cursor *cur, sqlite_int64 *pRowid) {
vec_expo_cursor *pCur = (vec_expo_cursor *)cur;
switch(pCur->query_plan) {
case VEC_SBE__QUERYPLAN_FULLSCAN: {
*pRowid = pCur->iRowid;
break;
}
case VEC_SBE__QUERYPLAN_KNN: {
*pRowid = pCur->knn_data->rowids[pCur->knn_data->current_idx];
break;
}
}
return SQLITE_OK;
}
static int vec_expoNext(sqlite3_vtab_cursor *cur) {
vec_expo_cursor *pCur = (vec_expo_cursor *)cur;
switch(pCur->query_plan) {
case VEC_SBE__QUERYPLAN_FULLSCAN: {
pCur->iRowid++;
return SQLITE_OK;
}
case VEC_SBE__QUERYPLAN_KNN: {
pCur->knn_data->current_idx++;
return SQLITE_OK;
}
}
}
static int vec_expoEof(sqlite3_vtab_cursor *cur) {
vec_expo_cursor *pCur = (vec_expo_cursor *)cur;
vec_expo_vtab * p = (vec_expo_vtab *) pCur->base.pVtab;
switch(pCur->query_plan) {
case VEC_SBE__QUERYPLAN_FULLSCAN: {
return 1;//(size_t) pCur->iRowid >= p->blob->nvectors;
}
case VEC_SBE__QUERYPLAN_KNN: {
return pCur->knn_data->current_idx >= pCur->knn_data->k;
}
}
}
static int vec_expoColumn(sqlite3_vtab_cursor *cur, sqlite3_context *context,
int i) {
vec_expo_cursor *pCur = (vec_expo_cursor *)cur;
vec_expo_vtab *p = (vec_expo_vtab *)cur->pVtab;
switch(pCur->query_plan) {
case VEC_SBE__QUERYPLAN_FULLSCAN: {
return SQLITE_OK;
}
case VEC_SBE__QUERYPLAN_KNN: {
switch(i) {
case VEC_EXPO_VECTOR: {
break;
}
case VEC_EXPO_DISTANCE: {
sqlite3_result_double(context, pCur->knn_data->distances[pCur->knn_data->current_idx]);
break;
}
}
return SQLITE_OK;
}
}
}
static sqlite3_module vec_expoModule = {
/* iVersion */ 3,
/* xCreate */ vec_expoCreate,
/* xConnect */ vec_expoConnect,
/* xBestIndex */ vec_expoBestIndex,
/* xDisconnect */ vec_expoDisconnect,
/* xDestroy */ vec_expoDisconnect,
/* xOpen */ vec_expoOpen,
/* xClose */ vec_expoClose,
/* xFilter */ vec_expoFilter,
/* xNext */ vec_expoNext,
/* xEof */ vec_expoEof,
/* xColumn */ vec_expoColumn,
/* xRowid */ vec_expoRowid,
/* xUpdate */ 0,
/* xBegin */ 0,
/* xSync */ 0,
/* xCommit */ 0,
/* xRollback */ 0,
/* xFindMethod */ 0,
/* xRename */ 0,
/* xSavepoint */ 0,
/* xRelease */ 0,
/* xRollbackTo */ 0,
/* xShadowName */ 0};
#pragma endregion
#endif
int sqlite3_mmap_warm(sqlite3 *db, const char *zDb) {
@ -5225,6 +5618,8 @@ __declspec(dllexport)
assert(rc == SQLITE_OK);
rc = sqlite3_create_module_v2(db, "vec_static_blob_entries", &vec_static_blob_entriesModule, static_blob_data, NULL);
assert(rc == SQLITE_OK);
rc = sqlite3_create_module_v2(db, "vec_expo", &vec_expoModule, NULL, NULL);
assert(rc == SQLITE_OK);
#endif
return SQLITE_OK;