diff --git a/sqlite-vec.c b/sqlite-vec.c index 83b4006..dc33c67 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -9161,6 +9161,9 @@ int vec0_write_metadata_value(vec0_vtab *p, int metadata_column_idx, i64 rowid, * * @return int SQLITE_OK on success, otherwise error code on failure */ +// Forward declaration: needed for INSERT OR REPLACE handling in vec0Update_Insert +int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite3_value *idValue); + int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, sqlite_int64 *pRowid) { UNUSED_PARAMETER(argc); @@ -9281,6 +9284,44 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, goto cleanup; } + // Handle INSERT OR REPLACE: if the conflict resolution is REPLACE and the + // row already exists, delete the existing row first before inserting. + if (sqlite3_vtab_on_conflict(p->db) == SQLITE_REPLACE) { + sqlite3_value *idValue = argv[2 + VEC0_COLUMN_ID]; + int idType = sqlite3_value_type(idValue); + int existingRowExists = 0; + + if (p->pkIsText && idType == SQLITE_TEXT) { + i64 existingRowid; + rc = vec0_rowid_from_id(p, idValue, &existingRowid); + if (rc == SQLITE_OK) { + existingRowExists = 1; + } else if (rc == SQLITE_EMPTY) { + rc = SQLITE_OK; // row doesn't exist, proceed with normal insert + } else { + goto cleanup; + } + } else if (!p->pkIsText && idType == SQLITE_INTEGER) { + i64 existingRowid = sqlite3_value_int64(idValue); + i64 chunk_id_tmp, chunk_offset_tmp; + rc = vec0_get_chunk_position(p, existingRowid, NULL, &chunk_id_tmp, &chunk_offset_tmp); + if (rc == SQLITE_OK) { + existingRowExists = 1; + } else if (rc == SQLITE_EMPTY) { + rc = SQLITE_OK; // row doesn't exist, proceed with normal insert + } else { + goto cleanup; + } + } + + if (existingRowExists) { + rc = vec0Update_Delete(pVTab, idValue); + if (rc != SQLITE_OK) { + goto cleanup; + } + } + } + // Step #1: Insert/get a rowid for this row, from the _rowids table. rc = vec0Update_InsertRowidStep(p, argv[2 + VEC0_COLUMN_ID], &rowid); if (rc != SQLITE_OK) { diff --git a/tests/test-insert-delete.py b/tests/test-insert-delete.py index 7e97ea2..74a2093 100644 --- a/tests/test-insert-delete.py +++ b/tests/test-insert-delete.py @@ -537,3 +537,117 @@ def test_wal_concurrent_reader_during_write(tmp_path): writer.close() reader.close() + + +def test_insert_or_replace_integer_pk(db): + """INSERT OR REPLACE should update vector when rowid already exists.""" + db.execute("create virtual table v using vec0(emb float[4], chunk_size=8)") + + db.execute( + "insert into v(rowid, emb) values (1, ?)", [_f32([1.0, 2.0, 3.0, 4.0])] + ) + # Replace with new vector + db.execute( + "insert or replace into v(rowid, emb) values (1, ?)", + [_f32([10.0, 20.0, 30.0, 40.0])], + ) + + # Should still have exactly 1 row + count = db.execute("select count(*) from v").fetchone()[0] + assert count == 1 + + # Vector should be the replaced value + row = db.execute("select emb from v where rowid = 1").fetchone() + assert row[0] == _f32([10.0, 20.0, 30.0, 40.0]) + + +def test_insert_or_replace_new_row(db): + """INSERT OR REPLACE with a new rowid should just insert normally.""" + db.execute("create virtual table v using vec0(emb float[4], chunk_size=8)") + + db.execute( + "insert or replace into v(rowid, emb) values (1, ?)", + [_f32([1.0, 2.0, 3.0, 4.0])], + ) + + count = db.execute("select count(*) from v").fetchone()[0] + assert count == 1 + + row = db.execute("select emb from v where rowid = 1").fetchone() + assert row[0] == _f32([1.0, 2.0, 3.0, 4.0]) + + +def test_insert_or_replace_text_pk(db): + """INSERT OR REPLACE should work with text primary keys.""" + db.execute( + "create virtual table v using vec0(" + "id text primary key, emb float[4], chunk_size=8" + ")" + ) + + db.execute( + "insert into v(id, emb) values ('doc_a', ?)", + [_f32([1.0, 2.0, 3.0, 4.0])], + ) + db.execute( + "insert or replace into v(id, emb) values ('doc_a', ?)", + [_f32([10.0, 20.0, 30.0, 40.0])], + ) + + count = db.execute("select count(*) from v").fetchone()[0] + assert count == 1 + + row = db.execute("select emb from v where id = 'doc_a'").fetchone() + assert row[0] == _f32([10.0, 20.0, 30.0, 40.0]) + + +def test_insert_or_replace_with_auxiliary(db): + """INSERT OR REPLACE should also replace auxiliary column values.""" + db.execute( + "create virtual table v using vec0(" + "emb float[4], +label text, chunk_size=8" + ")" + ) + + db.execute( + "insert into v(rowid, emb, label) values (1, ?, 'old')", + [_f32([1.0, 2.0, 3.0, 4.0])], + ) + db.execute( + "insert or replace into v(rowid, emb, label) values (1, ?, 'new')", + [_f32([10.0, 20.0, 30.0, 40.0])], + ) + + count = db.execute("select count(*) from v").fetchone()[0] + assert count == 1 + + row = db.execute("select emb, label from v where rowid = 1").fetchone() + assert row[0] == _f32([10.0, 20.0, 30.0, 40.0]) + assert row[1] == "new" + + +def test_insert_or_replace_knn_uses_new_vector(db): + """After INSERT OR REPLACE, KNN should find the new vector, not the old one.""" + db.execute("create virtual table v using vec0(emb float[4], chunk_size=8)") + + db.execute( + "insert into v(rowid, emb) values (1, ?)", [_f32([1.0, 0.0, 0.0, 0.0])] + ) + db.execute( + "insert into v(rowid, emb) values (2, ?)", [_f32([0.0, 1.0, 0.0, 0.0])] + ) + + # Replace row 1's vector to be very close to row 2 + db.execute( + "insert or replace into v(rowid, emb) values (1, ?)", + [_f32([0.0, 0.9, 0.0, 0.0])], + ) + + # KNN for [0, 1, 0, 0] should return row 2 first (exact), then row 1 (close) + rows = db.execute( + "select rowid, distance from v where emb match ? and k = 2", + [_f32([0.0, 1.0, 0.0, 0.0])], + ).fetchall() + assert rows[0][0] == 2 + assert rows[1][0] == 1 + assert rows[1][1] < 0.11 # should be close (L2 distance ≈ 0.1)