From da29ace630a09eab057ec0f5761780862e4f4689 Mon Sep 17 00:00:00 2001 From: Alex Garcia Date: Sun, 17 Nov 2024 14:56:31 -0800 Subject: [PATCH] updates --- ARCHITECTURE.md | 38 ++++++++++ sqlite-vec.c | 99 +++++++++++++++++++++++--- tests/__snapshots__/test-metadata.ambr | 26 +++++-- tests/afbd/test-afbd.py | 6 +- tests/test-metadata.py | 7 ++ 5 files changed, 158 insertions(+), 18 deletions(-) diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 13660a1..9bc40ab 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -1,5 +1,43 @@ ## `vec0` +### Shadow Tables + +#### `xyz_chunks` + +- `chunk_id INTEGER` +- `size INTEGER` +- `validity BLOB` +- `rowids BLOB` + + +#### `xyz_rowids` + +- `rowid INTEGER` +- `id` +- `chunk_id INTEGER` +- `chunk_offset INTEGER` + +#### `xyz_vector_chunksNN` + +- `rowid INTEGER` +- `vector BLOB` + +#### `xyz_auxiliary` + +- `rowid INTEGER` +- `valueNN [type]` + +#### `xyz_metadata_chunksNN` + +- `rowid INTEGER` +- `data BLOB` + + +#### `xyz_metadata_text_data_00` + +- `rowid INTEGER` +- `data TEXT` + ### idxStr The `vec0` idxStr is a string composed of single "header" character and 0 or diff --git a/sqlite-vec.c b/sqlite-vec.c index 2687d79..7e49fe7 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -5719,6 +5719,41 @@ int min_idx(const f32 *distances, i32 n, u8 *candidates, i32 *out, i32 k, return SQLITE_OK; } +int vec0_get_metadata_text_long_value( + vec0_vtab * p, + sqlite3_stmt ** stmt, + int metadata_idx, + i64 rowid, + int *n, + char ** s) { + int rc; + if(!(*stmt)) { + const char * zSql = sqlite3_mprintf("select data from " VEC0_SHADOW_METADATA_TEXT_DATA_NAME " where rowid = ?", p->schemaName, p->tableName, metadata_idx); + if(!zSql) { + rc = SQLITE_NOMEM; + goto done; + } + rc = sqlite3_prepare_v2(p->db, zSql, -1, stmt, NULL); + sqlite3_free( (void *) zSql); + if(rc != SQLITE_OK) { + goto done; + } + } + + sqlite3_reset(*stmt); + sqlite3_bind_int64(*stmt, 1, rowid); + rc = sqlite3_step(*stmt); + if(rc != SQLITE_ROW) { + rc = SQLITE_ERROR; + goto done; + } + *s = (char *) sqlite3_column_text(*stmt, 0); + *n = sqlite3_column_bytes(*stmt, 0); + rc = SQLITE_OK; + done: + return rc; +} + /** * @brief Crete at "iterator" (sqlite3_stmt) of chunks with the given constraints * @@ -5845,6 +5880,28 @@ int vec0_set_metadata_filter_bitmap( if(rc != SQLITE_OK) { return rc; } + // TODO: only on text columns + sqlite3_blob * rowidsBlob; + rc = sqlite3_blob_open(p->db, p->schemaName, p->shadowChunksName, "rowids", chunk_rowid, 0, &rowidsBlob); + if(rc != SQLITE_OK) { + return rc; + } + assert(sqlite3_blob_bytes(rowidsBlob) % sizeof(i64) == 0); + assert((sqlite3_blob_bytes(rowidsBlob) / sizeof(i64)) == size); + i64 * rowids; + rowids = sqlite3_malloc(sqlite3_blob_bytes(rowidsBlob)); + if(!rowids) { + sqlite3_blob_close(rowidsBlob); + return SQLITE_NOMEM; + } + + rc = sqlite3_blob_read(rowidsBlob, rowids, sqlite3_blob_bytes(rowidsBlob), 0); + if(rc != SQLITE_OK) { + sqlite3_blob_close(rowidsBlob); + return rc; + } + sqlite3_blob_close(rowidsBlob); + vec0_metadata_column_kind kind = p->metadata_columns[metadata_idx].kind; int szMatch = 0; int blobSize = sqlite3_blob_bytes(blob); @@ -5951,25 +6008,41 @@ int vec0_set_metadata_filter_bitmap( break; } case VEC0_METADATA_COLUMN_KIND_TEXT: { - // TODO: handle large strings. For now just raise a generic error - for(int i = 0; i < size; i++) { - u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; - int n = ((int*) view)[0]; - if(n > 12) { - rc = SQLITE_ERROR; - goto done; - } - } const char * target = (const char *) sqlite3_value_text(value); + int targetn = sqlite3_value_bytes(value); switch(op) { case VEC0_METADATA_OPERATOR_EQ: { + sqlite3_stmt * stmt = NULL; for(int i = 0; i < size; i++) { u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; int n = ((int*) view)[0]; char * s = (char *) &view[4]; - bitmap_set(b, i, strncmp(s, target, n) == 0); + if(n != targetn) { + bitmap_set(b, i, 0); + continue; + } + int prefix_cmp = strncmp(s, target, min(n, 12)); + if(n <= 12) { + bitmap_set(b, i, prefix_cmp == 0); + } + // if the prefix doesnt match, the rest of the string wont match + else if(prefix_cmp) { + bitmap_set(b, i, 0); + } + // need to consult + else { + char *slong; + int slongn; + rc = vec0_get_metadata_text_long_value(p, &stmt, metadata_idx, rowids[i], &slongn, &slong); + if(rc != SQLITE_OK) { + goto done; + } + assert(n == slongn); + bitmap_set(b, i, strncmp(slong, target, n) == 0); + } } + sqlite3_finalize(stmt); break; } case VEC0_METADATA_OPERATOR_NE: { @@ -5977,6 +6050,7 @@ int vec0_set_metadata_filter_bitmap( u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; int n = ((int*) view)[0]; char * s = (char *) &view[4]; + if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */ bitmap_set(b, i, strncmp(s, target, n) != 0); } break; @@ -5986,6 +6060,7 @@ int vec0_set_metadata_filter_bitmap( u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; int n = ((int*) view)[0]; char * s = (char *) &view[4]; + if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */ bitmap_set(b, i, strncmp(s, target, n) > 0); } break; @@ -5995,6 +6070,7 @@ int vec0_set_metadata_filter_bitmap( u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; int n = ((int*) view)[0]; char * s = (char *) &view[4]; + if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */ bitmap_set(b, i, strncmp(s, target, n) >= 0); } break; @@ -6004,6 +6080,7 @@ int vec0_set_metadata_filter_bitmap( u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; int n = ((int*) view)[0]; char * s = (char *) &view[4]; + if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */ bitmap_set(b, i, strncmp(s, target, n) <= 0); } break; @@ -6013,6 +6090,7 @@ int vec0_set_metadata_filter_bitmap( u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; int n = ((int*) view)[0]; char * s = (char *) &view[4]; + if(n > 12) {rc = SQLITE_ERROR;goto done;} /* TODO */ bitmap_set(b, i, strncmp(s, target, n) < 0); } break; @@ -6024,6 +6102,7 @@ int vec0_set_metadata_filter_bitmap( } done: sqlite3_free(buffer); + sqlite3_free(rowids); return rc; } diff --git a/tests/__snapshots__/test-metadata.ambr b/tests/__snapshots__/test-metadata.ambr index f4f2379..7213724 100644 --- a/tests/__snapshots__/test-metadata.ambr +++ b/tests/__snapshots__/test-metadata.ambr @@ -617,10 +617,28 @@ ]), }) # --- +# name: test_long_text_knn[knn-eq-short] + OrderedDict({ + 'sql': "select * from v where vector match X'11111111' and k = 5 and name = ?", + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'vector': b'\x11\x11\x11\x11', + 'name': 'aaaaaaaaaaaa', + }), + ]), + }) +# --- # name: test_long_text_knn[knn-eq-true] - dict({ - 'error': 'OperationalError', - 'message': 'Could not filter metadata fields', + OrderedDict({ + 'sql': "select * from v where vector match X'11111111' and k = 5 and name = ?", + 'rows': list([ + OrderedDict({ + 'rowid': 3, + 'vector': b'\x11\x11\x11\x11', + 'name': 'aaaaaaaaaaaa_aaa', + }), + ]), }) # --- # name: test_long_text_updates @@ -1416,7 +1434,6 @@ # name: test_stress.1 OrderedDict({ 'sql': ''' - select movie_id, title, @@ -1431,7 +1448,6 @@ and num_reviews between 100 and 500 and mean_rating > 3.5 and k = 5; - ''', 'rows': list([ OrderedDict({ diff --git a/tests/afbd/test-afbd.py b/tests/afbd/test-afbd.py index 098af40..31f0a86 100644 --- a/tests/afbd/test-afbd.py +++ b/tests/afbd/test-afbd.py @@ -119,7 +119,7 @@ def tests_command(file_path): tests = [ json.loads(row["data"]) - for row in db.execute("select data from tests limit 2000").fetchall() + for row in db.execute("select data from tests").fetchall() ] num_or_skips = 0 @@ -179,8 +179,8 @@ def tests_command(file_path): == diff["values_changed"][bkey]["new_value"] ) elif len(keys_changed) == 1: - v = int(akey.lstrip("root[").rstrip("]")) - assert v == len(expected_closest_ids) + v = int(keys_changed[0].lstrip("root[").rstrip("]")) + assert (v + 1) == len(expected_closest_ids) else: raise Exception("fuck") num_1off_errors += 1 diff --git a/tests/test-metadata.py b/tests/test-metadata.py index bef088e..3e2fb95 100644 --- a/tests/test-metadata.py +++ b/tests/test-metadata.py @@ -148,10 +148,17 @@ def test_long_text_knn(db, snapshot): "create virtual table v using vec0(vector float[1], name text, chunk_size=8)" ) INSERT = "insert into v(vector, name) values (?, ?)" + exec(db, INSERT, [b"\x11\x11\x11\x11", "aaaaaaaaaaaa"]) + exec(db, INSERT, [b"\x11\x11\x11\x11", "bbbbbbbbbbbb"]) exec(db, INSERT, [b"\x11\x11\x11\x11", "aaaaaaaaaaaa_aaa"]) exec(db, INSERT, [b"\x11\x11\x11\x11", "aaaaaaaaaaaa_bbb"]) exec(db, INSERT, [b"\x11\x11\x11\x11", "aaaaaaaaaaaa_ccc"]) + assert exec( + db, + "select * from v where vector match X'11111111' and k = 5 and name = ?", + ["aaaaaaaaaaaa"], + ) == snapshot(name="knn-eq-short") assert exec( db, "select * from v where vector match X'11111111' and k = 5 and name = ?",