This commit is contained in:
Alex Garcia 2024-07-23 23:57:42 -07:00
parent 7a1b14976a
commit 0f5bc2f254
2 changed files with 258 additions and 238 deletions

View file

@ -1277,8 +1277,7 @@ char * vec_type_name(enum VectorElementType elementType) {
} }
} }
static void vec_type(sqlite3_context *context, int argc, static void vec_type(sqlite3_context *context, int argc, sqlite3_value **argv) {
sqlite3_value **argv) {
assert(argc == 1); assert(argc == 1);
void *vector; void *vector;
size_t dimensions; size_t dimensions;
@ -1294,7 +1293,6 @@ static void vec_type(sqlite3_context *context, int argc,
} }
sqlite3_result_text(context, vec_type_name(elementType), -1, SQLITE_STATIC); sqlite3_result_text(context, vec_type_name(elementType), -1, SQLITE_STATIC);
cleanup(vector); cleanup(vector);
} }
static void vec_quantize_binary(sqlite3_context *context, int argc, static void vec_quantize_binary(sqlite3_context *context, int argc,
sqlite3_value **argv) { sqlite3_value **argv) {
@ -1318,7 +1316,10 @@ static void vec_quantize_binary(sqlite3_context *context, int argc,
return; return;
} }
if ((dimensions % CHAR_BIT) != 0) { if ((dimensions % CHAR_BIT) != 0) {
sqlite3_result_error(context, "Binary quantization requires vectors with a length divisible by 8", -1); sqlite3_result_error(
context,
"Binary quantization requires vectors with a length divisible by 8",
-1);
goto cleanup; goto cleanup;
return; return;
} }
@ -1349,7 +1350,8 @@ static void vec_quantize_binary(sqlite3_context *context, int argc,
break; break;
} }
case SQLITE_VEC_ELEMENT_TYPE_BIT: { case SQLITE_VEC_ELEMENT_TYPE_BIT: {
sqlite3_result_error(context, "Can only binary quantize float or int8 vectors", -1); sqlite3_result_error(context,
"Can only binary quantize float or int8 vectors", -1);
sqlite3_free(out); sqlite3_free(out);
return; return;
} }
@ -1357,7 +1359,6 @@ static void vec_quantize_binary(sqlite3_context *context, int argc,
sqlite3_result_blob(context, out, sz, sqlite3_free); sqlite3_result_blob(context, out, sz, sqlite3_free);
sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_BIT); sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_BIT);
cleanup: cleanup:
vectorCleanup(vector); vectorCleanup(vector);
} }
@ -1389,7 +1390,8 @@ static void vec_quantize_int8(sqlite3_context *context, int argc,
(sqlite3_value_bytes(argv[1]) != strlen("unit")) || (sqlite3_value_bytes(argv[1]) != strlen("unit")) ||
(sqlite3_stricmp((const char *)sqlite3_value_text(argv[1]), "unit") != (sqlite3_stricmp((const char *)sqlite3_value_text(argv[1]), "unit") !=
0)) { 0)) {
sqlite3_result_error(context, "2nd argument to vec_quantize_i8() must be 'unit'.", -1); sqlite3_result_error(
context, "2nd argument to vec_quantize_i8() must be 'unit'.", -1);
sqlite3_free(out); sqlite3_free(out);
goto cleanup; goto cleanup;
} }
@ -1405,7 +1407,6 @@ cleanup:
srcCleanup(srcVector); srcCleanup(srcVector);
} }
static void vec_add(sqlite3_context *context, int argc, sqlite3_value **argv) { static void vec_add(sqlite3_context *context, int argc, sqlite3_value **argv) {
assert(argc == 2); assert(argc == 2);
int rc; int rc;
@ -3287,11 +3288,13 @@ int vec0_column_idx_to_vector_idx(vec0_vtab *pVtab, int column_idx) {
return column_idx - VEC0_COLUMN_VECTORN_START; return column_idx - VEC0_COLUMN_VECTORN_START;
} }
int vec0_get_chunk_position(vec0_vtab * p, i64 rowid, sqlite3_value ** id, i64 *chunk_id, i64 * chunk_offset) { int vec0_get_chunk_position(vec0_vtab *p, i64 rowid, sqlite3_value **id,
i64 *chunk_id, i64 *chunk_offset) {
int rc; int rc;
if (!p->stmtRowidsGetChunkPosition) { if (!p->stmtRowidsGetChunkPosition) {
const char * zSql = sqlite3_mprintf("SELECT id, chunk_id, chunk_offset " const char *zSql =
sqlite3_mprintf("SELECT id, chunk_id, chunk_offset "
"FROM " VEC0_SHADOW_ROWIDS_NAME " WHERE rowid = ?", "FROM " VEC0_SHADOW_ROWIDS_NAME " WHERE rowid = ?",
p->schemaName, p->tableName); p->schemaName, p->tableName);
if (!zSql) { if (!zSql) {
@ -3301,8 +3304,8 @@ int vec0_get_chunk_position(vec0_vtab * p, i64 rowid, sqlite3_value ** id, i64 *
rc = sqlite3_prepare_v2(p->db, zSql, -1, &p->stmtRowidsGetChunkPosition, 0); rc = sqlite3_prepare_v2(p->db, zSql, -1, &p->stmtRowidsGetChunkPosition, 0);
sqlite3_free((void *)zSql); sqlite3_free((void *)zSql);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
vtab_set_error(&p->base, vtab_set_error(
VEC_INTERAL_ERROR &p->base, VEC_INTERAL_ERROR
"could not initialize 'rowids get chunk position' statement"); "could not initialize 'rowids get chunk position' statement");
goto cleanup; goto cleanup;
} }
@ -3310,7 +3313,8 @@ int vec0_get_chunk_position(vec0_vtab * p, i64 rowid, sqlite3_value ** id, i64 *
sqlite3_bind_int64(p->stmtRowidsGetChunkPosition, 1, rowid); sqlite3_bind_int64(p->stmtRowidsGetChunkPosition, 1, rowid);
rc = sqlite3_step(p->stmtRowidsGetChunkPosition); rc = sqlite3_step(p->stmtRowidsGetChunkPosition);
// special case: when no results, return SQLITE_EMPTY to convene "that chunk position doesnt exist" // special case: when no results, return SQLITE_EMPTY to convene "that chunk
// position doesnt exist"
if (rc == SQLITE_DONE) { if (rc == SQLITE_DONE) {
rc = SQLITE_EMPTY; rc = SQLITE_EMPTY;
goto cleanup; goto cleanup;
@ -3320,7 +3324,8 @@ int vec0_get_chunk_position(vec0_vtab * p, i64 rowid, sqlite3_value ** id, i64 *
} }
if (id) { if (id) {
sqlite3_value *value = sqlite3_column_value(p->stmtRowidsGetChunkPosition, 0); sqlite3_value *value =
sqlite3_column_value(p->stmtRowidsGetChunkPosition, 0);
*id = sqlite3_value_dup(value); *id = sqlite3_value_dup(value);
if (!*id) { if (!*id) {
rc = SQLITE_NOMEM; rc = SQLITE_NOMEM;
@ -3452,14 +3457,14 @@ int vec0_get_vector_data(vec0_vtab *pVtab, i64 rowid, int vector_column_idx,
} }
rc = sqlite3_blob_open(p->db, p->schemaName, rc = sqlite3_blob_open(p->db, p->schemaName,
p->shadowVectorChunksNames[vector_column_idx], "vectors", chunk_id, 0, p->shadowVectorChunksNames[vector_column_idx],
&vectorBlob); "vectors", chunk_id, 0, &vectorBlob);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
// TODO evidence-of // TODO evidence-of
vtab_set_error( vtab_set_error(&pVtab->base,
&pVtab->base, "Could not fetch vector data for %lld, opening blob failed",
"Could not fetch vector data for %lld, opening blob failed", rowid); rowid);
rc = SQLITE_ERROR; rc = SQLITE_ERROR;
goto cleanup; goto cleanup;
} }
@ -3495,7 +3500,8 @@ cleanup:
brc = sqlite3_blob_close(vectorBlob); brc = sqlite3_blob_close(vectorBlob);
if ((rc == SQLITE_OK) && (brc != SQLITE_OK)) { if ((rc == SQLITE_OK) && (brc != SQLITE_OK)) {
vtab_set_error( vtab_set_error(
&p->base, VEC_INTERAL_ERROR "unknown error, could not close vector blob, please file an issue"); &p->base, VEC_INTERAL_ERROR
"unknown error, could not close vector blob, please file an issue");
return brc; return brc;
} }
@ -3517,7 +3523,8 @@ int vec0_get_latest_chunk_rowid(vec0_vtab *p, i64 *chunk_rowid) {
sqlite3_free((void *)zSql); sqlite3_free((void *)zSql);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
// IMP: V21406_05476 // IMP: V21406_05476
vtab_set_error(&p->base, VEC_INTERAL_ERROR "could not initialize 'latest chunk' statement"); vtab_set_error(&p->base, VEC_INTERAL_ERROR
"could not initialize 'latest chunk' statement");
goto cleanup; goto cleanup;
} }
} }
@ -3541,7 +3548,6 @@ int vec0_get_latest_chunk_rowid(vec0_vtab *p, i64 *chunk_rowid) {
} }
rc = SQLITE_OK; rc = SQLITE_OK;
cleanup: cleanup:
if (p->stmtLatestChunk) { if (p->stmtLatestChunk) {
sqlite3_reset(p->stmtLatestChunk); sqlite3_reset(p->stmtLatestChunk);
@ -3554,7 +3560,8 @@ int vec0_rowids_insert_rowid(vec0_vtab *p, i64 rowid) {
int entered = 0; int entered = 0;
UNUSED_PARAMETER(entered); // temporary UNUSED_PARAMETER(entered); // temporary
if (!p->stmtRowidsInsertRowid) { if (!p->stmtRowidsInsertRowid) {
const char * zSql = sqlite3_mprintf("INSERT INTO " VEC0_SHADOW_ROWIDS_NAME "(rowid)" const char *zSql =
sqlite3_mprintf("INSERT INTO " VEC0_SHADOW_ROWIDS_NAME "(rowid)"
"VALUES (?);", "VALUES (?);",
p->schemaName, p->tableName); p->schemaName, p->tableName);
if (!zSql) { if (!zSql) {
@ -3564,12 +3571,12 @@ int vec0_rowids_insert_rowid(vec0_vtab *p, i64 rowid) {
rc = sqlite3_prepare_v2(p->db, zSql, -1, &p->stmtRowidsInsertRowid, 0); rc = sqlite3_prepare_v2(p->db, zSql, -1, &p->stmtRowidsInsertRowid, 0);
sqlite3_free((void *)zSql); sqlite3_free((void *)zSql);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
vtab_set_error(&p->base, VEC_INTERAL_ERROR "could not initialize 'insert rowids' statement"); vtab_set_error(&p->base, VEC_INTERAL_ERROR
"could not initialize 'insert rowids' statement");
goto cleanup; goto cleanup;
} }
} }
#ifdef SQLITE_THREADSAFE #ifdef SQLITE_THREADSAFE
if (sqlite3_mutex_enter) { if (sqlite3_mutex_enter) {
sqlite3_mutex_enter(sqlite3_db_mutex(p->db)); sqlite3_mutex_enter(sqlite3_db_mutex(p->db));
@ -3586,15 +3593,14 @@ int vec0_rowids_insert_rowid(vec0_vtab *p, i64 rowid) {
p->tableName); p->tableName);
} else { } else {
// IMP: V04679_21517 // IMP: V04679_21517
vtab_set_error( vtab_set_error(&p->base,
&p->base, "Error inserting rowid into rowids shadow table: %s", "Error inserting rowid into rowids shadow table: %s",
sqlite3_errmsg(sqlite3_db_handle(p->stmtRowidsInsertId))); sqlite3_errmsg(sqlite3_db_handle(p->stmtRowidsInsertId)));
} }
rc = SQLITE_ERROR; rc = SQLITE_ERROR;
goto cleanup; goto cleanup;
} }
rc = SQLITE_OK; rc = SQLITE_OK;
cleanup: cleanup:
@ -3609,7 +3615,6 @@ int vec0_rowids_insert_rowid(vec0_vtab *p, i64 rowid) {
} }
#endif #endif
return rc; return rc;
} }
int vec0_rowids_insert_id(vec0_vtab *p, sqlite3_value *idValue, i64 *rowid) { int vec0_rowids_insert_id(vec0_vtab *p, sqlite3_value *idValue, i64 *rowid) {
@ -3617,7 +3622,8 @@ int vec0_rowids_insert_id(vec0_vtab *p, sqlite3_value * idValue, i64 * rowid) {
int entered = 0; int entered = 0;
UNUSED_PARAMETER(entered); // temporary UNUSED_PARAMETER(entered); // temporary
if (!p->stmtRowidsInsertId) { if (!p->stmtRowidsInsertId) {
const char * zSql = sqlite3_mprintf("INSERT INTO " VEC0_SHADOW_ROWIDS_NAME "(id)" const char *zSql =
sqlite3_mprintf("INSERT INTO " VEC0_SHADOW_ROWIDS_NAME "(id)"
"VALUES (?);", "VALUES (?);",
p->schemaName, p->tableName); p->schemaName, p->tableName);
if (!zSql) { if (!zSql) {
@ -3653,8 +3659,8 @@ int vec0_rowids_insert_id(vec0_vtab *p, sqlite3_value * idValue, i64 * rowid) {
} else { } else {
// IMP: V24016_08086 // IMP: V24016_08086
// IMP: V15177_32015 // IMP: V15177_32015
vtab_set_error( vtab_set_error(&p->base,
&p->base, "Error inserting id into rowids shadow table: %s", "Error inserting id into rowids shadow table: %s",
sqlite3_errmsg(sqlite3_db_handle(p->stmtRowidsInsertId))); sqlite3_errmsg(sqlite3_db_handle(p->stmtRowidsInsertId)));
} }
rc = SQLITE_ERROR; rc = SQLITE_ERROR;
@ -3678,7 +3684,8 @@ int vec0_rowids_insert_id(vec0_vtab *p, sqlite3_value * idValue, i64 * rowid) {
return rc; return rc;
} }
int vec0_rowids_update_position(vec0_vtab * p, i64 rowid, i64 chunk_rowid, i64 chunk_offset) { int vec0_rowids_update_position(vec0_vtab *p, i64 rowid, i64 chunk_rowid,
i64 chunk_offset) {
int rc = SQLITE_OK; int rc = SQLITE_OK;
if (!p->stmtRowidsUpdatePosition) { if (!p->stmtRowidsUpdatePosition) {
@ -3690,12 +3697,10 @@ const char * zSql = sqlite3_mprintf(" UPDATE " VEC0_SHADOW_ROWIDS_NAME
rc = SQLITE_NOMEM; rc = SQLITE_NOMEM;
goto cleanup; goto cleanup;
} }
rc = sqlite3_prepare_v2(p->db, zSql, -1, &p->stmtRowidsUpdatePosition, rc = sqlite3_prepare_v2(p->db, zSql, -1, &p->stmtRowidsUpdatePosition, 0);
0);
sqlite3_free((void *)zSql); sqlite3_free((void *)zSql);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
vtab_set_error(&p->base, vtab_set_error(&p->base, VEC_INTERAL_ERROR
VEC_INTERAL_ERROR
"could not initialize 'update rowids position' statement"); "could not initialize 'update rowids position' statement");
goto cleanup; goto cleanup;
} }
@ -3724,9 +3729,6 @@ const char * zSql = sqlite3_mprintf(" UPDATE " VEC0_SHADOW_ROWIDS_NAME
} }
return rc; return rc;
} }
/** /**
@ -4805,7 +4807,8 @@ int vec0Filter_knn_chunks_iter(vec0_vtab *p, sqlite3_stmt *stmtChunks,
topk_distances[i] = tmp_topk_distances[i]; topk_distances[i] = tmp_topk_distances[i];
} }
k_used = used; k_used = used;
// blobVectors is always opened with read-only permissions, so this never fails. // blobVectors is always opened with read-only permissions, so this never
// fails.
sqlite3_blob_close(blobVectors); sqlite3_blob_close(blobVectors);
blobVectors = NULL; blobVectors = NULL;
} }
@ -4828,7 +4831,8 @@ cleanup:
sqlite3_free(bmRowids); sqlite3_free(bmRowids);
sqlite3_free(baseVectors); sqlite3_free(baseVectors);
sqlite3_free(chunk_distances); sqlite3_free(chunk_distances);
// blobVectors is always opened with read-only permissions, so this never fails. // blobVectors is always opened with read-only permissions, so this never
// fails.
sqlite3_blob_close(blobVectors); sqlite3_blob_close(blobVectors);
return rc; return rc;
} }
@ -5403,7 +5407,6 @@ int vec0Update_InsertNextAvailableStep(
goto cleanup; goto cleanup;
} }
rc = sqlite3_blob_open(p->db, p->schemaName, p->shadowChunksName, "validity", rc = sqlite3_blob_open(p->db, p->schemaName, p->shadowChunksName, "validity",
*chunk_rowid, 1, blobChunksValidity); *chunk_rowid, 1, blobChunksValidity);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
@ -5480,7 +5483,9 @@ done:
*blobChunksValidity = NULL; *blobChunksValidity = NULL;
*bufferChunksValidity = NULL; *bufferChunksValidity = NULL;
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
vtab_set_error(&p->base, VEC_INTERAL_ERROR "unknown error, blobChunksValidity could not be closed, please file an issue."); vtab_set_error(&p->base, VEC_INTERAL_ERROR
"unknown error, blobChunksValidity could not be closed, "
"please file an issue.");
rc = SQLITE_ERROR; rc = SQLITE_ERROR;
goto cleanup; goto cleanup;
} }
@ -5582,7 +5587,6 @@ int vec0Update_InsertWriteFinalStep(vec0_vtab *p, i64 chunk_rowid,
int rc, brc; int rc, brc;
sqlite3_blob *blobChunksRowids = NULL; sqlite3_blob *blobChunksRowids = NULL;
// mark the validity bit for this row in the chunk's validity bitmap // mark the validity bit for this row in the chunk's validity bitmap
// Get the byte offset of the bitmap // Get the byte offset of the bitmap
char unsigned bx = bufferChunksValidity[chunk_offset / CHAR_BIT]; char unsigned bx = bufferChunksValidity[chunk_offset / CHAR_BIT];
@ -5648,7 +5652,6 @@ int vec0Update_InsertWriteFinalStep(vec0_vtab *p, i64 chunk_rowid,
} }
} }
// write the new rowid to the rowids column of the _chunks table // write the new rowid to the rowids column of the _chunks table
rc = sqlite3_blob_open(p->db, p->schemaName, p->shadowChunksName, "rowids", rc = sqlite3_blob_open(p->db, p->schemaName, p->shadowChunksName, "rowids",
chunk_rowid, 1, &blobChunksRowids); chunk_rowid, 1, &blobChunksRowids);
@ -5826,8 +5829,9 @@ cleanup:
sqlite3_free((void *)bufferChunksValidity); sqlite3_free((void *)bufferChunksValidity);
int brc = sqlite3_blob_close(blobChunksValidity); int brc = sqlite3_blob_close(blobChunksValidity);
if ((rc == SQLITE_OK) && (brc != SQLITE_OK)) { if ((rc == SQLITE_OK) && (brc != SQLITE_OK)) {
vtab_set_error( vtab_set_error(&p->base,
&p->base, VEC_INTERAL_ERROR "unknown error, blobChunksValidity could not be closed, please file an issue"); VEC_INTERAL_ERROR "unknown error, blobChunksValidity could "
"not be closed, please file an issue");
return brc; return brc;
} }
return rc; return rc;

View file

@ -513,12 +513,14 @@ def test_vec_slice():
def test_vec_type(): def test_vec_type():
vec_type = lambda *args, a="?": db.execute(f"select vec_type({a})", args).fetchone()[0] vec_type = lambda *args, a="?": db.execute(
assert vec_type('[1]') == "float32" f"select vec_type({a})", args
).fetchone()[0]
assert vec_type("[1]") == "float32"
assert vec_type(b"\xaa\xbb\xcc\xdd") == "float32" assert vec_type(b"\xaa\xbb\xcc\xdd") == "float32"
assert vec_type('[1]', a='vec_f32(?)') == "float32" assert vec_type("[1]", a="vec_f32(?)") == "float32"
assert vec_type('[1]', a='vec_int8(?)') == "int8" assert vec_type("[1]", a="vec_int8(?)") == "int8"
assert vec_type(b"\xaa", a='vec_bit(?)') == "bit" assert vec_type(b"\xaa", a="vec_bit(?)") == "bit"
with _raises("invalid float32 vector"): with _raises("invalid float32 vector"):
vec_type(b"\xaa") vec_type(b"\xaa")
@ -697,7 +699,10 @@ def test_vec0_inserts():
db.commit() db.commit()
db.set_authorizer(authorizer_deny_on(sqlite3.SQLITE_INSERT, "t1_rowids")) db.set_authorizer(authorizer_deny_on(sqlite3.SQLITE_INSERT, "t1_rowids"))
# EVIDENCE-OF: V04679_21517 vec0 INSERT failed on _rowid shadow insert raises error # EVIDENCE-OF: V04679_21517 vec0 INSERT failed on _rowid shadow insert raises error
with _raises("Internal sqlite-vec error: could not initialize 'insert rowids' statement", sqlite3.DatabaseError): with _raises(
"Internal sqlite-vec error: could not initialize 'insert rowids' statement",
sqlite3.DatabaseError,
):
db.execute("insert into t1 values (2, '[2,2,2,2]')") db.execute("insert into t1 values (2, '[2,2,2,2]')")
db.set_authorizer(None) db.set_authorizer(None)
db.rollback() db.rollback()
@ -1798,7 +1803,7 @@ def test_vec0_create_errors():
db.set_authorizer(authorizer_deny_on(sqlite3.SQLITE_READ, "t1_chunks", "")) db.set_authorizer(authorizer_deny_on(sqlite3.SQLITE_READ, "t1_chunks", ""))
with _raises( with _raises(
"Internal sqlite-vec error: could not initialize 'latest chunk' statement", "Internal sqlite-vec error: could not initialize 'latest chunk' statement",
sqlite3.DatabaseError sqlite3.DatabaseError,
): ):
db.execute("create virtual table t1 using vec0(a float[1])") db.execute("create virtual table t1 using vec0(a float[1])")
db.execute("insert into t1(a) values (X'AABBCCDD')") db.execute("insert into t1(a) values (X'AABBCCDD')")
@ -1808,21 +1813,22 @@ def test_vec0_create_errors():
db.execute("BEGIN") db.execute("BEGIN")
db.set_authorizer(authorizer_deny_on(sqlite3.SQLITE_INSERT, "t1_rowids")) db.set_authorizer(authorizer_deny_on(sqlite3.SQLITE_INSERT, "t1_rowids"))
with _raises( with _raises(
"Internal sqlite-vec error: could not initialize 'insert rowids id' statement", sqlite3.DatabaseError "Internal sqlite-vec error: could not initialize 'insert rowids id' statement",
sqlite3.DatabaseError,
): ):
db.execute("create virtual table t1 using vec0(a float[1])") db.execute("create virtual table t1 using vec0(a float[1])")
db.execute("insert into t1(a) values (X'AABBCCDD')") db.execute("insert into t1(a) values (X'AABBCCDD')")
db.set_authorizer(None) db.set_authorizer(None)
db.rollback() db.rollback()
db.commit() db.commit()
db.execute("BEGIN") db.execute("BEGIN")
db.set_authorizer( db.set_authorizer(
authorizer_deny_on(sqlite3.SQLITE_UPDATE, "t1_rowids", "chunk_id") authorizer_deny_on(sqlite3.SQLITE_UPDATE, "t1_rowids", "chunk_id")
) )
with _raises( with _raises(
"Internal sqlite-vec error: could not initialize 'update rowids position' statement", sqlite3.DatabaseError "Internal sqlite-vec error: could not initialize 'update rowids position' statement",
sqlite3.DatabaseError,
): ):
db.execute("create virtual table t1 using vec0(a float[1])") db.execute("create virtual table t1 using vec0(a float[1])")
db.execute("insert into t1(a) values (X'AABBCCDD')") db.execute("insert into t1(a) values (X'AABBCCDD')")
@ -2247,9 +2253,10 @@ def test_vec0_stress_small_chunks():
] ]
) )
def test_vec0_distance_metric(): def test_vec0_distance_metric():
base = "('[1, 2]'), ('[3, 4]'), ('[5, 6]')" base = "('[1, 2]'), ('[3, 4]'), ('[5, 6]')"
q = '[-1, -2]' q = "[-1, -2]"
db = connect(EXT_PATH) db = connect(EXT_PATH)
db.execute("create virtual table v1 using vec0( a float[2])") db.execute("create virtual table v1 using vec0( a float[2])")
@ -2265,26 +2272,34 @@ def test_vec0_distance_metric():
db.execute(f"insert into v4(a) values {base}") db.execute(f"insert into v4(a) values {base}")
# default (L2) # default (L2)
assert execute_all(db, "select rowid, distance from v1 where a match ? and k = 3", [q]) == [ assert execute_all(
db, "select rowid, distance from v1 where a match ? and k = 3", [q]
) == [
{"rowid": 1, "distance": 4.4721360206604}, {"rowid": 1, "distance": 4.4721360206604},
{"rowid": 2, "distance": 7.211102485656738}, {"rowid": 2, "distance": 7.211102485656738},
{"rowid": 3, "distance": 10.0}, {"rowid": 3, "distance": 10.0},
] ]
# l2 # l2
assert execute_all(db, "select rowid, distance from v2 where a match ? and k = 3", [q]) == [ assert execute_all(
db, "select rowid, distance from v2 where a match ? and k = 3", [q]
) == [
{"rowid": 1, "distance": 4.4721360206604}, {"rowid": 1, "distance": 4.4721360206604},
{"rowid": 2, "distance": 7.211102485656738}, {"rowid": 2, "distance": 7.211102485656738},
{"rowid": 3, "distance": 10.0}, {"rowid": 3, "distance": 10.0},
] ]
# l1 # l1
assert execute_all(db, "select rowid, distance from v3 where a match ? and k = 3", [q]) == [ assert execute_all(
db, "select rowid, distance from v3 where a match ? and k = 3", [q]
) == [
{"rowid": 1, "distance": 6}, {"rowid": 1, "distance": 6},
{"rowid": 2, "distance": 10}, {"rowid": 2, "distance": 10},
{"rowid": 3, "distance": 14}, {"rowid": 3, "distance": 14},
] ]
# consine # consine
assert execute_all(db, "select rowid, distance from v4 where a match ? and k = 3", [q]) == [ assert execute_all(
db, "select rowid, distance from v4 where a match ? and k = 3", [q]
) == [
{"rowid": 3, "distance": 1.9734171628952026}, {"rowid": 3, "distance": 1.9734171628952026},
{"rowid": 2, "distance": 1.9838699102401733}, {"rowid": 2, "distance": 1.9838699102401733},
{"rowid": 1, "distance": 2}, {"rowid": 1, "distance": 2},
@ -2293,12 +2308,13 @@ def test_vec0_distance_metric():
def test_vec0_vacuum(): def test_vec0_vacuum():
db = connect(EXT_PATH) db = connect(EXT_PATH)
db.execute('create virtual table vec_t using vec0(a float[1]);') db.execute("create virtual table vec_t using vec0(a float[1]);")
db.execute("begin") db.execute("begin")
db.execute("insert into vec_t(a) values (X'AABBCCDD')") db.execute("insert into vec_t(a) values (X'AABBCCDD')")
db.commit() db.commit()
db.execute("vacuum") db.execute("vacuum")
def rowids_value(buffer: bytearray) -> List[int]: def rowids_value(buffer: bytearray) -> List[int]:
assert (len(buffer) % 8) == 0 assert (len(buffer) % 8) == 0
n = int(len(buffer) / 8) n = int(len(buffer) / 8)