fixes and tests

This commit is contained in:
Alex Garcia 2024-11-19 21:46:50 -08:00
parent 4039328eda
commit a657b3a216
4 changed files with 121 additions and 32 deletions

View file

@ -5508,7 +5508,7 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
switch(p->metadata_columns[metadata_idx].kind) { switch(p->metadata_columns[metadata_idx].kind) {
case VEC0_METADATA_COLUMN_KIND_FLOAT: case VEC0_METADATA_COLUMN_KIND_FLOAT:
case VEC0_METADATA_COLUMN_KIND_BOOLEAN: { case VEC0_METADATA_COLUMN_KIND_BOOLEAN: {
// IMP: TODO // IMP: V15248_32086
rc = SQLITE_ERROR; rc = SQLITE_ERROR;
vtab_set_error(pVTab, "'xxx in (...)' is only available on INTEGER or TEXT metadata columns."); vtab_set_error(pVTab, "'xxx in (...)' is only available on INTEGER or TEXT metadata columns.");
goto done; goto done;
@ -6142,7 +6142,8 @@ int vec0_metadata_filter_text(vec0_vtab * p, sqlite3_value * value, const void *
} }
} }
if(metadataInIdx < 0) { if(metadataInIdx < 0) {
abort(); // TODO rc = SQLITE_ERROR;
goto done;
} }
struct Vec0MetadataIn * metadataIn = &((struct Vec0MetadataIn *) aMetadataIn->z)[metadataInIdx]; struct Vec0MetadataIn * metadataIn = &((struct Vec0MetadataIn *) aMetadataIn->z)[metadataInIdx];
@ -6313,7 +6314,8 @@ int vec0_set_metadata_filter_bitmap(
} }
} }
if(metadataInIdx < 0) { if(metadataInIdx < 0) {
abort(); // TODO rc = SQLITE_ERROR;
goto done;
} }
struct Vec0MetadataIn * metadataIn = &((struct Vec0MetadataIn *) aMetadataIn->z)[metadataInIdx]; struct Vec0MetadataIn * metadataIn = &((struct Vec0MetadataIn *) aMetadataIn->z)[metadataInIdx];
struct Array * aTarget = &(metadataIn->array); struct Array * aTarget = &(metadataIn->array);
@ -6916,7 +6918,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
} }
if (rc != SQLITE_DONE) { if (rc != SQLITE_DONE) {
vtab_set_error(&p->base, "fuck"); // TODO vtab_set_error(&p->base, "Error fetching next value in `x in (...)` integer expression");
goto cleanup; goto cleanup;
} }
@ -6933,7 +6935,6 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
int n = sqlite3_value_bytes(entry); int n = sqlite3_value_bytes(entry);
struct Vec0MetadataInTextEntry entry; struct Vec0MetadataInTextEntry entry;
// TODO if this exits early, does it get properly cleaned up
entry.zString = sqlite3_mprintf("%.*s", n, s); entry.zString = sqlite3_mprintf("%.*s", n, s);
if(!entry.zString) { if(!entry.zString) {
rc = SQLITE_NOMEM; rc = SQLITE_NOMEM;
@ -6947,20 +6948,21 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
} }
if (rc != SQLITE_DONE) { if (rc != SQLITE_DONE) {
vtab_set_error(&p->base, "fuck"); // TODO vtab_set_error(&p->base, "Error fetching next value in `x in (...)` text expression");
goto cleanup; goto cleanup;
} }
break; break;
} }
default: { default: {
abort(); vtab_set_error(&p->base, "Internal sqlite-vec error");
goto cleanup;
} }
} }
rc = array_append(aMetadataIn, &item); rc = array_append(aMetadataIn, &item);
if(rc != SQLITE_OK) { if(rc != SQLITE_OK) {
abort(); // TODO goto cleanup;
} }
} }
#endif #endif
@ -7291,8 +7293,18 @@ static int vec0Column_fullscan(vec0_vtab *pVtab, vec0_cursor *pCur,
int metadata_idx = vec0_column_idx_to_metadata_idx(pVtab, i); int metadata_idx = vec0_column_idx_to_metadata_idx(pVtab, i);
int rc = vec0_result_metadata_value_for_rowid(pVtab, rowid, metadata_idx, context); int rc = vec0_result_metadata_value_for_rowid(pVtab, rowid, metadata_idx, context);
if(rc != SQLITE_OK) { if(rc != SQLITE_OK) {
// TODO handle // IMP: V15466_32305
sqlite3_result_error(context, "fuck", -1); const char * zErr = sqlite3_mprintf(
"Could not extract metadata value for column %.*s at rowid %lld",
pVtab->metadata_columns[metadata_idx].name_length,
pVtab->metadata_columns[metadata_idx].name, rowid
);
if(zErr) {
sqlite3_result_error(context, zErr, -1);
sqlite3_free((void *) zErr);
}else {
sqlite3_result_error_nomem(context);
}
} }
} }
return SQLITE_OK; return SQLITE_OK;
@ -7364,8 +7376,17 @@ static int vec0Column_point(vec0_vtab *pVtab, vec0_cursor *pCur,
int metadata_idx = vec0_column_idx_to_metadata_idx(pVtab, i); int metadata_idx = vec0_column_idx_to_metadata_idx(pVtab, i);
int rc = vec0_result_metadata_value_for_rowid(pVtab, rowid, metadata_idx, context); int rc = vec0_result_metadata_value_for_rowid(pVtab, rowid, metadata_idx, context);
if(rc != SQLITE_OK) { if(rc != SQLITE_OK) {
// TODO handle const char * zErr = sqlite3_mprintf(
sqlite3_result_error(context, "fuck", -1); "Could not extract metadata value for column %.*s at rowid %lld",
pVtab->metadata_columns[metadata_idx].name_length,
pVtab->metadata_columns[metadata_idx].name, rowid
);
if(zErr) {
sqlite3_result_error(context, zErr, -1);
sqlite3_free((void *) zErr);
}else {
sqlite3_result_error_nomem(context);
}
} }
} }
@ -7432,8 +7453,17 @@ static int vec0Column_knn(vec0_vtab *pVtab, vec0_cursor *pCur,
i64 rowid = pCur->knn_data->rowids[pCur->knn_data->current_idx]; i64 rowid = pCur->knn_data->rowids[pCur->knn_data->current_idx];
int rc = vec0_result_metadata_value_for_rowid(pVtab, rowid, metadata_idx, context); int rc = vec0_result_metadata_value_for_rowid(pVtab, rowid, metadata_idx, context);
if(rc != SQLITE_OK) { if(rc != SQLITE_OK) {
// TODO: handle const char * zErr = sqlite3_mprintf(
sqlite3_result_error(context, "fuck", -1); "Could not extract metadata value for column %.*s at rowid %lld",
pVtab->metadata_columns[metadata_idx].name_length,
pVtab->metadata_columns[metadata_idx].name, rowid
);
if(zErr) {
sqlite3_result_error(context, zErr, -1);
sqlite3_free((void *) zErr);
}else {
sqlite3_result_error_nomem(context);
}
} }
} }
@ -8199,6 +8229,7 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
); );
goto cleanup; goto cleanup;
} }
// first 1 is for 1-based indexing on sqlite3_bind_*, second 1 is to account for initial rowid parameter
sqlite3_bind_value(stmt, 1 + 1 + auxiliary_key_idx, v); sqlite3_bind_value(stmt, 1 + 1 + auxiliary_key_idx, v);
} }

