This commit is contained in:
Alex Garcia 2024-06-28 10:51:59 -07:00
parent b923c596df
commit 2fdd760dd1
2 changed files with 91 additions and 53 deletions

View file

@ -613,7 +613,8 @@ static int fvec_from_value(sqlite3_value *value, f32 **vector,
} }
*pzErr = sqlite3_mprintf( *pzErr = sqlite3_mprintf(
"Input must have type BLOB (compact format) or TEXT (JSON), found %s", type_name(value_type)); "Input must have type BLOB (compact format) or TEXT (JSON), found %s",
type_name(value_type));
return SQLITE_ERROR; return SQLITE_ERROR;
} }
@ -799,7 +800,6 @@ int vector_from_value(sqlite3_value *value, void **vector, size_t *dimensions,
return SQLITE_ERROR; return SQLITE_ERROR;
} }
int ensure_vector_match(sqlite3_value *aValue, sqlite3_value *bValue, void **a, int ensure_vector_match(sqlite3_value *aValue, sqlite3_value *bValue, void **a,
void **b, enum VectorElementType *element_type, void **b, enum VectorElementType *element_type,
size_t *dimensions, vector_cleanup *outACleanup, size_t *dimensions, vector_cleanup *outACleanup,
@ -3044,7 +3044,7 @@ int vec0_get_id_value_from_rowid(vec0_vtab *pVtab, i64 rowid,
*out = sqlite3_value_dup(value); *out = sqlite3_value_dup(value);
rc = SQLITE_OK; rc = SQLITE_OK;
cleanup: cleanup:
sqlite3_reset(pVtab->stmtRowidsGetChunkPosition); sqlite3_reset(pVtab->stmtRowidsGetChunkPosition);
sqlite3_clear_bindings(pVtab->stmtRowidsGetChunkPosition); sqlite3_clear_bindings(pVtab->stmtRowidsGetChunkPosition);
return rc; return rc;
@ -4471,7 +4471,7 @@ static int vec0Column_point(vec0_vtab *pVtab, vec0_cursor *pCur,
} }
// TODO only have 1st vector data // TODO only have 1st vector data
if (vec0_column_idx_is_vector(pVtab, i)) { if (vec0_column_idx_is_vector(pVtab, i)) {
if(sqlite3_vtab_nochange(context)) { if (sqlite3_vtab_nochange(context)) {
sqlite3_result_null(context); sqlite3_result_null(context);
return SQLITE_OK; return SQLITE_OK;
} }
@ -5231,17 +5231,17 @@ int vec0Update_Delete_DeleteRowids(vec0_vtab *p, i64 rowid) {
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
sqlite3_free(zSql); sqlite3_free(zSql);
if(rc != SQLITE_OK ) { if (rc != SQLITE_OK) {
goto cleanup; goto cleanup;
} }
sqlite3_bind_int64(stmt, 1, rowid); sqlite3_bind_int64(stmt, 1, rowid);
rc = sqlite3_step(stmt); rc = sqlite3_step(stmt);
if(rc != SQLITE_DONE) { if (rc != SQLITE_DONE) {
goto cleanup; goto cleanup;
} }
rc = SQLITE_OK; rc = SQLITE_OK;
cleanup: cleanup:
sqlite3_finalize(stmt); sqlite3_finalize(stmt);
return rc; return rc;
} }
@ -5374,7 +5374,7 @@ int vec0Update_UpdateOnRowid(sqlite3_vtab *pVTab, int argc,
// 1. get chunk_id and chunk_offset from _rowids // 1. get chunk_id and chunk_offset from _rowids
rc = vec0_get_chunk_position(p, rowid, &chunk_id, &chunk_offset); rc = vec0_get_chunk_position(p, rowid, &chunk_id, &chunk_offset);
if(rc != SQLITE_OK) { if (rc != SQLITE_OK) {
return rc; return rc;
} }
@ -5392,13 +5392,13 @@ int vec0Update_UpdateOnRowid(sqlite3_vtab *pVTab, int argc,
// but subtypes don't appear to survive xColumn -> xUpdate, it's always 0. // but subtypes don't appear to survive xColumn -> xUpdate, it's always 0.
// So for now, we'll just use NULL and warn people to not SET X = NULL // So for now, we'll just use NULL and warn people to not SET X = NULL
// in the docs. // in the docs.
if(sqlite3_value_type(valueVector) == SQLITE_NULL) { if (sqlite3_value_type(valueVector) == SQLITE_NULL) {
continue; continue;
} }
rc = vec0Update_UpdateVectorColumn(p, chunk_id, chunk_offset, i, rc = vec0Update_UpdateVectorColumn(p, chunk_id, chunk_offset, i,
valueVector); valueVector);
if(rc != SQLITE_OK){ if (rc != SQLITE_OK) {
return SQLITE_ERROR; return SQLITE_ERROR;
} }
} }
@ -5488,7 +5488,7 @@ static void vec_static_blob_from_raw(sqlite3_context *context, int argc,
sqlite3_value **argv) { sqlite3_value **argv) {
struct static_blob_definition *p; struct static_blob_definition *p;
p = sqlite3_malloc(sizeof(*p)); p = sqlite3_malloc(sizeof(*p));
if(!p) { if (!p) {
sqlite3_result_error_nomem(context); sqlite3_result_error_nomem(context);
return; return;
} }
@ -6179,7 +6179,7 @@ int sqlite3_vec_init(sqlite3 *db, char **pzErrMsg,
#ifdef SQLITE_VEC_ENABLE_EXPERIMENTAL #ifdef SQLITE_VEC_ENABLE_EXPERIMENTAL
vec_static_blob_data *static_blob_data; vec_static_blob_data *static_blob_data;
static_blob_data = sqlite3_malloc(sizeof(*static_blob_data)); static_blob_data = sqlite3_malloc(sizeof(*static_blob_data));
if(!static_blob_data) { if (!static_blob_data) {
return SQLITE_NOMEM; return SQLITE_NOMEM;
} }
memset(static_blob_data, 0, sizeof(*static_blob_data)); memset(static_blob_data, 0, sizeof(*static_blob_data));

View file

@ -538,18 +538,18 @@ def test_vec0_inserts():
"ccc": bitmap_full(128), "ccc": bitmap_full(128),
} }
] ]
#db.execute( # db.execute(
# "update t set aaa = ? where rowid = ?", # "update t set aaa = ? where rowid = ?",
# [np.full((128,), 0.00011, dtype="float32"), 1], # [np.full((128,), 0.00011, dtype="float32"), 1],
#) # )
#assert execute_all(db, "select * from t") == [ # assert execute_all(db, "select * from t") == [
# { # {
# "rowid": 1, # "rowid": 1,
# "aaa": _f32([0.00011] * 128), # "aaa": _f32([0.00011] * 128),
# "bbb": _int8([4] * 128), # "bbb": _int8([4] * 128),
# "ccc": bitmap_full(128), # "ccc": bitmap_full(128),
# } # }
#] # ]
db.execute("create virtual table t1 using vec0(aaa float[4], chunk_size=8)") db.execute("create virtual table t1 using vec0(aaa float[4], chunk_size=8)")
db.execute( db.execute(
@ -998,7 +998,7 @@ def test_vec0_updates():
{ {
"x": "[.1, .1, .1, .1, -.1, -.1, -.1, -.1]", "x": "[.1, .1, .1, .1, -.1, -.1, -.1, -.1]",
"y": "[-.2, .2, .2, .2, .2, .2, -.2, .2]", "y": "[-.2, .2, .2, .2, .2, .2, -.2, .2]",
"z": "[.3, .3, .3, .3, .3, .3, .3, .3]" "z": "[.3, .3, .3, .3, .3, .3, .3, .3]",
}, },
) )
assert execute_all(db, "select * from t3") == [ assert execute_all(db, "select * from t3") == [
@ -1017,12 +1017,23 @@ def test_vec0_updates():
{ {
"rowid": 3, "rowid": 3,
"aaa": _f32([0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3]), "aaa": _f32([0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3]),
"bbb": _int8([37, 37, 37, 37, 37, 37, 37, 37, ]), "bbb": _int8(
[
37,
37,
37,
37,
37,
37,
37,
37,
]
),
"ccc": bitmap("11111111"), "ccc": bitmap("11111111"),
}, },
] ]
db.execute("UPDATE t3 SET aaa = ? WHERE rowid = 1", ['[.9,.9,.9,.9,.9,.9,.9,.9]']) db.execute("UPDATE t3 SET aaa = ? WHERE rowid = 1", ["[.9,.9,.9,.9,.9,.9,.9,.9]"])
assert execute_all(db, "select * from t3") == [ assert execute_all(db, "select * from t3") == [
{ {
"rowid": 1, "rowid": 1,
@ -1039,39 +1050,64 @@ def test_vec0_updates():
{ {
"rowid": 3, "rowid": 3,
"aaa": _f32([0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3]), "aaa": _f32([0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3]),
"bbb": _int8([37, 37, 37, 37, 37, 37, 37, 37, ]), "bbb": _int8(
[
37,
37,
37,
37,
37,
37,
37,
37,
]
),
"ccc": bitmap("11111111"), "ccc": bitmap("11111111"),
}, },
] ]
# EVIDENCE-OF: V15203_32042 vec0 UPDATE validates vector # EVIDENCE-OF: V15203_32042 vec0 UPDATE validates vector
with _raises('Updated vector for the "aaa" column is invalid: invalid float32 vector BLOB length. Must be divisible by 4, found 1'): with _raises(
'Updated vector for the "aaa" column is invalid: invalid float32 vector BLOB length. Must be divisible by 4, found 1'
):
db.execute("UPDATE t3 SET aaa = X'AB' WHERE rowid = 1") db.execute("UPDATE t3 SET aaa = X'AB' WHERE rowid = 1")
# EVIDENCE-OF: V25739_09810 vec0 UPDATE validates dimension length # EVIDENCE-OF: V25739_09810 vec0 UPDATE validates dimension length
with _raises('Dimension mismatch for new updated vector for the "aaa" column. Expected 8 dimensions but received 1.'): with _raises(
'Dimension mismatch for new updated vector for the "aaa" column. Expected 8 dimensions but received 1.'
):
db.execute("UPDATE t3 SET aaa = vec_bit(X'AABBCCDD') WHERE rowid = 1") db.execute("UPDATE t3 SET aaa = vec_bit(X'AABBCCDD') WHERE rowid = 1")
# EVIDENCE-OF: V03643_20481 vec0 UPDATE validates vector column type # EVIDENCE-OF: V03643_20481 vec0 UPDATE validates vector column type
with _raises('Updated vector for the "bbb" column is expected to be of type int8, but a float32 vector was provided.'): with _raises(
'Updated vector for the "bbb" column is expected to be of type int8, but a float32 vector was provided.'
):
db.execute("UPDATE t3 SET bbb = X'ABABABAB' WHERE rowid = 1") db.execute("UPDATE t3 SET bbb = X'ABABABAB' WHERE rowid = 1")
db.execute("CREATE VIRTUAL TABLE t2 USING vec0(a float[2], b float[2])") db.execute("CREATE VIRTUAL TABLE t2 USING vec0(a float[2], b float[2])")
db.execute("INSERT INTO t2(rowid, a, b) VALUES (1, '[.1, .1]', '[.2, .2]')") db.execute("INSERT INTO t2(rowid, a, b) VALUES (1, '[.1, .1]', '[.2, .2]')")
assert execute_all(db, "select * from t2") == [{ assert execute_all(db, "select * from t2") == [
'rowid': 1, {
'a': _f32([.1, .1]), "rowid": 1,
'b': _f32([.2, .2]), "a": _f32([0.1, 0.1]),
}] "b": _f32([0.2, 0.2]),
}
]
# sanity check: the 1st column UPDATE "works", but since the 2nd one fails, # sanity check: the 1st column UPDATE "works", but since the 2nd one fails,
# then aaa should remain unchanged. # then aaa should remain unchanged.
with _raises('Dimension mismatch for new updated vector for the "b" column. Expected 2 dimensions but received 3.'): with _raises(
db.execute("UPDATE t2 SET a = '[.11, .11]', b = '[.22, .22, .22]' WHERE rowid = 1") 'Dimension mismatch for new updated vector for the "b" column. Expected 2 dimensions but received 3.'
assert execute_all(db, "select * from t2") == [{ ):
'rowid': 1, db.execute(
'a': _f32([.1, .1]), "UPDATE t2 SET a = '[.11, .11]', b = '[.22, .22, .22]' WHERE rowid = 1"
'b': _f32([.2, .2]), )
}] assert execute_all(db, "select * from t2") == [
{
"rowid": 1,
"a": _f32([0.1, 0.1]),
"b": _f32([0.2, 0.2]),
}
]
# TODO: set UPDATEs on int8/bit columns # TODO: set UPDATEs on int8/bit columns
# db.execute("UPDATE t3 SET ccc = vec_bit(?) WHERE rowid = 3", [bitmap('01010101')]) # db.execute("UPDATE t3 SET ccc = vec_bit(?) WHERE rowid = 3", [bitmap('01010101')])
@ -1108,7 +1144,8 @@ def test_vec0_text_pk():
); );
""" """
) )
db.executemany("INSERT INTO t VALUES (:t_id, :aaa, :bbb)", db.executemany(
"INSERT INTO t VALUES (:t_id, :aaa, :bbb)",
[ [
{ {
"t_id": "t_1", "t_id": "t_1",
@ -1129,6 +1166,7 @@ def test_vec0_text_pk():
) )
assert execute_all(db, "select * from t") == [] assert execute_all(db, "select * from t") == []
def authorizer_deny_on(operation, x1, x2=None): def authorizer_deny_on(operation, x1, x2=None):
def _auth(op, p1, p2, p3, p4): def _auth(op, p1, p2, p3, p4):
if op == operation and p1 == x1 and p2 == x2: if op == operation and p1 == x1 and p2 == x2: