static updates

This commit is contained in:
Alex Garcia 2024-07-31 12:56:09 -07:00
parent e91ccf38ff
commit a0bc9404ce
3 changed files with 323 additions and 231 deletions

View file

@ -736,7 +736,7 @@ static int fvec_from_value(sqlite3_value *value, f32 **vector,
if (source[offset] == ',') {
offset++;
continue;
} // TODO multiple commas in a row without digits?
}
if (source[offset] == ']')
goto done;
break;
@ -2006,7 +2006,7 @@ int parse_vector_column(const char *source, int source_length,
struct VectorColumnDefinition *outColumn) {
// parses a vector column definition like so:
// "abc float[123]", "abc_123 bit[1234]", eetc.
// TODO: more precise error messages here.
// https://github.com/asg017/sqlite-vec/issues/46
int rc;
struct Vec0Scanner scanner;
struct Vec0Token token;
@ -2989,7 +2989,7 @@ static int vec_npy_eachColumnBuffer(vec_npy_each_cursor *pCur,
context,
&((unsigned char *)
pCur->vector)[pCur->iRowid * pCur->nDimensions * sizeof(f32)],
pCur->nDimensions * sizeof(f32), SQLITE_STATIC);
pCur->nDimensions * sizeof(f32), SQLITE_TRANSIENT);
break;
}
@ -3361,7 +3361,6 @@ cleanup:
int vec0_get_id_value_from_rowid(vec0_vtab *pVtab, i64 rowid,
sqlite3_value **out) {
// PERF: different strategy than get_chunk_position?
// TODO: test / evidence-of
return vec0_get_chunk_position((vec0_vtab *)pVtab, rowid, out, NULL, NULL);
}
@ -3461,7 +3460,6 @@ int vec0_get_vector_data(vec0_vtab *pVtab, i64 rowid, int vector_column_idx,
"vectors", chunk_id, 0, &vectorBlob);
if (rc != SQLITE_OK) {
// TODO evidence-of
vtab_set_error(&pVtab->base,
"Could not fetch vector data for %lld, opening blob failed",
rowid);
@ -4490,7 +4488,11 @@ void bitmap_and_inplace(u8 *base, u8 *other, i32 n) {
}
void bitmap_set(u8 *bitmap, i32 position, int value) {
bitmap[position / CHAR_BIT] |= value << (position % CHAR_BIT);
if (value) {
bitmap[position / CHAR_BIT] |= 1 << (position % CHAR_BIT);
} else {
bitmap[position / CHAR_BIT] &= ~(1 << (position % CHAR_BIT));
}
}
int bitmap_get(u8 *bitmap, i32 position) {
@ -4502,6 +4504,11 @@ void bitmap_clear(u8 *bitmap, i32 n) {
memset(bitmap, 0, n / CHAR_BIT);
}
void bitmap_fill(u8 *bitmap, i32 n) {
assert((n % 8) == 0);
memset(bitmap, 0xFF, n / CHAR_BIT);
}
/**
* @brief Finds the minimum k items in distances, and writes the indicies to
* out.
@ -4885,7 +4892,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
if (dimensions != vector_column->dimensions) {
vtab_set_error(
&p->base,
"Dimension mismatch for inserted vector for the \"%.*s\" column. "
"Dimension mismatch for query vector for the \"%.*s\" column. "
"Expected %d dimensions but received %d.",
vector_column->name_length, vector_column->name,
vector_column->dimensions, dimensions);
@ -5728,8 +5735,7 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
// a write-able blob of the validity column for the given chunk. Used to mark
// validity bit
sqlite3_blob *blobChunksValidity = NULL;
// buffer for the valididty column for the given chunk. TODO maybe not needed
// here?
// buffer for the valididty column for the given chunk. Maybe not needed here?
const unsigned char *bufferChunksValidity = NULL;
int numReadVectors = 0;
@ -5951,9 +5957,9 @@ int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite_int64 rowid) {
return rc;
}
// 3. zero out rowid in chunks.rowids TODO
// 3. zero out rowid in chunks.rowids https://github.com/asg017/sqlite-vec/issues/54
// 4. zero out any data in vector chunks tables TODO
// 4. zero out any data in vector chunks tables https://github.com/asg017/sqlite-vec/issues/54
// 5. delete from _rowids table
rc = vec0Update_Delete_DeleteRowids(p, rowid);
@ -5975,9 +5981,7 @@ int vec0Update_UpdateVectorColumn(vec0_vtab *p, i64 chunk_id, i64 chunk_offset,
enum VectorElementType elementType;
void *vector;
vector_cleanup cleanup = vector_cleanup_noop;
// TODO: Can't update non f32, bc subtypes are stripped from UPDATEs.
// Need to 1) create a less strict vector_from_value, or 2) wait
// for this to resolve: https://sqlite.org/forum/forumpost/65317ce9c6
// https://github.com/asg017/sqlite-vec/issues/53
rc = vector_from_value(valueVector, &vector, &dimensions, &elementType,
&cleanup, &pzError);
if (rc != SQLITE_OK) {
@ -6449,12 +6453,36 @@ typedef enum {
VEC_SBE__QUERYPLAN_KNN = 2
} vec_sbe_query_plan;
struct sbe_query_knn_data {
i64 k;
i64 k_used;
// Array of rowids of size k. Must be freed with sqlite3_free().
i32 *rowids;
// Array of distances of size k. Must be freed with sqlite3_free().
f32 *distances;
i64 current_idx;
};
void sbe_query_knn_data_clear(struct sbe_query_knn_data *knn_data) {
if (!knn_data)
return;
if (knn_data->rowids) {
sqlite3_free(knn_data->rowids);
knn_data->rowids = NULL;
}
if (knn_data->distances) {
sqlite3_free(knn_data->distances);
knn_data->distances = NULL;
}
}
typedef struct vec_static_blob_entries_cursor vec_static_blob_entries_cursor;
struct vec_static_blob_entries_cursor {
sqlite3_vtab_cursor base;
sqlite3_int64 iRowid;
vec_sbe_query_plan query_plan;
struct vec0_query_knn_data *knn_data;
struct sbe_query_knn_data *knn_data;
};
static int vec_static_blob_entriesConnect(sqlite3 *db, void *pAux, int argc,
@ -6519,6 +6547,7 @@ static int vec_static_blob_entriesOpen(sqlite3_vtab *p,
static int vec_static_blob_entriesClose(sqlite3_vtab_cursor *cur) {
vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)cur;
sqlite3_free(pCur->knn_data);
sqlite3_free(pCur);
return SQLITE_OK;
}
@ -6528,7 +6557,7 @@ static int vec_static_blob_entriesBestIndex(sqlite3_vtab *pVTab,
vec_static_blob_entries_vtab *p = (vec_static_blob_entries_vtab *)pVTab;
int iMatchTerm = -1;
int iLimitTerm = -1;
// int iRowidTerm = -1; // TODO point query
// int iRowidTerm = -1; // https://github.com/asg017/sqlite-vec/issues/47
int iKTerm = -1;
for (int i = 0; i < pIdxInfo->nConstraint; i++) {
@ -6540,7 +6569,7 @@ static int vec_static_blob_entriesBestIndex(sqlite3_vtab *pVTab,
if (op == SQLITE_INDEX_CONSTRAINT_MATCH &&
iColumn == VEC_STATIC_BLOB_ENTRIES_VECTOR) {
if (iMatchTerm > -1) {
// TODO only 1 match operator at a time
// https://github.com/asg017/sqlite-vec/issues/51
return SQLITE_ERROR;
}
iMatchTerm = i;
@ -6555,7 +6584,7 @@ static int vec_static_blob_entriesBestIndex(sqlite3_vtab *pVTab,
}
if (iMatchTerm >= 0) {
if (iLimitTerm < 0 && iKTerm < 0) {
// TODO: error, match on vector1 should require a limit for KNN
// https://github.com/asg017/sqlite-vec/issues/51
return SQLITE_ERROR;
}
if (iLimitTerm >= 0 && iKTerm >= 0) {
@ -6566,7 +6595,7 @@ static int vec_static_blob_entriesBestIndex(sqlite3_vtab *pVTab,
return SQLITE_CONSTRAINT;
}
if (pIdxInfo->nOrderBy > 1) {
// TODO error, orderByConsumed is all or nothing, only 1 order by allowed
// https://github.com/asg017/sqlite-vec/issues/51
vtab_set_error(pVTab, "more than 1 ORDER BY clause provided");
return SQLITE_CONSTRAINT;
}
@ -6615,8 +6644,9 @@ static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor,
(vec_static_blob_entries_vtab *)pCur->base.pVtab;
if (idxNum == VEC_SBE__QUERYPLAN_KNN) {
assert(argc == 2);
pCur->query_plan = VEC_SBE__QUERYPLAN_KNN;
struct vec0_query_knn_data *knn_data;
struct sbe_query_knn_data *knn_data;
knn_data = sqlite3_malloc(sizeof(*knn_data));
if (!knn_data) {
return SQLITE_NOMEM;
@ -6642,7 +6672,7 @@ static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor,
i64 k = min(sqlite3_value_int64(argv[1]), (i64)p->blob->nvectors);
if (k < 0) {
// TODO handle
// HANDLE https://github.com/asg017/sqlite-vec/issues/55
return SQLITE_ERROR;
}
if (k == 0) {
@ -6651,24 +6681,38 @@ static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor,
return SQLITE_OK;
}
i64 *topk_rowids = sqlite3_malloc(k * sizeof(i64));
size_t bsize = (p->blob->nvectors + 7) & ~7;
i32 *topk_rowids = sqlite3_malloc(k * sizeof(i32));
if (!topk_rowids) {
// TODO handle
// HANDLE https://github.com/asg017/sqlite-vec/issues/55
return SQLITE_ERROR;
}
f32 *distances = sqlite3_malloc(p->blob->nvectors * sizeof(f32));
f32 *distances = sqlite3_malloc(bsize * sizeof(f32));
if (!distances) {
// TODO handle
// HANDLE https://github.com/asg017/sqlite-vec/issues/55
return SQLITE_ERROR;
}
for (size_t i = 0; i < p->blob->nvectors; i++) {
// https://github.com/asg017/sqlite-vec/issues/52
float *v = ((float *)p->blob->p) + (i * p->blob->dimensions);
distances[i] =
distance_l2_sqr_float(v, (float *)queryVector, &p->blob->dimensions);
// TODO other metrics
}
// min_idx(distances, k, topk_rowids, k); // TODO
u8 * candidates = bitmap_new(bsize);
assert(candidates);
u8 * taken = bitmap_new(bsize);
assert(taken);
bitmap_fill(candidates, bsize);
for(size_t i = bsize; i >= p->blob->nvectors; i--) {
bitmap_set(candidates, i, 0);
}
i32 k_used = 0;
min_idx(distances, bsize, candidates, topk_rowids, k, taken, &k_used);
knn_data->current_idx = 0;
knn_data->distances = distances;
knn_data->k = k;
@ -6686,8 +6730,18 @@ static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor,
static int vec_static_blob_entriesRowid(sqlite3_vtab_cursor *cur,
sqlite_int64 *pRowid) {
vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)cur;
*pRowid = pCur->iRowid;
return SQLITE_OK;
switch (pCur->query_plan) {
case VEC_SBE__QUERYPLAN_FULLSCAN: {
*pRowid = pCur->iRowid;
return SQLITE_OK;
}
case VEC_SBE__QUERYPLAN_KNN: {
i32 rowid = ((i32 *)pCur->knn_data->rowids)[pCur->knn_data->current_idx];
*pRowid = (sqlite3_int64) rowid;
return SQLITE_OK;
}
}
}
static int vec_static_blob_entriesNext(sqlite3_vtab_cursor *cur) {
@ -6732,7 +6786,7 @@ static int vec_static_blob_entriesColumn(sqlite3_vtab_cursor *cur,
context,
((unsigned char *)p->blob->p) +
(pCur->iRowid * p->blob->dimensions * sizeof(float)),
p->blob->dimensions * sizeof(float), SQLITE_STATIC);
p->blob->dimensions * sizeof(float), SQLITE_TRANSIENT);
sqlite3_result_subtype(context, p->blob->element_type);
break;
}
@ -6742,11 +6796,10 @@ static int vec_static_blob_entriesColumn(sqlite3_vtab_cursor *cur,
switch (i) {
case VEC_STATIC_BLOB_ENTRIES_VECTOR: {
i32 rowid = ((i32 *)pCur->knn_data->rowids)[pCur->knn_data->current_idx];
sqlite3_result_blob(context,
((unsigned char *)p->blob->p) +
(rowid * p->blob->dimensions * sizeof(float)),
p->blob->dimensions * sizeof(float), SQLITE_STATIC);
p->blob->dimensions * sizeof(float), SQLITE_TRANSIENT);
sqlite3_result_subtype(context, p->blob->element_type);
break;
}
@ -6758,7 +6811,7 @@ static int vec_static_blob_entriesColumn(sqlite3_vtab_cursor *cur,
static sqlite3_module vec_static_blob_entriesModule = {
/* iVersion */ 3,
/* xCreate */ vec_static_blob_entriesCreate, // TODO rm?
/* xCreate */ vec_static_blob_entriesCreate, // handle rm? https://github.com/asg017/sqlite-vec/issues/55
/* xConnect */ vec_static_blob_entriesConnect,
/* xBestIndex */ vec_static_blob_entriesBestIndex,
/* xDisconnect */ vec_static_blob_entriesDisconnect,
@ -6787,91 +6840,6 @@ static sqlite3_module vec_static_blob_entriesModule = {
};
#pragma endregion
int sqlite3_mmap_warm(sqlite3 *db, const char *zDb) {
int rc = SQLITE_OK;
char *zSql = 0;
int pgsz = 0;
unsigned int nTotal = 0;
if (0 == sqlite3_get_autocommit(db))
return SQLITE_MISUSE;
/* Open a read-only transaction on the file in question */
zSql = sqlite3_mprintf("BEGIN; SELECT * FROM %s%q%ssqlite_schema",
(zDb ? "'" : ""), (zDb ? zDb : ""), (zDb ? "'." : ""));
if (zSql == 0)
return SQLITE_NOMEM;
rc = sqlite3_exec(db, zSql, 0, 0, 0);
sqlite3_free(zSql);
/* Find the SQLite page size of the file */
if (rc == SQLITE_OK) {
zSql = sqlite3_mprintf("PRAGMA %s%q%spage_size", (zDb ? "'" : ""),
(zDb ? zDb : ""), (zDb ? "'." : ""));
if (zSql == 0) {
rc = SQLITE_NOMEM;
} else {
sqlite3_stmt *pPgsz = 0;
rc = sqlite3_prepare_v2(db, zSql, -1, &pPgsz, 0);
sqlite3_free(zSql);
if (rc == SQLITE_OK) {
if (sqlite3_step(pPgsz) == SQLITE_ROW) {
pgsz = sqlite3_column_int(pPgsz, 0);
}
rc = sqlite3_finalize(pPgsz);
}
if (rc == SQLITE_OK && pgsz == 0) {
rc = SQLITE_ERROR;
}
}
}
/* Touch each mmap'd page of the file */
if (rc == SQLITE_OK) {
int rc2;
sqlite3_file *pFd = 0;
rc = sqlite3_file_control(db, zDb, SQLITE_FCNTL_FILE_POINTER, &pFd);
if (rc == SQLITE_OK && pFd->pMethods && pFd->pMethods->iVersion >= 3) {
i64 iPg = 1;
sqlite3_io_methods const *p = pFd->pMethods;
while (1) {
unsigned char *pMap;
rc = p->xFetch(pFd, pgsz * iPg, pgsz, (void **)&pMap);
if (rc != SQLITE_OK || pMap == 0)
break;
nTotal += (unsigned int)pMap[0];
nTotal += (unsigned int)pMap[pgsz - 1];
rc = p->xUnfetch(pFd, pgsz * iPg, (void *)pMap);
if (rc != SQLITE_OK)
break;
iPg++;
}
sqlite3_log(SQLITE_OK,
"sqlite3_mmap_warm_cache: Warmed up %d pages of %s",
iPg == 1 ? 0 : iPg, sqlite3_db_filename(db, zDb));
}
rc2 = sqlite3_exec(db, "END", 0, 0, 0);
if (rc == SQLITE_OK)
rc = rc2;
}
(void)nTotal;
return rc;
}
#ifdef _WIN32
__declspec(dllexport)
#endif
int sqlite3_vec_warm_mmap(sqlite3 *db, char **pzErrMsg,
const sqlite3_api_routines *pApi) {
UNUSED_PARAMETER(pzErrMsg);
SQLITE_EXTENSION_INIT2(pApi);
return sqlite3_mmap_warm(db, NULL);
}
#ifdef SQLITE_VEC_ENABLE_AVX
#define SQLITE_VEC_DEBUG_BUILD_AVX "avx"
#else
@ -6907,12 +6875,6 @@ __declspec(dllexport)
const sqlite3_api_routines *pApi) {
SQLITE_EXTENSION_INIT2(pApi);
int rc = SQLITE_OK;
vec_static_blob_data *static_blob_data;
static_blob_data = sqlite3_malloc(sizeof(*static_blob_data));
if (!static_blob_data) {
return SQLITE_NOMEM;
}
memset(static_blob_data, 0, sizeof(*static_blob_data));
#define DEFAULT_FLAGS (SQLITE_UTF8 | SQLITE_INNOCUOUS | SQLITE_DETERMINISTIC)
@ -6953,7 +6915,6 @@ __declspec(dllexport)
{"vec_int8", vec_int8, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, },
{"vec_quantize_int8", vec_quantize_int8, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, },
{"vec_quantize_binary", vec_quantize_binary, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, },
{"vec_static_blob_from_raw", vec_static_blob_from_raw, 4, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE },
// clang-format on
};
@ -6989,13 +6950,7 @@ __declspec(dllexport)
return rc;
}
}
rc = sqlite3_create_module_v2(db, "vec_static_blobs", &vec_static_blobsModule,
static_blob_data, sqlite3_free);
assert(rc == SQLITE_OK);
rc = sqlite3_create_module_v2(db, "vec_static_blob_entries",
&vec_static_blob_entriesModule,
static_blob_data, NULL);
assert(rc == SQLITE_OK);
return SQLITE_OK;
}
@ -7014,23 +6969,33 @@ __declspec(dllexport)
}
#endif
#ifdef SQLITE_VEC_ENABLE_TRACE_ENTRYPOINT
int trace(unsigned int x, void *p1, void *p2, void *p3) {
if (x == SQLITE_TRACE_STMT) {
sqlite3_stmt *stmt = (sqlite3_stmt *)p2;
char *zSql = sqlite3_expanded_sql(stmt);
printf("%s\n", zSql);
}
}
#ifdef _WIN32
__declspec(dllexport)
#endif
int trace_debug(sqlite3 *db, char **pzErrMsg,
const sqlite3_api_routines *pApi) {
int sqlite3_vec_static_blobs_init(sqlite3 *db, char **pzErrMsg,
const sqlite3_api_routines *pApi) {
UNUSED_PARAMETER(pzErrMsg);
SQLITE_EXTENSION_INIT2(pApi);
sqlite3_trace_v2(db, SQLITE_TRACE_STMT, trace, NULL);
return SQLITE_OK;
int rc = SQLITE_OK;
vec_static_blob_data *static_blob_data;
static_blob_data = sqlite3_malloc(sizeof(*static_blob_data));
if (!static_blob_data) {
return SQLITE_NOMEM;
}
memset(static_blob_data, 0, sizeof(*static_blob_data));
rc = sqlite3_create_function_v2(db, "vec_static_blob_from_raw", 4, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE,
NULL, vec_static_blob_from_raw, NULL, NULL, NULL);
if(rc != SQLITE_OK) return rc;
rc = sqlite3_create_module_v2(db, "vec_static_blobs", &vec_static_blobsModule,
static_blob_data, sqlite3_free);
if(rc != SQLITE_OK) return rc;
rc = sqlite3_create_module_v2(db, "vec_static_blob_entries",
&vec_static_blob_entriesModule,
static_blob_data, NULL);
if(rc != SQLITE_OK) return rc;
return rc;
}
#endif

View file

@ -80,8 +80,6 @@ def connect(ext, path=":memory:", extra_entrypoint=None):
db = connect(EXT_PATH)
# db.load_extension(EXT_PATH, entrypoint="trace_debug")
def explain_query_plan(sql):
return db.execute("explain query plan " + sql).fetchone()["detail"]
@ -113,7 +111,6 @@ FUNCTIONS = [
"vec_quantize_binary",
"vec_quantize_int8",
"vec_slice",
"vec_static_blob_from_raw",
"vec_sub",
"vec_to_json",
"vec_type",
@ -123,24 +120,150 @@ MODULES = [
"vec0",
"vec_each",
"vec_npy_each",
"vec_static_blob_entries",
"vec_static_blobs",
#"vec_static_blob_entries",
#"vec_static_blobs",
]
@pytest.mark.skip(reason="TODO")
def test_vec_static_blob_from_raw():
pass
def register_numpy(db, name: str, array):
ptr = array.__array_interface__["data"][0]
nvectors, dimensions = array.__array_interface__["shape"]
element_type = array.__array_interface__["typestr"]
assert element_type == "<f4"
name_escaped = db.execute("select printf('%w', ?)", [name]).fetchone()[0]
db.execute(
"""
insert into temp.vec_static_blobs(name, data)
select ?, vec_static_blob_from_raw(?, ?, ?, ?)
""",
[name, ptr, element_type, dimensions, nvectors],
)
db.execute(
f'create virtual table "{name_escaped}" using vec_static_blob_entries({name_escaped})'
)
@pytest.mark.skip(reason="TODO")
def test_vec_static_blobs():
pass
@pytest.mark.skip(reason="TODO")
def test_vec_static_blob_entries():
pass
db = connect(EXT_PATH, extra_entrypoint="sqlite3_vec_static_blobs_init")
x = np.array([[0.1, 0.2, 0.3, 0.4], [0.9, 0.8, 0.7, 0.6]], dtype=np.float32)
y = np.array([[0.2, 0.3], [0.9, 0.8], [0.6, 0.5]], dtype=np.float32)
z = np.array(
[
[0.1, 0.1, 0.1, 0.1],
[0.2, 0.2, 0.2, 0.2],
[0.3, 0.3, 0.3, 0.3],
[0.4, 0.4, 0.4, 0.4],
[0.5, 0.5, 0.5, 0.5],
],
dtype=np.float32,
)
register_numpy(db, "x", x)
register_numpy(db, "y", y)
register_numpy(db, "z", z)
assert execute_all(
db, "select *, dimensions, count from temp.vec_static_blobs;"
) == [
{
"count": 2,
"data": None,
"dimensions": 4,
"name": "x",
},
{
"count": 3,
"data": None,
"dimensions": 2,
"name": "y",
},
{
"count": 5,
"data": None,
"dimensions": 4,
"name": "z",
},
]
assert execute_all(db, "select vec_to_json(vector) from x;") == [
{
"vec_to_json(vector)": "[0.100000,0.200000,0.300000,0.400000]",
},
{
"vec_to_json(vector)": "[0.900000,0.800000,0.700000,0.600000]",
},
]
assert execute_all(db, "select (vector) from y limit 2;") == [
{
"vector": b"\xcd\xccL>\x9a\x99\x99>",
},
{
"vector": b"fff?\xcd\xccL?",
},
]
assert execute_all(db, "select rowid, (vector) from z") == [
{
"rowid": 0,
"vector": b"\xcd\xcc\xcc=\xcd\xcc\xcc=\xcd\xcc\xcc=\xcd\xcc\xcc=",
},
{
"rowid": 1,
"vector": b"\xcd\xccL>\xcd\xccL>\xcd\xccL>\xcd\xccL>",
},
{
"rowid": 2,
"vector": b"\x9a\x99\x99>\x9a\x99\x99>\x9a\x99\x99>\x9a\x99\x99>",
},
{
"rowid": 3,
"vector": b"\xcd\xcc\xcc>\xcd\xcc\xcc>\xcd\xcc\xcc>\xcd\xcc\xcc>",
},
{
"rowid": 4,
"vector": b"\x00\x00\x00?\x00\x00\x00?\x00\x00\x00?\x00\x00\x00?",
},
]
assert execute_all(
db,
"select rowid, vec_to_json(vector) as v from z where vector match ? and k = 3 order by distance;",
[np.array([0.3, 0.3, 0.3, 0.3], dtype=np.float32)],
) == [
{
"rowid": 2,
"v": "[0.300000,0.300000,0.300000,0.300000]",
},
{
"rowid": 3,
"v": "[0.400000,0.400000,0.400000,0.400000]",
},
{
"rowid": 1,
"v": "[0.200000,0.200000,0.200000,0.200000]",
},
]
assert execute_all(
db,
"select rowid, vec_to_json(vector) as v from z where vector match ? and k = 3 order by distance;",
[np.array([0.6, 0.6, 0.6, 0.6], dtype=np.float32)],
) == [
{
"rowid": 4,
"v": "[0.500000,0.500000,0.500000,0.500000]",
},
{
"rowid": 3,
"v": "[0.400000,0.400000,0.400000,0.400000]",
},
{
"rowid": 2,
"v": "[0.300000,0.300000,0.300000,0.300000]",
},
]
def test_funcs():
@ -1872,7 +1995,7 @@ def test_vec0_knn():
db.execute("select * from v where aaa match vec_bit(X'AA') and k = 10")
with _raises(
'Dimension mismatch for inserted vector for the "aaa" column. Expected 8 dimensions but received 1.'
'Dimension mismatch for query vector for the "aaa" column. Expected 8 dimensions but received 1.'
):
db.execute("select * from v where aaa match vec_f32('[.1]') and k = 10")
@ -2120,36 +2243,42 @@ def test_smoke():
db.execute("insert into vec_xyz(rowid, a) select 2, X'0000000000000040'")
chunk = db.execute("select * from vec_xyz_chunks").fetchone()
assert chunk[
"rowids"
] == b"\x01\x00\x00\x00\x00\x00\x00\x00" + b"\x02\x00\x00\x00\x00\x00\x00\x00" + bytearray(
int(1024 * 8) - 8 * 2
assert (
chunk["rowids"]
== b"\x01\x00\x00\x00\x00\x00\x00\x00"
+ b"\x02\x00\x00\x00\x00\x00\x00\x00"
+ bytearray(int(1024 * 8) - 8 * 2)
)
assert chunk["chunk_id"] == 1
assert chunk["validity"] == b"\x03" + bytearray(int(1024 / 8) - 1)
vchunk = db.execute("select * from vec_xyz_vector_chunks00").fetchone()
assert vchunk["rowid"] == 1
assert vchunk[
"vectors"
] == b"\x00\x00\x00\x00\x00\x00\x80\x3f" + b"\x00\x00\x00\x00\x00\x00\x00\x40" + bytearray(
int(1024 * 4 * 2) - (2 * 4 * 2)
assert (
vchunk["vectors"]
== b"\x00\x00\x00\x00\x00\x00\x80\x3f"
+ b"\x00\x00\x00\x00\x00\x00\x00\x40"
+ bytearray(int(1024 * 4 * 2) - (2 * 4 * 2))
)
db.execute("insert into vec_xyz(rowid, a) select 3, X'00000000000080bf'")
chunk = db.execute("select * from vec_xyz_chunks").fetchone()
assert chunk["chunk_id"] == 1
assert chunk["validity"] == b"\x07" + bytearray(int(1024 / 8) - 1)
assert chunk[
"rowids"
] == b"\x01\x00\x00\x00\x00\x00\x00\x00" + b"\x02\x00\x00\x00\x00\x00\x00\x00" + b"\x03\x00\x00\x00\x00\x00\x00\x00" + bytearray(
int(1024 * 8) - 8 * 3
assert (
chunk["rowids"]
== b"\x01\x00\x00\x00\x00\x00\x00\x00"
+ b"\x02\x00\x00\x00\x00\x00\x00\x00"
+ b"\x03\x00\x00\x00\x00\x00\x00\x00"
+ bytearray(int(1024 * 8) - 8 * 3)
)
vchunk = db.execute("select * from vec_xyz_vector_chunks00").fetchone()
assert vchunk["rowid"] == 1
assert vchunk[
"vectors"
] == b"\x00\x00\x00\x00\x00\x00\x80\x3f" + b"\x00\x00\x00\x00\x00\x00\x00\x40" + b"\x00\x00\x00\x00\x00\x00\x80\xbf" + bytearray(
int(1024 * 4 * 2) - (2 * 4 * 3)
assert (
vchunk["vectors"]
== b"\x00\x00\x00\x00\x00\x00\x80\x3f"
+ b"\x00\x00\x00\x00\x00\x00\x00\x40"
+ b"\x00\x00\x00\x00\x00\x00\x80\xbf"
+ bytearray(int(1024 * 4 * 2) - (2 * 4 * 3))
)
# db.execute("select * from vec_xyz")
@ -2192,66 +2321,63 @@ def test_vec0_stress_small_chunks():
{"rowid": 994, "a": _f32([99.4] * 8)},
{"rowid": 993, "a": _f32([99.3] * 8)},
]
assert (
execute_all(
db,
"""
assert execute_all(
db,
"""
select rowid, a, distance
from vec_small
where a match ?
and k = 9
order by distance
""",
[_f32([50.0] * 8)],
)
== [
{
"a": _f32([500 * 0.1] * 8),
"distance": 0.0,
"rowid": 500,
},
{
"a": _f32([501 * 0.1] * 8),
"distance": 0.2828384041786194,
"rowid": 501,
},
{
"a": _f32([499 * 0.1] * 8),
"distance": 0.2828384041786194,
"rowid": 499,
},
{
"a": _f32([502 * 0.1] * 8),
"distance": 0.5656875967979431,
"rowid": 502,
},
{
"a": _f32([498 * 0.1] * 8),
"distance": 0.5656875967979431,
"rowid": 498,
},
{
"a": _f32([503 * 0.1] * 8),
"distance": 0.8485260009765625,
"rowid": 503,
},
{
"a": _f32([497 * 0.1] * 8),
"distance": 0.8485260009765625,
"rowid": 497,
},
{
"a": _f32([496 * 0.1] * 8),
"distance": 1.1313751935958862,
"rowid": 496,
},
{
"a": _f32([504 * 0.1] * 8),
"distance": 1.1313751935958862,
"rowid": 504,
},
]
)
[_f32([50.0] * 8)],
) == [
{
"a": _f32([500 * 0.1] * 8),
"distance": 0.0,
"rowid": 500,
},
{
"a": _f32([501 * 0.1] * 8),
"distance": 0.2828384041786194,
"rowid": 501,
},
{
"a": _f32([499 * 0.1] * 8),
"distance": 0.2828384041786194,
"rowid": 499,
},
{
"a": _f32([502 * 0.1] * 8),
"distance": 0.5656875967979431,
"rowid": 502,
},
{
"a": _f32([498 * 0.1] * 8),
"distance": 0.5656875967979431,
"rowid": 498,
},
{
"a": _f32([503 * 0.1] * 8),
"distance": 0.8485260009765625,
"rowid": 503,
},
{
"a": _f32([497 * 0.1] * 8),
"distance": 0.8485260009765625,
"rowid": 497,
},
{
"a": _f32([496 * 0.1] * 8),
"distance": 1.1313751935958862,
"rowid": 496,
},
{
"a": _f32([504 * 0.1] * 8),
"distance": 1.1313751935958862,
"rowid": 504,
},
]
def test_vec0_distance_metric():

View file

@ -5,6 +5,7 @@ db = sqlite3.connect(":memory:")
db.enable_load_extension(True)
db.load_extension("./dist/vec0")
db.execute("select load_extension('./dist/vec0', 'sqlite3_vec_raw_init')")
db.enable_load_extension(False)
x = np.array([[0.1, 0.2, 0.3, 0.4], [0.9, 0.8, 0.7, 0.6]], dtype=np.float32)