View file

@ -363,6 +363,24 @@
]), ]),
}) })
# --- # ---
# name: test_errors
OrderedDict({
'sql': 'select * from v',
'rows': list([
OrderedDict({
'rowid': 1,
'vector': b'\x00\x00\x80?',
't': 'aaaaaaaaaaaax',
}),
]),
})
# ---
# name: test_errors.1
dict({
'error': 'OperationalError',
'message': 'Could not extract metadata value for column t at rowid 1',
})
# ---
# name: test_idxstr # name: test_idxstr
OrderedDict({ OrderedDict({
'sql': "select * from vec_movies where synopsis_embedding match '' and k = 0 and is_favorited = true", 'sql': "select * from vec_movies where synopsis_embedding match '' and k = 0 and is_favorited = true",

View file

@ -132,13 +132,11 @@ def tests_command(file_path):
conditions = test["conditions"] conditions = test["conditions"]
expected_closest_ids = test["closest_ids"] expected_closest_ids = test["closest_ids"]
expected_closest_scores = test["closest_scores"] expected_closest_scores = test["closest_scores"]
if "or" in conditions:
num_or_skips += 1
continue
sql = "select rowid, 1 - distance as similarity from v where vector match ? and k = ?" sql = "select rowid, 1 - distance as similarity from v where vector match ? and k = ?"
params = [serialize_float32(query), len(expected_closest_ids)] params = [serialize_float32(query), len(expected_closest_ids)]
if "and" in conditions:
for condition in conditions["and"]: for condition in conditions["and"]:
assert len(condition.keys()) == 1 assert len(condition.keys()) == 1
column = list(condition.keys())[0] column = list(condition.keys())[0]
@ -154,7 +152,25 @@ def tests_command(file_path):
params.append(condition[column]["range"]["lt"]) params.append(condition[column]["range"]["lt"])
else: else:
raise Exception(f"Unknown condition type: {condition_type}") raise Exception(f"Unknown condition type: {condition_type}")
elif "or" in conditions:
column = list(conditions["or"][0].keys())[0]
condition_type = list(conditions["or"][0][column].keys())[0]
assert condition_type == "match"
sql += f" and {column} in ("
for idx, condition in enumerate(conditions["or"]):
if condition_type == "match":
value = condition[column]["match"]["value"]
if idx != 0:
sql += ","
sql += "?"
params.append(value)
elif condition_type == "range":
breakpoint()
else:
raise Exception(f"Unknown condition type: {condition_type}")
sql += ")"
# print(sql, params[1:])
rows = db.execute(sql, params).fetchall() rows = db.execute(sql, params).fetchall()
actual_closest_ids = [row["rowid"] for row in rows] actual_closest_ids = [row["rowid"] for row in rows]
matches = expected_closest_ids == actual_closest_ids matches = expected_closest_ids == actual_closest_ids

View file

@ -309,6 +309,8 @@ def test_vtab_in(db, snapshot):
(8, "[8]", 555, "zzzz", 0, 1.1), (8, "[8]", 555, "zzzz", 0, 1.1),
], ],
) )
# EVIDENCE-OF: V15248_32086
assert exec( assert exec(
db, "select * from v where vector match '[0]' and k = 8 and b in (1, 0)" db, "select * from v where vector match '[0]' and k = 8 and b in (1, 0)"
) == snapshot(name="block-bool") ) == snapshot(name="block-bool")
@ -570,6 +572,28 @@ def test_stress(db, snapshot):
) == snapshot(name="bool-other-op") ) == snapshot(name="bool-other-op")
def test_errors(db, snapshot):
db.execute("create virtual table v using vec0(vector float[1], t text)")
db.execute("insert into v(vector, t) values ('[1]', 'aaaaaaaaaaaax')")
assert exec(db, "select * from v") == snapshot()
# EVIDENCE-OF: V15466_32305
db.set_authorizer(
authorizer_deny_on(sqlite3.SQLITE_READ, "v_metadata_text_data_00", "data")
)
assert exec(db, "select * from v") == snapshot()
def authorizer_deny_on(operation, x1, x2=None):
def _auth(op, p1, p2, p3, p4):
if op == operation and p1 == x1 and p2 == x2:
return sqlite3.SQLITE_DENY
return sqlite3.SQLITE_OK
return _auth
def exec(db, sql, parameters=[]): def exec(db, sql, parameters=[]):
try: try:
rows = db.execute(sql, parameters).fetchall() rows = db.execute(sql, parameters).fetchall()