diff --git a/TODO b/TODO index 179b290..c42ab8f 100644 --- a/TODO +++ b/TODO @@ -7,7 +7,6 @@ # auxiliary columns -- enforce column types, ie STRICT? - in xBestIndex, ensure there are no constraints on any aux column - DELETE and UPDATE support - later: diff --git a/sqlite-vec.c b/sqlite-vec.c index 172a270..1acf4d2 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -2074,7 +2074,7 @@ int vec0_parse_auxiliary_column_definition(const char *source, int source_length 0 || sqlite3_strnicmp(token.start, "double", token.end - token.start) == 0) { - column_type = SQLITE_INTEGER; + column_type = SQLITE_FLOAT; } else if (sqlite3_strnicmp(token.start, "blob", token.end - token.start) ==0) { column_type = SQLITE_BLOB; } else { @@ -6736,6 +6736,20 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, } int auxiliary_key_idx = p->user_column_idxs[i]; sqlite3_value * v = argv[2+VEC0_COLUMN_USERN_START + i]; + int v_type = sqlite3_value_type(v); + if(v_type != SQLITE_NULL && (v_type != p->auxiliary_columns[auxiliary_key_idx].type)) { + sqlite3_finalize(stmt); + rc = SQLITE_ERROR; + vtab_set_error( + pVTab, + "Auxiliary column type mismatch: The auxiliary column %.*s has type %s, but %s was provided.", + p->auxiliary_columns[auxiliary_key_idx].name_length, + p->auxiliary_columns[auxiliary_key_idx].name, + type_name(p->auxiliary_columns[auxiliary_key_idx].type), + type_name(v_type) + ); + goto cleanup; + } sqlite3_bind_value(stmt, 1 + auxiliary_key_idx, v); } diff --git a/tests/__snapshots__/test-auxiliary.ambr b/tests/__snapshots__/test-auxiliary.ambr index e706f6a..8cba12a 100644 --- a/tests/__snapshots__/test-auxiliary.ambr +++ b/tests/__snapshots__/test-auxiliary.ambr @@ -161,3 +161,86 @@ ]), }) # --- +# name: test_types + OrderedDict({ + 'sql': 'select * from v', + 'rows': list([ + ]), + }) +# --- +# name: test_types.1 + OrderedDict({ + 'sql': 'insert into v(vector, aux_int, aux_float, aux_text, aux_blob) values (?, ?, ?, ?, ?)', + 'rows': list([ + ]), + }) +# --- +# name: test_types.2 + OrderedDict({ + 'sql': 'select * from v', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'vector': b'\x11\x11\x11\x11', + 'aux_int': 1, + 'aux_float': 1.22, + 'aux_text': 'text', + 'aux_blob': b'blob', + }), + ]), + }) +# --- +# name: test_types.3 + dict({ + 'error': 'OperationalError', + 'message': 'Auxiliary column type mismatch: The auxiliary column aux_int has type INTEGER, but TEXT was provided.', + }) +# --- +# name: test_types.4 + dict({ + 'error': 'OperationalError', + 'message': 'Auxiliary column type mismatch: The auxiliary column aux_float has type FLOAT, but TEXT was provided.', + }) +# --- +# name: test_types.5 + dict({ + 'error': 'OperationalError', + 'message': 'Auxiliary column type mismatch: The auxiliary column aux_text has type TEXT, but INTEGER was provided.', + }) +# --- +# name: test_types.6 + dict({ + 'error': 'OperationalError', + 'message': 'Auxiliary column type mismatch: The auxiliary column aux_blob has type BLOB, but INTEGER was provided.', + }) +# --- +# name: test_types.7 + OrderedDict({ + 'sql': 'insert into v(vector, aux_int, aux_float, aux_text, aux_blob) values (?, ?, ?, ?, ?)', + 'rows': list([ + ]), + }) +# --- +# name: test_types.8 + OrderedDict({ + 'sql': 'select * from v', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'vector': b'\x11\x11\x11\x11', + 'aux_int': 1, + 'aux_float': 1.22, + 'aux_text': 'text', + 'aux_blob': b'blob', + }), + OrderedDict({ + 'rowid': 2, + 'vector': b'\x11\x11\x11\x11', + 'aux_int': None, + 'aux_float': None, + 'aux_text': None, + 'aux_blob': None, + }), + ]), + }) +# --- diff --git a/tests/test-auxiliary.py b/tests/test-auxiliary.py index 8e210a0..f09de96 100644 --- a/tests/test-auxiliary.py +++ b/tests/test-auxiliary.py @@ -36,7 +36,40 @@ def test_normal(db, snapshot): def test_types(db, snapshot): - pass + db.execute( + """ + create virtual table v using vec0( + vector float[1], + +aux_int integer, + +aux_float float, + +aux_text text, + +aux_blob blob + ) + """ + ) + assert exec(db, "select * from v") == snapshot() + INSERT = "insert into v(vector, aux_int, aux_float, aux_text, aux_blob) values (?, ?, ?, ?, ?)" + + assert ( + exec(db, INSERT, [b"\x11\x11\x11\x11", 1, 1.22, "text", b"blob"]) == snapshot() + ) + assert exec(db, "select * from v") == snapshot() + + # bad types + assert ( + exec(db, INSERT, [b"\x11\x11\x11\x11", "not int", 1.2, "text", b"blob"]) + == snapshot() + ) + assert ( + exec(db, INSERT, [b"\x11\x11\x11\x11", 1, "not float", "text", b"blob"]) + == snapshot() + ) + assert exec(db, INSERT, [b"\x11\x11\x11\x11", 1, 1.2, 1, b"blob"]) == snapshot() + assert exec(db, INSERT, [b"\x11\x11\x11\x11", 1, 1.2, "text", 1]) == snapshot() + + # NULLs are totally chill + assert exec(db, INSERT, [b"\x11\x11\x11\x11", None, None, None, None]) == snapshot() + assert exec(db, "select * from v") == snapshot() def test_updates(db, snapshot):