From 19b6e0a025480d3d7ae91baac184c177d1dd303e Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" Date: Mon, 4 May 2026 03:06:15 -0700 Subject: [PATCH] feat: moved chat persistance to Server Side --- .vscode/launch.json | 31 +- .../141_unique_chat_message_turn_role.py | 66 ++ .../142_token_usage_message_id_unique.py | 134 ++++ surfsense_backend/app/db.py | 33 + .../app/routes/new_chat_routes.py | 231 +++++- surfsense_backend/app/schemas/new_chat.py | 46 ++ .../app/tasks/chat/content_builder.py | 515 ++++++++++++ .../app/tasks/chat/persistence.py | 534 +++++++++++++ .../app/tasks/chat/stream_new_chat.py | 490 ++++++++++-- .../tests/integration/chat/__init__.py | 0 .../chat/test_append_message_recovery.py | 573 ++++++++++++++ .../integration/chat/test_message_id_sse.py | 332 ++++++++ .../integration/chat/test_persistence.py | 747 ++++++++++++++++++ .../unit/tasks/chat/test_content_builder.py | 526 ++++++++++++ .../unit/test_stream_new_chat_contract.py | 24 +- .../new-chat/[[...chat_id]]/page.tsx | 558 ++++++------- surfsense_web/lib/chat/stream-side-effects.ts | 23 + surfsense_web/lib/chat/streaming-state.ts | 31 + surfsense_web/lib/chat/thread-persistence.ts | 11 + 19 files changed, 4515 insertions(+), 390 deletions(-) create mode 100644 surfsense_backend/alembic/versions/141_unique_chat_message_turn_role.py create mode 100644 surfsense_backend/alembic/versions/142_token_usage_message_id_unique.py create mode 100644 surfsense_backend/app/tasks/chat/content_builder.py create mode 100644 surfsense_backend/app/tasks/chat/persistence.py create mode 100644 surfsense_backend/tests/integration/chat/__init__.py create mode 100644 surfsense_backend/tests/integration/chat/test_append_message_recovery.py create mode 100644 surfsense_backend/tests/integration/chat/test_message_id_sse.py create mode 100644 surfsense_backend/tests/integration/chat/test_persistence.py create mode 100644 surfsense_backend/tests/unit/tasks/chat/test_content_builder.py diff --git a/.vscode/launch.json b/.vscode/launch.json index 029e7c647..ad8f8f2a7 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -26,7 +26,16 @@ "pythonArgs": [ "run", "python" - ] + ], + // Mute LangGraph/Pydantic checkpoint serializer warnings + // (UserWarnings emitted from pydantic/main.py when the + // runtime snapshots a SurfSenseContextSchema into a field + // typed `None`) so the debugger's "Raised Exceptions" + // breakpoint doesn't pause on a known-harmless event. + // Production logs are unaffected. + "env": { + "PYTHONWARNINGS": "ignore::UserWarning:pydantic.main" + } }, { "name": "Backend: FastAPI (No Reload)", @@ -40,7 +49,10 @@ "pythonArgs": [ "run", "python" - ] + ], + "env": { + "PYTHONWARNINGS": "ignore::UserWarning:pydantic.main" + } }, { "name": "Backend: FastAPI (main.py)", @@ -54,7 +66,10 @@ "pythonArgs": [ "run", "python" - ] + ], + "env": { + "PYTHONWARNINGS": "ignore::UserWarning:pydantic.main" + } }, { "name": "Frontend: Next.js", @@ -104,7 +119,10 @@ "pythonArgs": [ "run", "python" - ] + ], + "env": { + "PYTHONWARNINGS": "ignore::UserWarning:pydantic.main" + } }, { "name": "Celery: Beat Scheduler", @@ -124,7 +142,10 @@ "pythonArgs": [ "run", "python" - ] + ], + "env": { + "PYTHONWARNINGS": "ignore::UserWarning:pydantic.main" + } } ], "compounds": [ diff --git a/surfsense_backend/alembic/versions/141_unique_chat_message_turn_role.py b/surfsense_backend/alembic/versions/141_unique_chat_message_turn_role.py new file mode 100644 index 000000000..9a27e7ed0 --- /dev/null +++ b/surfsense_backend/alembic/versions/141_unique_chat_message_turn_role.py @@ -0,0 +1,66 @@ +"""141_unique_chat_message_turn_role + +Revision ID: 141 +Revises: 140 +Create Date: 2026-05-04 + +Add a partial unique index on ``new_chat_messages(thread_id, turn_id, role)`` +where ``turn_id IS NOT NULL``. + +Why +--- +The streaming chat path (`stream_new_chat` / `stream_resume_chat`) is being +moved to write its own ``new_chat_messages`` rows server-side instead of +relying on the frontend's later ``POST /threads/{id}/messages`` call. This +closes the "ghost-thread" abuse vector where authenticated callers got free +LLM completions while ``new_chat_messages`` stayed empty. + +For server-side and legacy frontend writes to coexist we need an idempotency +key. The natural triple is ``(thread_id, turn_id, role)``: the server issues +exactly one ``turn_id`` per turn, and a turn produces at most one user +message and one assistant message. Whichever side wins the race writes the +row; the loser hits ``IntegrityError`` and recovers gracefully. + +Partial — ``WHERE turn_id IS NOT NULL`` — so: + + * Legacy rows that predate the ``turn_id`` column (migration 136) keep + co-existing without de-dup. + * Clone / snapshot inserts in + ``app/services/public_chat_service.py`` that build ``NewChatMessage`` + without ``turn_id`` are unaffected (multiple snapshot copies of the same + user/assistant pair are intentional). + +This index coexists with the existing single-column ``ix_new_chat_messages_turn_id`` +from migration 136 — no collision. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "141" +down_revision: str | None = "140" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +INDEX_NAME = "uq_new_chat_messages_thread_turn_role" +TABLE_NAME = "new_chat_messages" + + +def upgrade() -> None: + op.create_index( + INDEX_NAME, + TABLE_NAME, + ["thread_id", "turn_id", "role"], + unique=True, + postgresql_where=sa.text("turn_id IS NOT NULL"), + ) + + +def downgrade() -> None: + op.drop_index(INDEX_NAME, table_name=TABLE_NAME) diff --git a/surfsense_backend/alembic/versions/142_token_usage_message_id_unique.py b/surfsense_backend/alembic/versions/142_token_usage_message_id_unique.py new file mode 100644 index 000000000..43b30a756 --- /dev/null +++ b/surfsense_backend/alembic/versions/142_token_usage_message_id_unique.py @@ -0,0 +1,134 @@ +"""142_token_usage_message_id_unique + +Revision ID: 142 +Revises: 141 +Create Date: 2026-05-04 + +Add a partial unique index on ``token_usage(message_id)`` where +``message_id IS NOT NULL``. + +Why +--- +Two writers can race on the same assistant turn's ``token_usage`` row: + + * ``finalize_assistant_turn`` (server-side, called from the streaming + finally block in ``stream_new_chat`` / ``stream_resume_chat``) + * ``append_message``'s recovery branch in + ``app/routes/new_chat_routes.py`` (legacy frontend round-trip) + +Both currently use ``SELECT ... THEN INSERT`` in separate sessions, so a +micro-second-aligned race could observe "no row" on each side and double +INSERT, producing duplicate ``token_usage`` rows for the same +``message_id``. + +A partial unique index on ``message_id`` (``WHERE message_id IS NOT NULL``) +turns both writes into ``INSERT ... ON CONFLICT (message_id) DO NOTHING`` +no-ops for the loser, hard-eliminating the race at the DB level. Partial +because non-chat usage rows (indexing, image generation, podcasts) keep +``message_id`` NULL — they're per-event, no de-dup needed. + +Pre-flight +---------- +Today's schema only has a non-unique index on ``message_id`` so a +duplicate population could already exist from any past race. We: + + * Detect duplicate ``message_id`` groups (``HAVING COUNT(*) > 1``). + * If the group count is at or below ``DUPLICATE_ABORT_THRESHOLD`` (50) + we dedupe by deleting all but the smallest ``id`` per group. + * If the count exceeds the threshold we abort with a descriptive + error rather than silently mutate prod data — operator must + investigate before retrying. + +Concurrency +----------- +``CREATE INDEX CONCURRENTLY`` is required on this hot table to avoid +stalling production writes during deploy (a regular ``CREATE INDEX`` +holds an ACCESS EXCLUSIVE lock for the duration of the build, which +would block ``token_usage`` INSERTs for every active streaming chat). +The trade-off is a slower migration (CONCURRENTLY scans the table +twice) and the ``CREATE`` statement cannot run inside alembic's default +transaction wrapper — ``autocommit_block()`` handles that. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "142" +down_revision: str | None = "141" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +INDEX_NAME = "uq_token_usage_message_id" +TABLE_NAME = "token_usage" + +# Refuse to silently mutate prod data if the duplicate population is +# unexpectedly large — operator should investigate the upstream cause +# before retrying. 50 is comfortably above any plausible duplicate +# count from the existing race window (the race is microseconds wide). +DUPLICATE_ABORT_THRESHOLD = 50 + + +def upgrade() -> None: + conn = op.get_bind() + + dup_groups = conn.execute( + sa.text( + "SELECT message_id, COUNT(*) AS n " + "FROM token_usage " + "WHERE message_id IS NOT NULL " + "GROUP BY message_id " + "HAVING COUNT(*) > 1" + ) + ).fetchall() + + if len(dup_groups) > DUPLICATE_ABORT_THRESHOLD: + raise RuntimeError( + f"token_usage has {len(dup_groups)} duplicate message_id groups " + f"(threshold={DUPLICATE_ABORT_THRESHOLD}). " + "Resolve the duplicates manually before re-running this migration." + ) + + if dup_groups: + # Delete all but the smallest-id row per duplicate group. The + # smallest id is by definition the earliest insert, so we keep + # the row most likely to reflect the actual stream's first + # successful write. + conn.execute( + sa.text( + """ + DELETE FROM token_usage + WHERE id IN ( + SELECT id FROM ( + SELECT + id, + row_number() OVER ( + PARTITION BY message_id ORDER BY id ASC + ) AS rn + FROM token_usage + WHERE message_id IS NOT NULL + ) ranked + WHERE rn > 1 + ) + """ + ) + ) + + # CREATE INDEX CONCURRENTLY cannot run inside a transaction. Drop + # alembic's auto-transaction for this op only. + with op.get_context().autocommit_block(): + op.execute( + f"CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS {INDEX_NAME} " + f"ON {TABLE_NAME} (message_id) " + "WHERE message_id IS NOT NULL" + ) + + +def downgrade() -> None: + with op.get_context().autocommit_block(): + op.execute(f"DROP INDEX CONCURRENTLY IF EXISTS {INDEX_NAME}") diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index aef959ec9..9fc27fb1f 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -675,6 +675,23 @@ class NewChatMessage(BaseModel, TimestampMixin): __tablename__ = "new_chat_messages" + # Partial unique index on (thread_id, turn_id, role) where turn_id IS NOT NULL. + # Mirrors alembic migration 141. Lets the streaming agent and the + # legacy frontend appendMessage call coexist idempotently — the second + # writer trips the unique and recovers without creating a duplicate row. + # Partial so legacy NULL turn_id rows and clone/snapshot inserts in + # app/services/public_chat_service.py (which omit turn_id) are unaffected. + __table_args__ = ( + Index( + "uq_new_chat_messages_thread_turn_role", + "thread_id", + "turn_id", + "role", + unique=True, + postgresql_where=text("turn_id IS NOT NULL"), + ), + ) + role = Column(SQLAlchemyEnum(NewChatMessageRole), nullable=False) # Content stored as JSONB to support rich content (text, tool calls, etc.) content = Column(JSONB, nullable=False) @@ -728,6 +745,22 @@ class TokenUsage(BaseModel, TimestampMixin): __tablename__ = "token_usage" + # Partial unique index on (message_id) where message_id IS NOT NULL. + # Mirrors alembic migration 142. Lets the streaming agent's + # ``finalize_assistant_turn`` and the legacy frontend ``append_message`` + # recovery branch both use ``INSERT ... ON CONFLICT DO NOTHING`` without + # racing on a SELECT-then-INSERT window. Partial so non-chat usage rows + # (indexing, image generation, podcasts) — which keep ``message_id`` NULL + # because there is no per-message anchor — are unaffected. + __table_args__ = ( + Index( + "uq_token_usage_message_id", + "message_id", + unique=True, + postgresql_where=text("message_id IS NOT NULL"), + ), + ) + prompt_tokens = Column(Integer, nullable=False, default=0) completion_tokens = Column(Integer, nullable=False, default=0) total_tokens = Column(Integer, nullable=False, default=0) diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index d3bd51129..2ade207d4 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -17,7 +17,8 @@ from datetime import UTC, datetime from fastapi import APIRouter, Depends, HTTPException, Request, Response from fastapi.responses import StreamingResponse -from sqlalchemy import func, or_ +from sqlalchemy import func, or_, text as sa_text +from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select @@ -44,6 +45,7 @@ from app.db import ( NewChatThread, Permission, SearchSpace, + TokenUsage, User, get_async_session, shielded_async_session, @@ -69,9 +71,9 @@ from app.schemas.new_chat import ( TokenUsageSummary, TurnStatusResponse, ) -from app.services.token_tracking_service import record_token_usage from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat from app.users import current_active_user +from app.utils.perf import get_perf_logger from app.utils.rbac import check_permission from app.utils.user_message_multimodal import ( split_langchain_human_content, @@ -79,6 +81,7 @@ from app.utils.user_message_multimodal import ( ) _logger = logging.getLogger(__name__) +_perf_log = get_perf_logger() _background_tasks: set[asyncio.Task] = set() TURN_CANCELLING_INITIAL_DELAY_MS = 200 TURN_CANCELLING_BACKOFF_FACTOR = 2 @@ -1287,6 +1290,24 @@ async def append_message( user: User = Depends(current_active_user), ): """ + .. deprecated:: 2026-05 + Replaced by the **SSE-based message ID handshake**. The streaming + generator (`stream_new_chat` / `stream_resume_chat`) now persists + both the user and assistant rows server-side via + ``persist_user_turn`` / ``persist_assistant_shell`` and emits + ``data-user-message-id`` / ``data-assistant-message-id`` SSE events + so the frontend can rename its optimistic IDs in real time. The + new FE bundle no longer calls this route. + + This handler is retained as a **silent no-op for legacy / cached + FE bundles**: the underlying ``INSERT ... ON CONFLICT DO NOTHING`` + pattern means a stale bundle hitting this route after the SSE + handshake already wrote the row simply returns the existing row + (200 OK) without raising or duplicating data. After a 2-week soak + (target: ``[persist_user_turn] outcome=race_recovered`` rate ~0) + this entire route — and the FE ``appendMessage`` function — is + earmarked for removal. + Append a message to a thread. This is used by ThreadHistoryAdapter.append() to persist messages. @@ -1297,6 +1318,22 @@ async def append_message( Requires CHATS_UPDATE permission. """ try: + # Capture ``user.id`` as a primitive UUID up front. The + # ``current_active_user`` dependency hands us a ``User`` ORM + # row bound to ``session``; if the outer ``except + # IntegrityError`` block below ever fires (an unexpected + # constraint like a foreign key violation — the common + # ``(thread_id, turn_id, role)`` race is now handled silently + # by ``ON CONFLICT DO NOTHING`` so it never raises) it calls + # ``session.rollback()``, which expires every attached ORM + # row including this user. Any later ``user.id`` access would + # then trigger a lazy PK reload — which on async SQLAlchemy + # fails with ``MissingGreenlet`` because the reload happens + # outside the awaitable greenlet boundary. Reading ``id`` + # once here pins the value as a plain UUID so all downstream + # uses (TokenUsage insert, response build) are immune. + user_uuid = user.id + # Parse raw body - extract only role and content, ignoring extra fields raw_body = await request.json() role = raw_body.get("role") @@ -1351,42 +1388,166 @@ async def append_message( else None ) - db_message = NewChatMessage( - thread_id=thread_id, - role=message_role, - content=content, - author_id=user.id, - turn_id=turn_id_value, - ) - session.add(db_message) - - # Update thread's updated_at timestamp + # Update thread's updated_at timestamp (always — both insert + # and recovery paths represent thread activity). thread.updated_at = datetime.now(UTC) - # flush assigns the PK/defaults without a round-trip SELECT - await session.flush() + # Insert the new message via ``INSERT ... ON CONFLICT DO NOTHING`` + # keyed on the ``(thread_id, turn_id, role)`` partial unique + # index from migration 141 (``WHERE turn_id IS NOT NULL``). + # + # Why ON CONFLICT instead of ``session.add() + flush() + except + # IntegrityError``: + # 1. The conflict between this legacy FE ``appendMessage`` + # round-trip and the server-side + # ``finalize_assistant_turn`` writer is a NORMAL, + # *expected* race — every assistant turn fires it. Using + # catch-and-recover means asyncpg raises + # ``UniqueViolationError`` -> SQLAlchemy wraps it as + # ``IntegrityError`` -> our handler catches and recovers. + # Functionally fine, but every ``raise`` event lights up + # VS Code's debugger (debugpy's ``justMyCode=false`` mode + # loses track of the catch frame across SQLAlchemy's + # async greenlet boundary, so even ``Raised Exceptions`` + # being unchecked doesn't reliably suppress the pause). + # ON CONFLICT pushes the conflict resolution into Postgres + # where no Python exception is constructed at all. + # 2. No ``session.rollback()`` -> no expiring of attached + # ORM rows -> no risk of ``MissingGreenlet`` from + # lazy-loading expired user/thread state later in the + # handler. + # 3. Cleaner production logs (no SQLAlchemy ``IntegrityError`` + # tracebacks emitted by uvicorn's logger between the + # ``raise`` and our ``except``). + # + # When ``turn_id_value`` is ``None`` the partial index doesn't + # apply and the INSERT proceeds normally. Other constraint + # violations (FK, NOT NULL, etc.) still raise ``IntegrityError`` + # and are caught by the outer ``except IntegrityError`` block + # to preserve the legacy 400 behavior. + # + # Note on ``content``: when we recover the existing row, we + # intentionally discard the FE's ``content`` payload from + # ``raw_body`` and return the row's existing ``content``. The + # streaming task is now the *authoritative writer* for + # assistant ``ContentPart[]`` shape (mid-stream + # ``AssistantContentBuilder`` -> ``finalize_assistant_turn``) + # so the FE's later ``appendMessage`` is just a stale snapshot + # of the same data — keeping the server-built rich content + # (with full tool-call args / argsText / langchainToolCallId) + # is correct, not lossy. + insert_stmt = ( + pg_insert(NewChatMessage) + .values( + thread_id=thread_id, + role=message_role, + content=content, + author_id=user_uuid, + turn_id=turn_id_value, + ) + .on_conflict_do_nothing( + index_elements=["thread_id", "turn_id", "role"], + index_where=sa_text("turn_id IS NOT NULL"), + ) + .returning(NewChatMessage.id) + ) + inserted_id = (await session.execute(insert_stmt)).scalar() + + if inserted_id is None: + # Conflict on partial unique index — server-side stream + # already wrote this row. Look it up and reuse it. + if turn_id_value is None: + # Defensive: ON CONFLICT only fires for ``turn_id IS + # NOT NULL`` rows, so this branch should be + # unreachable. Preserve the legacy 400 just in case + # Postgres ever surprises us. + raise HTTPException( + status_code=400, + detail="Database constraint violation. Please check your input data.", + ) from None + lookup = await session.execute( + select(NewChatMessage).filter( + NewChatMessage.thread_id == thread_id, + NewChatMessage.turn_id == turn_id_value, + NewChatMessage.role == message_role, + ) + ) + existing_message = lookup.scalars().first() + if existing_message is None: + # Conflict reported but the row vanished between + # INSERT and SELECT — extremely unlikely (would + # require a concurrent DELETE within the same + # transaction visibility), but preserve safe + # behavior. + raise HTTPException( + status_code=400, + detail="Database constraint violation. Please check your input data.", + ) from None + db_message = existing_message + # Perf signal: counts how often the legacy FE round-trip + # races the server-side ``finalize_assistant_turn``. A + # rising rate after the rework is OK (it's exactly the + # ghost-thread fix's recovery path firing); a sudden drop + # to zero would mean the FE isn't posting appendMessage + # at all (different bug). + _perf_log.info( + "[append_message] outcome=recovered_via_unique_index " + "thread_id=%s turn_id=%s role=%s message_id=%s", + thread_id, + turn_id_value, + message_role.value, + db_message.id, + ) + else: + # INSERT succeeded — load the full ORM row so the + # response can include server-side-defaulted columns + # (``created_at``, etc.) and the relationship surface + # stays consistent with the recovery path. + inserted_row = await session.get(NewChatMessage, inserted_id) + if inserted_row is None: + # Should be impossible: we just inserted it in this + # same transaction. Fail loud if it happens. + raise HTTPException( + status_code=500, + detail="Inserted message could not be loaded.", + ) from None + db_message = inserted_row # Persist token usage if provided (for assistant messages). # ``cost_micros`` is the provider USD cost reported by LiteLLM, # forwarded by the FE through the appendMessage round-trip so # the historical TokenUsage row matches the credit debit applied # at finalize time. + # + # De-dup: ``finalize_assistant_turn`` may also race to write a + # token_usage row for this same ``message_id`` (cross-session, + # cross-shielded). Use ``INSERT ... ON CONFLICT DO NOTHING`` keyed + # on the ``uq_token_usage_message_id`` partial unique index + # (migration 142). The loser silently drops its insert; exactly + # one row results regardless of which writer commits first. token_usage_data = raw_body.get("token_usage") if token_usage_data and message_role == NewChatMessageRole.ASSISTANT: - await record_token_usage( - session, - usage_type="chat", - search_space_id=thread.search_space_id, - user_id=user.id, - prompt_tokens=token_usage_data.get("prompt_tokens", 0), - completion_tokens=token_usage_data.get("completion_tokens", 0), - total_tokens=token_usage_data.get("total_tokens", 0), - cost_micros=token_usage_data.get("cost_micros", 0), - model_breakdown=token_usage_data.get("usage"), - call_details=token_usage_data.get("call_details"), - thread_id=thread_id, - message_id=db_message.id, + insert_stmt = ( + pg_insert(TokenUsage) + .values( + usage_type="chat", + prompt_tokens=token_usage_data.get("prompt_tokens", 0), + completion_tokens=token_usage_data.get("completion_tokens", 0), + total_tokens=token_usage_data.get("total_tokens", 0), + cost_micros=token_usage_data.get("cost_micros", 0), + model_breakdown=token_usage_data.get("usage"), + call_details=token_usage_data.get("call_details"), + thread_id=thread_id, + message_id=db_message.id, + search_space_id=thread.search_space_id, + user_id=user_uuid, + ) + .on_conflict_do_nothing( + index_elements=["message_id"], + index_where=sa_text("message_id IS NOT NULL"), + ) ) + await session.execute(insert_stmt) await session.commit() @@ -1406,6 +1567,9 @@ async def append_message( except HTTPException: raise except IntegrityError: + # Any IntegrityError that escaped the inline handler above + # comes from a *different* constraint (foreign key, etc.) — + # preserve the legacy 400 path. await session.rollback() raise HTTPException( status_code=400, @@ -1599,6 +1763,12 @@ async def handle_new_chat( else None ) + mentioned_documents_payload = ( + [doc.model_dump() for doc in request.mentioned_documents] + if request.mentioned_documents + else None + ) + return StreamingResponse( stream_new_chat( user_query=request.user_query, @@ -1608,6 +1778,7 @@ async def handle_new_chat( llm_config_id=llm_config_id, mentioned_document_ids=request.mentioned_document_ids, mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids, + mentioned_documents=mentioned_documents_payload, needs_history_bootstrap=thread.needs_history_bootstrap, thread_visibility=thread.visibility, current_user_display_name=user.display_name or "A team member", @@ -2078,6 +2249,11 @@ async def regenerate_response( "data": revert_results, } yield f"data: {json.dumps(envelope, default=str)}\n\n".encode() + mentioned_documents_payload = ( + [doc.model_dump() for doc in request.mentioned_documents] + if request.mentioned_documents + else None + ) try: async for chunk in stream_new_chat( user_query=str(user_query_to_use), @@ -2087,6 +2263,7 @@ async def regenerate_response( llm_config_id=llm_config_id, mentioned_document_ids=request.mentioned_document_ids, mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids, + mentioned_documents=mentioned_documents_payload, checkpoint_id=target_checkpoint_id, needs_history_bootstrap=thread.needs_history_bootstrap, thread_visibility=thread.visibility, diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index 892ff9693..1a85484fa 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -200,6 +200,21 @@ class NewChatUserImagePart(BaseModel): return to_data_url(self.media_type, self.data) +class MentionedDocumentInfo(BaseModel): + """Display metadata for a single ``@``-mentioned document. + + The full triple ``{id, title, document_type}`` is forwarded by the + frontend mention chip so the server can embed it in the persisted + user message ``ContentPart[]`` (single ``mentioned-documents`` part). + The history loader then renders the chips on reload without an extra + fetch — mirrors the pre-refactor frontend ``persistUserTurn`` shape. + """ + + id: int + title: str = Field(..., min_length=1, max_length=500) + document_type: str = Field(..., min_length=1, max_length=100) + + class NewChatRequest(BaseModel): """Request schema for the deep agent chat endpoint.""" @@ -213,6 +228,17 @@ class NewChatRequest(BaseModel): mentioned_surfsense_doc_ids: list[int] | None = ( None # Optional SurfSense documentation IDs mentioned with @ in the chat ) + mentioned_documents: list[MentionedDocumentInfo] | None = Field( + default=None, + description=( + "Display metadata (id, title, document_type) for every " + "@-mentioned document. Persisted as a ``mentioned-documents`` " + "ContentPart on the user message so reload renders chips " + "without an extra fetch. Optional and additive — when None " + "the user message is persisted without a mentioned-documents " + "part." + ), + ) disabled_tools: list[str] | None = ( None # Optional list of tool names the user has disabled from the UI ) @@ -264,6 +290,16 @@ class RegenerateRequest(BaseModel): ) mentioned_document_ids: list[int] | None = None mentioned_surfsense_doc_ids: list[int] | None = None + mentioned_documents: list[MentionedDocumentInfo] | None = Field( + default=None, + description=( + "Display metadata (id, title, document_type) for every " + "@-mentioned document on the edited user turn. Only used " + "when ``user_query`` is non-None (edit). Persisted as a " + "``mentioned-documents`` ContentPart on the new user " + "message. None means no chip metadata." + ), + ) disabled_tools: list[str] | None = None filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud" client_platform: Literal["web", "desktop"] = "web" @@ -334,6 +370,16 @@ class ResumeRequest(BaseModel): filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud" client_platform: Literal["web", "desktop"] = "web" local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None + mentioned_documents: list[MentionedDocumentInfo] | None = Field( + default=None, + description=( + "Display metadata forwarded for symmetry with /new_chat and " + "/regenerate. Resume reuses the original interrupted user " + "turn so the server does not write a new user message. " + "Currently unused but accepted to keep request bodies " + "uniform across the three streaming entrypoints." + ), + ) class CancelActiveTurnResponse(BaseModel): diff --git a/surfsense_backend/app/tasks/chat/content_builder.py b/surfsense_backend/app/tasks/chat/content_builder.py new file mode 100644 index 000000000..041cab286 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/content_builder.py @@ -0,0 +1,515 @@ +"""Server-side mirror of the frontend's assistant-ui ``ContentPart`` projection. + +Background +---------- +The streaming chat task in ``stream_new_chat`` / ``stream_resume_chat`` yields +SSE events that the frontend folds into a ``ContentPartsState`` (see +``surfsense_web/lib/chat/streaming-state.ts`` and the matching pipeline in +``stream-pipeline.ts``). When a turn ends, the frontend calls +``buildContentForPersistence(...)`` and round-trips that ``ContentPart[]`` +JSONB to ``POST /threads/{id}/messages``, which is what was historically +written to ``new_chat_messages.content``. + +After the ghost-thread fix moved persistence server-side, the assistant +row is written by ``finalize_assistant_turn`` in the streaming finally +block. The frontend's later ``appendMessage`` is now a no-op (recovers +via the ``(thread_id, turn_id, role)`` partial unique index added in +migration 141), which means the *server* is now responsible for +producing the rich ``ContentPart[]`` shape the FE expects on history +reload — text + reasoning + tool-call cards (with ``args``, ``argsText``, +``result``, ``langchainToolCallId``) + thinking-step buckets + +step-separators. + +This module is the in-memory accumulator that mirrors the FE state for +exactly that purpose. The streaming code calls ``on_text_*`` / ``on_reasoning_*`` +/ ``on_tool_*`` / ``on_thinking_step`` / ``on_step_separator`` / +``mark_interrupted`` at the same call sites it yields the matching +``streaming_service.format_*`` SSE event, so the in-memory ``parts`` list +stays in lockstep with what the FE's pipeline would have produced live. +``snapshot()`` is then taken once in the ``finally`` block and persisted +in a single UPDATE. + +Pure synchronous state — no DB I/O, no async, no flush callbacks. The +streaming code is responsible for driving lifecycle methods; this class +is a thin projection helper. +""" + +from __future__ import annotations + +import copy +import json +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +# Mirrors the FE's filter in ``buildContentForPersistence`` / ``buildContentForUI``: +# only text/reasoning/tool-call parts count as "meaningful". data-thinking-steps +# and data-step-separator decorate the meaningful parts but never stand alone +# in a successful turn. +_MEANINGFUL_PART_TYPES: frozenset[str] = frozenset({"text", "reasoning", "tool-call"}) + + +class AssistantContentBuilder: + """Server-side projection of ``surfsense_web/lib/chat/streaming-state.ts``. + + Output shape (deep copy of ``self.parts`` via ``snapshot()``) strictly + matches the FE ``ContentPart`` union:: + + | { type: "text"; text: string } + | { type: "reasoning"; text: string } + | { type: "tool-call"; toolCallId: str; toolName: str; + args: dict; result?: any; argsText?: str; langchainToolCallId?: str; + state?: "aborted" } + | { type: "data-thinking-steps"; data: { steps: ThinkingStepData[] } } + | { type: "data-step-separator"; data: { stepIndex: int } } + + Order matches the wire order of the SSE events that drive the lifecycle + methods, with two FE-mirrored exceptions: + + 1. ``data-thinking-steps`` is a *singleton* and pinned at index 0 the + first time we see a ``data-thinking-step`` SSE event (the FE's + ``updateThinkingSteps`` does ``unshift`` on first sight). Subsequent + thinking-step updates mutate that singleton in place. + 2. ``data-step-separator`` is appended only when the message already has + meaningful content and the previous part isn't itself a separator + (so the FIRST step of a turn doesn't generate a leading divider). + """ + + def __init__(self) -> None: + self.parts: list[dict[str, Any]] = [] + # Index of the active text/reasoning part within ``parts`` while + # streaming is open; -1 means "no active part" and the next delta + # opens a fresh one. Mirrors ``ContentPartsState.currentTextPartIndex``. + self._current_text_idx: int = -1 + self._current_reasoning_idx: int = -1 + # ``ui_id``-keyed indexes for tool-call parts. ``ui_id`` is the + # synthetic ``call_`` (legacy) or the LangChain + # ``tool_call.id`` (parity_v2) — same key the streaming layer + # threads through every ``tool-input-*`` / ``tool-output-*`` event. + self._tool_call_idx_by_ui_id: dict[str, int] = {} + # Live argsText accumulator (concatenated ``tool-input-delta`` chunks) + # so we can reproduce the FE's ``appendToolInputDelta`` behaviour + # before ``tool-input-available`` overwrites it with the + # pretty-printed final JSON. + self._args_text_by_ui_id: dict[str, str] = {} + + # ------------------------------------------------------------------ + # Text + # ------------------------------------------------------------------ + + def on_text_start(self, text_id: str) -> None: + """Begin a fresh text block. + + Symmetric to FE ``appendText``: opening text closes any active + reasoning so the renderer treats them as separate parts. The + actual text part isn't materialised here — it's lazily created + on the first ``on_text_delta`` so an empty start/end pair + leaves no trace. Matches the FE pipeline which has no explicit + ``text-start`` handler at all. + """ + if self._current_reasoning_idx >= 0: + self._current_reasoning_idx = -1 + + def on_text_delta(self, text_id: str, delta: str) -> None: + if not delta: + return + if self._current_reasoning_idx >= 0: + # FE behaviour: a text delta after reasoning implicitly + # closes the reasoning block (see ``appendText`` lines + # 178-180). + self._current_reasoning_idx = -1 + if ( + self._current_text_idx >= 0 + and 0 <= self._current_text_idx < len(self.parts) + and self.parts[self._current_text_idx].get("type") == "text" + ): + self.parts[self._current_text_idx]["text"] += delta + return + self.parts.append({"type": "text", "text": delta}) + self._current_text_idx = len(self.parts) - 1 + + def on_text_end(self, text_id: str) -> None: + """Close the active text block. + + Mirrors the wire-level ``text-end`` boundary the streaming layer + emits before tool calls / reasoning / step boundaries. The FE + pipeline implicitly closes via ``currentTextPartIndex = -1`` + in ``addToolCall`` / ``appendReasoning`` / ``addStepSeparator``; + our helper does the same explicitly so callers don't have to + maintain that invariant per call site. + """ + self._current_text_idx = -1 + + # ------------------------------------------------------------------ + # Reasoning + # ------------------------------------------------------------------ + + def on_reasoning_start(self, reasoning_id: str) -> None: + if self._current_text_idx >= 0: + self._current_text_idx = -1 + + def on_reasoning_delta(self, reasoning_id: str, delta: str) -> None: + if not delta: + return + if self._current_text_idx >= 0: + self._current_text_idx = -1 + if ( + self._current_reasoning_idx >= 0 + and 0 <= self._current_reasoning_idx < len(self.parts) + and self.parts[self._current_reasoning_idx].get("type") == "reasoning" + ): + self.parts[self._current_reasoning_idx]["text"] += delta + return + self.parts.append({"type": "reasoning", "text": delta}) + self._current_reasoning_idx = len(self.parts) - 1 + + def on_reasoning_end(self, reasoning_id: str) -> None: + self._current_reasoning_idx = -1 + + # ------------------------------------------------------------------ + # Tool calls + # ------------------------------------------------------------------ + + def on_tool_input_start( + self, + ui_id: str, + tool_name: str, + langchain_tool_call_id: str | None, + ) -> None: + """Register a tool-call card. Args are filled in by later events.""" + if not ui_id: + return + # Skip duplicate registration: parity_v2 may emit + # ``tool-input-start`` from both ``on_chat_model_stream`` + # (when tool_call_chunks register a name) and ``on_tool_start`` + # (the canonical path). The FE de-dupes via ``toolCallIndices``; + # we mirror that here. + if ui_id in self._tool_call_idx_by_ui_id: + if langchain_tool_call_id: + idx = self._tool_call_idx_by_ui_id[ui_id] + part = self.parts[idx] + if not part.get("langchainToolCallId"): + part["langchainToolCallId"] = langchain_tool_call_id + return + + part: dict[str, Any] = { + "type": "tool-call", + "toolCallId": ui_id, + "toolName": tool_name, + "args": {}, + } + if langchain_tool_call_id: + part["langchainToolCallId"] = langchain_tool_call_id + self.parts.append(part) + self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1 + + self._current_text_idx = -1 + self._current_reasoning_idx = -1 + + def on_tool_input_delta(self, ui_id: str, args_chunk: str) -> None: + """Append a streamed args-delta chunk to the matching card's argsText. + + Mirrors FE ``appendToolInputDelta``: no-ops when no card has been + registered yet for the given ``ui_id`` — the deltas have nowhere + safe to land. + """ + if not ui_id or not args_chunk: + return + idx = self._tool_call_idx_by_ui_id.get(ui_id) + if idx is None: + return + if not (0 <= idx < len(self.parts)): + return + part = self.parts[idx] + if part.get("type") != "tool-call": + return + new_text = (part.get("argsText") or "") + args_chunk + part["argsText"] = new_text + self._args_text_by_ui_id[ui_id] = new_text + + def on_tool_input_available( + self, + ui_id: str, + tool_name: str, + args: dict[str, Any], + langchain_tool_call_id: str | None, + ) -> None: + """Finalize the tool-call card's input. + + Mirrors FE ``stream-pipeline.ts`` lines 127-153: replaces ``argsText`` + with ``json.dumps(input, indent=2)`` so the post-stream card renders + pretty-printed JSON, sets the full ``args`` dict, and backfills + ``langchainToolCallId`` if it wasn't known at ``tool-input-start`` time. + Also creates the card if no prior ``tool-input-start`` registered it + (legacy parity_v2-OFF / late-registration paths). + """ + if not ui_id: + return + try: + final_args_text = json.dumps(args or {}, indent=2, ensure_ascii=False) + except (TypeError, ValueError): + # Defensive: ``args`` should already be JSON-safe (the + # streaming layer sanitizes it before emitting), but if a + # caller hands us a non-serializable value we still want + # to record the call without breaking the snapshot. + final_args_text = str(args) + + idx = self._tool_call_idx_by_ui_id.get(ui_id) + if idx is not None and 0 <= idx < len(self.parts): + part = self.parts[idx] + if part.get("type") == "tool-call": + part["args"] = args or {} + part["argsText"] = final_args_text + if langchain_tool_call_id and not part.get("langchainToolCallId"): + part["langchainToolCallId"] = langchain_tool_call_id + return + + # No prior tool-input-start: register the card now. + new_part: dict[str, Any] = { + "type": "tool-call", + "toolCallId": ui_id, + "toolName": tool_name, + "args": args or {}, + "argsText": final_args_text, + } + if langchain_tool_call_id: + new_part["langchainToolCallId"] = langchain_tool_call_id + self.parts.append(new_part) + self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1 + + self._current_text_idx = -1 + self._current_reasoning_idx = -1 + + def on_tool_output_available( + self, + ui_id: str, + output: Any, + langchain_tool_call_id: str | None, + ) -> None: + """Attach the tool's output (``result``) to the matching card. + + Mirrors FE ``updateToolCall``: backfill ``langchainToolCallId`` + only if not already set (a NULL late-arriving value never blows + away an earlier known good one). + """ + if not ui_id: + return + idx = self._tool_call_idx_by_ui_id.get(ui_id) + if idx is None or not (0 <= idx < len(self.parts)): + return + part = self.parts[idx] + if part.get("type") != "tool-call": + return + part["result"] = output + if langchain_tool_call_id and not part.get("langchainToolCallId"): + part["langchainToolCallId"] = langchain_tool_call_id + + # ------------------------------------------------------------------ + # Thinking steps & step separators + # ------------------------------------------------------------------ + + def on_thinking_step( + self, + step_id: str, + title: str, + status: str, + items: list[str] | None, + ) -> None: + """Update / insert the singleton ``data-thinking-steps`` part. + + Mirrors FE ``updateThinkingSteps``: maintain a single + ``data-thinking-steps`` part anchored at index 0, replacing or + unshifting on first sight. Each ``on_thinking_step`` call + replaces the entry in the steps list keyed by ``step_id`` (or + appends if new). + """ + if not step_id: + return + + new_step = { + "id": step_id, + "title": title or "", + "status": status or "in_progress", + "items": list(items) if items else [], + } + + # Find existing data-thinking-steps part. + existing_idx = -1 + for i, p in enumerate(self.parts): + if p.get("type") == "data-thinking-steps": + existing_idx = i + break + + if existing_idx >= 0: + current_steps = self.parts[existing_idx].get("data", {}).get("steps") or [] + replaced = False + for i, step in enumerate(current_steps): + if step.get("id") == step_id: + current_steps[i] = new_step + replaced = True + break + if not replaced: + current_steps.append(new_step) + self.parts[existing_idx] = { + "type": "data-thinking-steps", + "data": {"steps": current_steps}, + } + return + + # First sight: unshift to position 0 (FE parity). + self.parts.insert( + 0, + { + "type": "data-thinking-steps", + "data": {"steps": [new_step]}, + }, + ) + # Bump tracked indices since we inserted at the head. + if self._current_text_idx >= 0: + self._current_text_idx += 1 + if self._current_reasoning_idx >= 0: + self._current_reasoning_idx += 1 + for ui_id, idx in list(self._tool_call_idx_by_ui_id.items()): + self._tool_call_idx_by_ui_id[ui_id] = idx + 1 + + def on_step_separator(self) -> None: + """Append a ``data-step-separator`` between consecutive model steps. + + Mirrors FE ``addStepSeparator``: only emit when the message + already has meaningful content AND the previous part isn't + itself a separator. ``stepIndex`` is the running count of + separators already in ``parts``. + """ + has_content = any(p.get("type") in _MEANINGFUL_PART_TYPES for p in self.parts) + if not has_content: + return + if self.parts and self.parts[-1].get("type") == "data-step-separator": + return + step_index = sum( + 1 for p in self.parts if p.get("type") == "data-step-separator" + ) + self.parts.append( + { + "type": "data-step-separator", + "data": {"stepIndex": step_index}, + } + ) + self._current_text_idx = -1 + self._current_reasoning_idx = -1 + + # ------------------------------------------------------------------ + # Interruption handling + # ------------------------------------------------------------------ + + def mark_interrupted(self) -> None: + """Close any open text/reasoning and flip running tools to aborted. + + Called from the streaming ``finally`` block before ``snapshot()`` so + the persisted JSONB reflects a coherent end-state even when the + client disconnected mid-turn or the agent hit a fatal error. + + - Active text/reasoning blocks: simply lose their "active" + marker (no synthetic content appended). Whatever was streamed + stays as-is. + - Tool-call parts that never received a ``result`` get + ``state="aborted"`` so the FE history loader can render them + as "interrupted" rather than "still running". + """ + self._current_text_idx = -1 + self._current_reasoning_idx = -1 + for part in self.parts: + if part.get("type") != "tool-call": + continue + if "result" in part: + continue + part["state"] = "aborted" + + # ------------------------------------------------------------------ + # Snapshot & introspection + # ------------------------------------------------------------------ + + def snapshot(self) -> list[dict[str, Any]]: + """Return a deep copy of ``parts`` ready for SQL UPDATE / json.dumps. + + Deep-copied so callers that finalize from the shielded ``finally`` + block can't accidentally mutate the persisted payload while the + SQL UPDATE is in flight (the streaming layer doesn't touch the + builder after this call, but defensive copies are cheap and cheap + is what we want in a finally block). + """ + return copy.deepcopy(self.parts) + + def is_empty(self) -> bool: + """True if no meaningful content was captured. + + ``data-thinking-steps`` and ``data-step-separator`` decorate + meaningful content but don't count on their own — a turn that + only emitted a thinking step before being interrupted should + still be treated as empty for the status-marker fallback. + """ + return not any(p.get("type") in _MEANINGFUL_PART_TYPES for p in self.parts) + + def stats(self) -> dict[str, int]: + """Return counts of each part-type plus rough byte size. + + Used by the streaming layer's perf logger so an ops dashboard + can correlate finalize latency with payload size, and so a + regression that quietly stops emitting tool-call parts (or + starts emitting hundreds) shows up in [PERF] grep rather than + only as a "history reload looks weird" bug report. + + ``bytes`` is the JSON-serialised payload length — what actually + crosses the wire to PostgreSQL's JSONB column. We compute it + with ``ensure_ascii=False`` to match the JSONB encoder's UTF-8 + on-disk layout closely enough for back-of-the-envelope sizing. + Reasoning/text/tool-call/thinking-step/step-separator counts are + independent so any one can spike without the others. + + Defensive: ``json.dumps`` failure (a non-serializable value + slipped past the streaming layer's sanitization) is reported as + ``bytes=-1`` rather than raised — perf logging must not be the + thing that breaks the streaming finally block. + """ + text_blocks = 0 + reasoning_blocks = 0 + tool_calls = 0 + tool_calls_completed = 0 + tool_calls_aborted = 0 + thinking_step_parts = 0 + step_separators = 0 + + for part in self.parts: + kind = part.get("type") + if kind == "text": + text_blocks += 1 + elif kind == "reasoning": + reasoning_blocks += 1 + elif kind == "tool-call": + tool_calls += 1 + if part.get("state") == "aborted": + tool_calls_aborted += 1 + elif "result" in part: + tool_calls_completed += 1 + elif kind == "data-thinking-steps": + thinking_step_parts += 1 + elif kind == "data-step-separator": + step_separators += 1 + + try: + byte_size = len(json.dumps(self.parts, ensure_ascii=False, default=str)) + except (TypeError, ValueError): + byte_size = -1 + + return { + "parts": len(self.parts), + "bytes": byte_size, + "text": text_blocks, + "reasoning": reasoning_blocks, + "tool_calls": tool_calls, + "tool_calls_completed": tool_calls_completed, + "tool_calls_aborted": tool_calls_aborted, + "thinking_step_parts": thinking_step_parts, + "step_separators": step_separators, + } diff --git a/surfsense_backend/app/tasks/chat/persistence.py b/surfsense_backend/app/tasks/chat/persistence.py new file mode 100644 index 000000000..b2b8b6a88 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/persistence.py @@ -0,0 +1,534 @@ +"""Server-side message persistence helpers for the streaming chat agent. + +Historically the streaming task (``stream_new_chat``/``stream_resume_chat``) +left ``new_chat_messages`` empty and relied on the frontend to round-trip +``POST /threads/{id}/messages`` afterwards. That gave authenticated clients +a "ghost-thread" abuse vector: skip the round-trip and burn LLM tokens +without leaving an audit trail. These helpers move both writes (the user +turn that triggered the stream and the assistant turn the stream produced) +into the server itself, idempotent against the partial unique index +``uq_new_chat_messages_thread_turn_role`` so legacy frontends that *do* +keep posting via ``appendMessage`` simply hit the unique-index recovery +path on the second writer instead of creating duplicates. + +Assistant turn lifecycle +------------------------ +The assistant side is split into two helpers so we can capture the row id +*before* the stream produces any output: + +* ``persist_assistant_shell`` runs immediately after ``persist_user_turn`` + and INSERTs an empty assistant row anchored to ``(thread_id, turn_id, + ASSISTANT)``. Returns the row id so the streaming layer can correlate + later writes (token_usage, AgentActionLog future-correlation) against + a stable PK from the start of the turn. +* ``finalize_assistant_turn`` runs from the streaming ``finally`` block. + It UPDATEs the row's ``content`` to the rich ``ContentPart[]`` snapshot + produced server-side by ``AssistantContentBuilder`` and writes the + ``token_usage`` row using ``INSERT ... ON CONFLICT DO NOTHING`` against + the ``uq_token_usage_message_id`` partial unique index from migration + 142, hard-eliminating any race against ``append_message``'s recovery + branch. + +Defensive contract +------------------ + +* Every helper runs inside ``shielded_async_session()`` so ``session.close()`` + survives starlette's mid-stream cancel scope on client disconnect. +* ``persist_user_turn`` and ``persist_assistant_shell`` use ``INSERT ... ON + CONFLICT DO NOTHING ... RETURNING id`` keyed on the ``(thread_id, turn_id, + role)`` partial unique index. On conflict the insert silently no-ops at + the DB level — no Python ``IntegrityError`` is constructed, which + eliminates spurious debugger pauses and keeps logs clean. On conflict a + follow-up ``SELECT`` resolves the existing row id so the streaming layer + can correlate writes against a stable PK. +* ``finalize_assistant_turn`` is best-effort: it never raises. The + streaming ``finally`` block calls it from within + ``anyio.CancelScope(shield=True)`` and any raised exception there + would mask the real error. +""" + +from __future__ import annotations + +import logging +import time +from datetime import UTC, datetime +from typing import Any +from uuid import UUID + +from sqlalchemy import text as sa_text +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.future import select + +from app.db import ( + NewChatMessage, + NewChatMessageRole, + NewChatThread, + TokenUsage, + shielded_async_session, +) +from app.services.token_tracking_service import ( + TurnTokenAccumulator, +) +from app.utils.perf import get_perf_logger + +logger = logging.getLogger(__name__) +_perf_log = get_perf_logger() + + +# Empty initial assistant content. ``finalize_assistant_turn`` overwrites +# this in a single UPDATE at end-of-stream with the full ``ContentPart[]`` +# snapshot produced by ``AssistantContentBuilder``. We persist a one-element +# list with an empty text part so a crash between shell-INSERT and finalize +# leaves the row in a FE-renderable shape (blank bubble) instead of +# blowing up the history loader. +_EMPTY_SHELL_CONTENT: list[dict[str, Any]] = [{"type": "text", "text": ""}] + +# Substituted content for genuinely empty turns (no text, no reasoning, +# no tool calls). The streaming layer flips to this when +# ``AssistantContentBuilder.is_empty()`` returns True so the persisted +# row is at least somewhat self-describing instead of an empty text +# bubble. The FE's ``ContentPart`` union doesn't include ``status`` +# yet, so the history loader will silently drop this part and render +# a blank bubble (matches today's behaviour for empty turns); a follow-up +# FE PR adds the explicit "no response" rendering. +_STATUS_NO_RESPONSE: list[dict[str, Any]] = [ + {"type": "status", "text": "(no text response)"} +] + + +def _build_user_content( + user_query: str, + user_image_data_urls: list[str] | None, + mentioned_documents: list[dict[str, Any]] | None = None, +) -> list[dict[str, Any]]: + """Build the persisted user-message ``content`` (assistant-ui v2 parts). + + Mirrors the shape the existing frontend posts via + ``appendMessage`` (see ``surfsense_web/.../new-chat/[[...chat_id]]/page.tsx``): + + [{"type": "text", "text": "..."}, + {"type": "image", "image": "data:..."}, + {"type": "mentioned-documents", "documents": [{"id": int, + "title": str, "document_type": str}, ...]}] + + The companion reader is + ``app.utils.user_message_multimodal.split_persisted_user_content_parts`` + which expects exactly this shape — keep them in sync. + + ``mentioned_documents``: optional list of ``{id, title, document_type}`` + dicts. When non-empty (and a ``mentioned-documents`` part is not already + in some other input shape), a single ``{"type": "mentioned-documents", + "documents": [...]}`` part is appended. Mirrors the FE injection at + ``page.tsx:281-286`` (``persistUserTurn``). + """ + parts: list[dict[str, Any]] = [{"type": "text", "text": user_query or ""}] + for url in user_image_data_urls or (): + if isinstance(url, str) and url: + parts.append({"type": "image", "image": url}) + if mentioned_documents: + normalized: list[dict[str, Any]] = [] + for doc in mentioned_documents: + if not isinstance(doc, dict): + continue + doc_id = doc.get("id") + title = doc.get("title") + document_type = doc.get("document_type") + if doc_id is None or title is None or document_type is None: + continue + normalized.append( + { + "id": doc_id, + "title": str(title), + "document_type": str(document_type), + } + ) + if normalized: + parts.append({"type": "mentioned-documents", "documents": normalized}) + return parts + + +async def persist_user_turn( + *, + chat_id: int, + user_id: str | None, + turn_id: str, + user_query: str, + user_image_data_urls: list[str] | None = None, + mentioned_documents: list[dict[str, Any]] | None = None, +) -> int | None: + """Persist the user-side row for a chat turn and return its ``id``. + + Uses ``INSERT ... ON CONFLICT DO NOTHING ... RETURNING id`` keyed on the + ``(thread_id, turn_id, role)`` partial unique index from migration 141 + (``WHERE turn_id IS NOT NULL``). On conflict the insert silently no-ops + at the DB level — no Python ``IntegrityError`` is constructed, which + eliminates the debugger pause that ``justMyCode=false`` + async greenlet + interactions used to produce, and keeps production logs clean. + + Returns the ``id`` of the row that exists for this turn after the call: + the freshly inserted ``id`` on the happy path, or the existing ``id`` + when a previous writer (legacy FE ``appendMessage`` racing the SSE + stream, redelivered request, etc.) already wrote it. Returns ``None`` + only on genuine DB failure; the caller should yield a streaming error + and abort the turn so we never produce a title/assistant row that + isn't anchored to a persisted user message. + + Other constraint violations (FK, NOT NULL, etc.) still raise + ``IntegrityError`` — only the ``(thread_id, turn_id, role)`` collision + is silenced. + """ + if not turn_id: + # Defensive: turn_id is always populated by the streaming path + # before this helper is called. If it isn't, we cannot be + # idempotent against the unique index — refuse to write rather + # than create a row the unique index can't dedupe. + logger.error( + "persist_user_turn called without a turn_id (chat_id=%s); skipping", + chat_id, + ) + return None + + t0 = time.perf_counter() + outcome = "failed" + resolved_id: int | None = None + try: + async with shielded_async_session() as ws: + # Re-attach the thread row so we can also bump updated_at + # in the same write — keeps the sidebar ordering accurate + # when a user fires off a turn but never reaches the + # legacy appendMessage. + thread = await ws.get(NewChatThread, chat_id) + author_uuid: UUID | None = None + if user_id: + try: + author_uuid = UUID(user_id) + except (TypeError, ValueError): + logger.warning( + "persist_user_turn: invalid user_id=%r, persisting as anonymous", + user_id, + ) + + content_payload = _build_user_content( + user_query, user_image_data_urls, mentioned_documents + ) + insert_stmt = ( + pg_insert(NewChatMessage) + .values( + thread_id=chat_id, + role=NewChatMessageRole.USER, + content=content_payload, + author_id=author_uuid, + turn_id=turn_id, + ) + .on_conflict_do_nothing( + index_elements=["thread_id", "turn_id", "role"], + index_where=sa_text("turn_id IS NOT NULL"), + ) + .returning(NewChatMessage.id) + ) + inserted_id = (await ws.execute(insert_stmt)).scalar() + + if inserted_id is None: + # Conflict on partial unique index — another writer + # (legacy FE appendMessage, redelivered request, etc.) + # already persisted this row. Look it up and reuse. + lookup = await ws.execute( + select(NewChatMessage.id).where( + NewChatMessage.thread_id == chat_id, + NewChatMessage.turn_id == turn_id, + NewChatMessage.role == NewChatMessageRole.USER, + ) + ) + existing_id = lookup.scalars().first() + if existing_id is None: + # Conflict reported but no row found — extremely + # unlikely (concurrent DELETE). Surface as failure. + logger.warning( + "persist_user_turn: conflict but no matching row " + "(chat_id=%s, turn_id=%s)", + chat_id, + turn_id, + ) + outcome = "integrity_no_match" + return None + resolved_id = int(existing_id) + outcome = "race_recovered" + else: + resolved_id = int(inserted_id) + outcome = "inserted" + # Bump thread.updated_at only on a real insert — when + # we recovered an existing row the prior writer + # already touched the thread. + if thread is not None: + thread.updated_at = datetime.now(UTC) + + await ws.commit() + return resolved_id + except Exception: + logger.exception( + "persist_user_turn failed (chat_id=%s, turn_id=%s)", + chat_id, + turn_id, + ) + return None + finally: + _perf_log.info( + "[persist_user_turn] outcome=%s chat_id=%s turn_id=%s " + "message_id=%s query_len=%d images=%d mentioned_docs=%d " + "in %.3fs", + outcome, + chat_id, + turn_id, + resolved_id, + len(user_query or ""), + len(user_image_data_urls or ()), + len(mentioned_documents or ()), + time.perf_counter() - t0, + ) + + +async def persist_assistant_shell( + *, + chat_id: int, + user_id: str | None, + turn_id: str, +) -> int | None: + """Pre-write an empty assistant row for the turn and return its id. + + Inserts a placeholder ``new_chat_messages`` row (empty text content) so + the streaming layer has a stable ``message_id`` to correlate against + for the rest of the turn. ``finalize_assistant_turn`` overwrites the + ``content`` field at end-of-stream with the rich ``ContentPart[]`` + snapshot produced by ``AssistantContentBuilder``. + + Returns the row id on success, ``None`` on a genuine DB failure (caller + should abort the turn rather than stream into a void). + + Idempotent against the ``(thread_id, turn_id, ASSISTANT)`` partial unique + index from migration 141: if a row already exists (resume retry, racing + legacy frontend, redelivered request, etc.) we look it up by + ``(thread_id, turn_id, role)`` and return its existing id. The streaming + layer is then free to UPDATE that row at finalize time. + """ + if not turn_id: + logger.error( + "persist_assistant_shell called without a turn_id (chat_id=%s); skipping", + chat_id, + ) + return None + + t0 = time.perf_counter() + outcome = "failed" + resolved_id: int | None = None + try: + async with shielded_async_session() as ws: + insert_stmt = ( + pg_insert(NewChatMessage) + .values( + thread_id=chat_id, + role=NewChatMessageRole.ASSISTANT, + content=_EMPTY_SHELL_CONTENT, + author_id=None, + turn_id=turn_id, + ) + .on_conflict_do_nothing( + index_elements=["thread_id", "turn_id", "role"], + index_where=sa_text("turn_id IS NOT NULL"), + ) + .returning(NewChatMessage.id) + ) + inserted_id = (await ws.execute(insert_stmt)).scalar() + + if inserted_id is None: + # Conflict — another writer (legacy FE appendMessage, + # resume retry, redelivered request) wrote the + # (thread_id, turn_id, ASSISTANT) row first. Look it up + # so the streaming layer can UPDATE the same row at + # finalize time. + lookup = await ws.execute( + select(NewChatMessage.id).where( + NewChatMessage.thread_id == chat_id, + NewChatMessage.turn_id == turn_id, + NewChatMessage.role == NewChatMessageRole.ASSISTANT, + ) + ) + existing_id = lookup.scalars().first() + if existing_id is None: + logger.warning( + "persist_assistant_shell: conflict but no matching " + "(thread_id, turn_id, role) row found " + "(chat_id=%s, turn_id=%s)", + chat_id, + turn_id, + ) + outcome = "integrity_no_match" + return None + resolved_id = int(existing_id) + outcome = "race_recovered" + else: + resolved_id = int(inserted_id) + outcome = "inserted" + + await ws.commit() + return resolved_id + except Exception: + logger.exception( + "persist_assistant_shell failed (chat_id=%s, turn_id=%s)", + chat_id, + turn_id, + ) + return None + finally: + _perf_log.info( + "[persist_assistant_shell] outcome=%s chat_id=%s turn_id=%s " + "message_id=%s in %.3fs", + outcome, + chat_id, + turn_id, + resolved_id, + time.perf_counter() - t0, + ) + + +async def finalize_assistant_turn( + *, + message_id: int, + chat_id: int, + search_space_id: int, + user_id: str | None, + turn_id: str, + content: list[dict[str, Any]], + accumulator: TurnTokenAccumulator | None, +) -> None: + """Finalize the assistant row and write its token_usage. + + Two writes in a single shielded session: + + 1. ``UPDATE new_chat_messages SET content = :c, updated_at = now() + WHERE id = :id`` — overwrites the placeholder ``persist_assistant_shell`` + wrote with the full ``ContentPart[]`` snapshot produced server-side. + 2. ``INSERT INTO token_usage (...) VALUES (...) ON CONFLICT (message_id) + WHERE message_id IS NOT NULL DO NOTHING`` — uses the partial unique + index ``uq_token_usage_message_id`` from migration 142 to make the + insert idempotent against ``append_message``'s recovery branch + (which uses the same ON CONFLICT clause). + + Substitutes the status-marker payload when ``content`` is empty + (pure tool-call turn that aborted before any output, or interrupt + before any event arrived). The status marker is preferable to a + blank text bubble because token accounting still runs and an ops + dashboard can flag the row. + + Best-effort — never raises. The streaming ``finally`` calls this + from within ``anyio.CancelScope(shield=True)``; any raised exception + here would mask the real error that triggered the cleanup. + """ + if not turn_id: + logger.error( + "finalize_assistant_turn called without turn_id " + "(chat_id=%s, message_id=%s); skipping", + chat_id, + message_id, + ) + return + if not message_id: + logger.error( + "finalize_assistant_turn called without message_id " + "(chat_id=%s, turn_id=%s); skipping", + chat_id, + turn_id, + ) + return + + payload: list[dict[str, Any]] + is_status_marker = False + if content: + payload = content + else: + payload = _STATUS_NO_RESPONSE + is_status_marker = True + + t0 = time.perf_counter() + outcome = "failed" + token_usage_attempted = bool( + accumulator is not None and accumulator.calls and user_id + ) + try: + async with shielded_async_session() as ws: + assistant_row = await ws.get(NewChatMessage, message_id) + if assistant_row is None: + logger.warning( + "finalize_assistant_turn: row not found " + "(chat_id=%s, message_id=%s, turn_id=%s); skipping", + chat_id, + message_id, + turn_id, + ) + outcome = "row_missing" + return + + assistant_row.content = payload + assistant_row.updated_at = datetime.now(UTC) + + # Token usage. ``record_token_usage`` (used elsewhere) does + # SELECT-then-INSERT in two statements which races with + # ``append_message``. Switch to a single INSERT ... ON + # CONFLICT DO NOTHING keyed on the migration-142 partial + # unique index so the loser silently drops its write at + # the DB level — exactly one row per ``message_id``, + # regardless of which session committed first. + if accumulator is not None and accumulator.calls and user_id: + try: + user_uuid = UUID(user_id) + except (TypeError, ValueError): + logger.warning( + "finalize_assistant_turn: invalid user_id=%r, " + "skipping token_usage row", + user_id, + ) + else: + insert_stmt = ( + pg_insert(TokenUsage) + .values( + usage_type="chat", + prompt_tokens=accumulator.total_prompt_tokens, + completion_tokens=accumulator.total_completion_tokens, + total_tokens=accumulator.grand_total, + cost_micros=accumulator.total_cost_micros, + model_breakdown=accumulator.per_message_summary(), + call_details={"calls": accumulator.serialized_calls()}, + thread_id=chat_id, + message_id=message_id, + search_space_id=search_space_id, + user_id=user_uuid, + ) + .on_conflict_do_nothing( + index_elements=["message_id"], + index_where=sa_text("message_id IS NOT NULL"), + ) + ) + await ws.execute(insert_stmt) + + await ws.commit() + outcome = "ok" + except Exception: + logger.exception( + "finalize_assistant_turn failed (chat_id=%s, message_id=%s, turn_id=%s)", + chat_id, + message_id, + turn_id, + ) + finally: + _perf_log.info( + "[finalize_assistant_turn] outcome=%s chat_id=%s message_id=%s " + "turn_id=%s parts=%d status_marker=%s " + "token_usage_attempted=%s in %.3fs", + outcome, + chat_id, + message_id, + turn_id, + len(payload), + is_status_marker, + token_usage_attempted, + time.perf_counter() - t0, + ) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index f7ddd8909..487602c3b 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -25,7 +25,6 @@ from uuid import UUID import anyio from langchain_core.messages import HumanMessage -from sqlalchemy import func from sqlalchemy.future import select from sqlalchemy.orm import selectinload @@ -314,6 +313,19 @@ class StreamResult: verification_succeeded: bool = False commit_gate_passed: bool = True commit_gate_reason: str = "" + # Pre-allocated assistant ``new_chat_messages.id`` for this turn, + # captured by ``persist_assistant_shell`` right after the user row is + # persisted. ``None`` for the legacy / anonymous code paths that don't + # opt in to server-side ``ContentPart[]`` projection. + assistant_message_id: int | None = None + # In-memory mirror of the FE's assistant-ui ``ContentPartsState``, + # populated by the lifecycle methods called from ``_stream_agent_events`` + # at each ``streaming_service.format_*`` yield site. Snapshot in the + # streaming ``finally`` to produce the rich JSONB persisted by + # ``finalize_assistant_turn``. ``repr=False`` keeps the + # log-on-error path (``StreamResult`` is logged in some error + # branches) from dumping a potentially-large parts list. + content_builder: Any | None = field(default=None, repr=False) def _safe_float(value: Any, default: float = 0.0) -> float: @@ -721,6 +733,7 @@ async def _stream_agent_events( fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD, fallback_commit_thread_id: int | None = None, runtime_context: Any = None, + content_builder: Any | None = None, ) -> AsyncGenerator[str, None]: """Shared async generator that streams and formats astream_events from the agent. @@ -737,6 +750,15 @@ async def _stream_agent_events( initial_step_id: If set, the helper inherits an already-active thinking step. initial_step_title: Title of the inherited thinking step. initial_step_items: Items of the inherited thinking step. + content_builder: Optional ``AssistantContentBuilder``. When set, every + ``streaming_service.format_*`` yield site also drives the matching + builder lifecycle method (``on_text_*``, ``on_reasoning_*``, + ``on_tool_*``, ``on_thinking_step``, ``on_step_separator``) so the + in-memory ``ContentPart[]`` projection stays in lockstep with what + the FE renders live. Pure in-memory accumulation — no DB I/O — + consumed by the streaming ``finally`` to produce the rich JSONB + persisted via ``finalize_assistant_turn``. ``None`` (the default) + is used by the anonymous / legacy code paths and is a no-op. Yields: SSE-formatted strings for each event. @@ -801,12 +823,46 @@ async def _stream_agent_events( current_lc_tool_call_id: dict[str, str | None] = {"value": None} def _emit_tool_output(call_id: str, output: Any) -> str: + # Drive the builder before formatting the SSE so the in-memory + # ContentPart[] mirror sees the result attached to the same + # card the FE will render. Builder method is a no-op when + # ``content_builder`` is None (anonymous / legacy paths). + if content_builder is not None: + content_builder.on_tool_output_available( + call_id, output, current_lc_tool_call_id["value"] + ) return streaming_service.format_tool_output_available( call_id, output, langchain_tool_call_id=current_lc_tool_call_id["value"], ) + def _emit_thinking_step( + *, + step_id: str, + title: str, + status: str = "in_progress", + items: list[str] | None = None, + ) -> str: + """Format a thinking-step SSE event and notify the builder. + + Single helper used at every ``format_thinking_step`` yield site + in this generator. Drives ``AssistantContentBuilder.on_thinking_step`` + first so the FE-mirror state lands the update before the SSE + carrying the same data leaves the wire — order matches the FE + pipeline (``processSharedStreamEvent`` updates state, then + flushes). Builder call is a no-op when ``content_builder`` is + None (anonymous / legacy paths). + """ + if content_builder is not None: + content_builder.on_thinking_step(step_id, title, status, items) + return streaming_service.format_thinking_step( + step_id=step_id, + title=title, + status=status, + items=items, + ) + def next_thinking_step_id() -> str: nonlocal thinking_step_counter thinking_step_counter += 1 @@ -816,7 +872,7 @@ async def _stream_agent_events( nonlocal last_active_step_id if last_active_step_id and last_active_step_id not in completed_step_ids: completed_step_ids.add(last_active_step_id) - event = streaming_service.format_thinking_step( + event = _emit_thinking_step( step_id=last_active_step_id, title=last_active_step_title, status="completed", @@ -861,6 +917,8 @@ async def _stream_agent_events( if parity_v2 and reasoning_delta: if current_text_id is not None: yield streaming_service.format_text_end(current_text_id) + if content_builder is not None: + content_builder.on_text_end(current_text_id) current_text_id = None if current_reasoning_id is None: completion_event = complete_current_step() @@ -873,13 +931,21 @@ async def _stream_agent_events( just_finished_tool = False current_reasoning_id = streaming_service.generate_reasoning_id() yield streaming_service.format_reasoning_start(current_reasoning_id) + if content_builder is not None: + content_builder.on_reasoning_start(current_reasoning_id) yield streaming_service.format_reasoning_delta( current_reasoning_id, reasoning_delta ) + if content_builder is not None: + content_builder.on_reasoning_delta( + current_reasoning_id, reasoning_delta + ) if text_delta: if current_reasoning_id is not None: yield streaming_service.format_reasoning_end(current_reasoning_id) + if content_builder is not None: + content_builder.on_reasoning_end(current_reasoning_id) current_reasoning_id = None if current_text_id is None: completion_event = complete_current_step() @@ -892,8 +958,12 @@ async def _stream_agent_events( just_finished_tool = False current_text_id = streaming_service.generate_text_id() yield streaming_service.format_text_start(current_text_id) + if content_builder is not None: + content_builder.on_text_start(current_text_id) yield streaming_service.format_text_delta(current_text_id, text_delta) accumulated_text += text_delta + if content_builder is not None: + content_builder.on_text_delta(current_text_id, text_delta) # Live tool-call argument streaming. Runs AFTER text/reasoning # processing so chunks containing both stay in their natural @@ -925,11 +995,17 @@ async def _stream_agent_events( # within the same stream window. if current_text_id is not None: yield streaming_service.format_text_end(current_text_id) + if content_builder is not None: + content_builder.on_text_end(current_text_id) current_text_id = None if current_reasoning_id is not None: yield streaming_service.format_reasoning_end( current_reasoning_id ) + if content_builder is not None: + content_builder.on_reasoning_end( + current_reasoning_id + ) current_reasoning_id = None index_to_meta[idx] = { @@ -942,6 +1018,8 @@ async def _stream_agent_events( name, langchain_tool_call_id=lc_id, ) + if content_builder is not None: + content_builder.on_tool_input_start(ui_id, name, lc_id) # Emit args delta for any chunk at a registered # index (including idless continuations). Once an @@ -957,6 +1035,10 @@ async def _stream_agent_events( yield streaming_service.format_tool_input_delta( meta["ui_id"], args_chunk ) + if content_builder is not None: + content_builder.on_tool_input_delta( + meta["ui_id"], args_chunk + ) else: pending_tool_call_chunks.append(tcc) @@ -974,6 +1056,8 @@ async def _stream_agent_events( if current_text_id is not None: yield streaming_service.format_text_end(current_text_id) + if content_builder is not None: + content_builder.on_text_end(current_text_id) current_text_id = None if last_active_step_title != "Synthesizing response": @@ -994,7 +1078,7 @@ async def _stream_agent_events( ) last_active_step_title = "Listing files" last_active_step_items = [ls_path] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Listing files", status="in_progress", @@ -1009,7 +1093,7 @@ async def _stream_agent_events( display_fp = fp if len(fp) <= 80 else "…" + fp[-77:] last_active_step_title = "Reading file" last_active_step_items = [display_fp] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Reading file", status="in_progress", @@ -1024,7 +1108,7 @@ async def _stream_agent_events( display_fp = fp if len(fp) <= 80 else "…" + fp[-77:] last_active_step_title = "Writing file" last_active_step_items = [display_fp] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Writing file", status="in_progress", @@ -1039,7 +1123,7 @@ async def _stream_agent_events( display_fp = fp if len(fp) <= 80 else "…" + fp[-77:] last_active_step_title = "Editing file" last_active_step_items = [display_fp] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Editing file", status="in_progress", @@ -1056,7 +1140,7 @@ async def _stream_agent_events( ) last_active_step_title = "Searching files" last_active_step_items = [f"{pat} in {base_path}"] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Searching files", status="in_progress", @@ -1076,7 +1160,7 @@ async def _stream_agent_events( last_active_step_items = [ f'"{display_pat}"' + (f" in {grep_path}" if grep_path else "") ] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Searching content", status="in_progress", @@ -1091,7 +1175,7 @@ async def _stream_agent_events( display_path = rm_path if len(rm_path) <= 80 else "…" + rm_path[-77:] last_active_step_title = "Deleting file" last_active_step_items = [display_path] if display_path else [] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Deleting file", status="in_progress", @@ -1108,7 +1192,7 @@ async def _stream_agent_events( ) last_active_step_title = "Deleting folder" last_active_step_items = [display_path] if display_path else [] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Deleting folder", status="in_progress", @@ -1125,7 +1209,7 @@ async def _stream_agent_events( ) last_active_step_title = "Creating folder" last_active_step_items = [display_path] if display_path else [] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Creating folder", status="in_progress", @@ -1148,7 +1232,7 @@ async def _stream_agent_events( last_active_step_items = ( [f"{display_src} → {display_dst}"] if src or dst else [] ) - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Moving file", status="in_progress", @@ -1165,7 +1249,7 @@ async def _stream_agent_events( if todo_count else [] ) - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Planning tasks", status="in_progress", @@ -1180,7 +1264,7 @@ async def _stream_agent_events( display_title = doc_title[:60] + ("…" if len(doc_title) > 60 else "") last_active_step_title = "Saving document" last_active_step_items = [display_title] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Saving document", status="in_progress", @@ -1196,7 +1280,7 @@ async def _stream_agent_events( last_active_step_items = [ f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}" ] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Generating image", status="in_progress", @@ -1212,7 +1296,7 @@ async def _stream_agent_events( last_active_step_items = [ f"URL: {url[:80]}{'...' if len(url) > 80 else ''}" ] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Scraping webpage", status="in_progress", @@ -1235,7 +1319,7 @@ async def _stream_agent_events( f"Content: {content_len:,} characters", "Preparing audio generation...", ] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Generating podcast", status="in_progress", @@ -1256,7 +1340,7 @@ async def _stream_agent_events( f"Topic: {report_topic}", "Analyzing source content...", ] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title=step_title, status="in_progress", @@ -1271,7 +1355,7 @@ async def _stream_agent_events( display_cmd = cmd[:80] + ("…" if len(cmd) > 80 else "") last_active_step_title = "Running command" last_active_step_items = [f"$ {display_cmd}"] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title="Running command", status="in_progress", @@ -1288,7 +1372,7 @@ async def _stream_agent_events( tool_name.replace("_", " ").strip().capitalize() or tool_name ) last_active_step_items = [] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=tool_step_id, title=last_active_step_title, status="in_progress", @@ -1349,6 +1433,10 @@ async def _stream_agent_events( tool_name, langchain_tool_call_id=langchain_tool_call_id, ) + if content_builder is not None: + content_builder.on_tool_input_start( + tool_call_id, tool_name, langchain_tool_call_id + ) if run_id: ui_tool_call_id_by_run[run_id] = tool_call_id @@ -1371,6 +1459,13 @@ async def _stream_agent_events( _safe_input, langchain_tool_call_id=langchain_tool_call_id, ) + if content_builder is not None: + content_builder.on_tool_input_available( + tool_call_id, + tool_name, + _safe_input, + langchain_tool_call_id, + ) elif event_type == "on_tool_end": active_tool_depth = max(0, active_tool_depth - 1) @@ -1443,70 +1538,70 @@ async def _stream_agent_events( current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id] if tool_name == "read_file": - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Reading file", status="completed", items=last_active_step_items, ) elif tool_name == "write_file": - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Writing file", status="completed", items=last_active_step_items, ) elif tool_name == "edit_file": - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Editing file", status="completed", items=last_active_step_items, ) elif tool_name == "glob": - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Searching files", status="completed", items=last_active_step_items, ) elif tool_name == "grep": - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Searching content", status="completed", items=last_active_step_items, ) elif tool_name == "rm": - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Deleting file", status="completed", items=last_active_step_items, ) elif tool_name == "rmdir": - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Deleting folder", status="completed", items=last_active_step_items, ) elif tool_name == "mkdir": - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Creating folder", status="completed", items=last_active_step_items, ) elif tool_name == "move_file": - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Moving file", status="completed", items=last_active_step_items, ) elif tool_name == "write_todos": - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Planning tasks", status="completed", @@ -1523,7 +1618,7 @@ async def _stream_agent_events( *last_active_step_items, result_str[:80] if is_error else "Saved to knowledge base", ] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Saving document", status="completed", @@ -1542,7 +1637,7 @@ async def _stream_agent_events( else "Generation failed" ) completed_items = [*last_active_step_items, f"Error: {error_msg}"] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Generating image", status="completed", @@ -1566,7 +1661,7 @@ async def _stream_agent_events( ] else: completed_items = [*last_active_step_items, "Content extracted"] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Scraping webpage", status="completed", @@ -1612,7 +1707,7 @@ async def _stream_agent_events( ] else: completed_items = last_active_step_items - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Generating podcast", status="completed", @@ -1647,7 +1742,7 @@ async def _stream_agent_events( ] else: completed_items = last_active_step_items - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Generating video presentation", status="completed", @@ -1695,7 +1790,7 @@ async def _stream_agent_events( else: completed_items = last_active_step_items - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title=step_title, status="completed", @@ -1721,7 +1816,7 @@ async def _stream_agent_events( ] else: completed_items = [*last_active_step_items, "Finished"] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Running command", status="completed", @@ -1761,7 +1856,7 @@ async def _stream_agent_events( completed_items.append(f"(+{len(file_names) - 4} more)") else: completed_items = ["No files found"] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title="Listing files", status="completed", @@ -1773,7 +1868,7 @@ async def _stream_agent_events( fallback_title = ( tool_name.replace("_", " ").strip().capitalize() or tool_name ) - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=original_step_id, title=fallback_title, status="completed", @@ -2113,7 +2208,7 @@ async def _stream_agent_events( # Phase transitions: replace everything after topic last_active_step_items = [*topic_items, message] - yield streaming_service.format_thinking_step( + yield _emit_thinking_step( step_id=last_active_step_id, title=last_active_step_title, status="in_progress", @@ -2155,10 +2250,14 @@ async def _stream_agent_events( elif event_type in ("on_chain_end", "on_agent_end"): if current_text_id is not None: yield streaming_service.format_text_end(current_text_id) + if content_builder is not None: + content_builder.on_text_end(current_text_id) current_text_id = None if current_text_id is not None: yield streaming_service.format_text_end(current_text_id) + if content_builder is not None: + content_builder.on_text_end(current_text_id) completion_event = complete_current_step() if completion_event: @@ -2243,8 +2342,14 @@ async def _stream_agent_events( ) gate_text_id = streaming_service.generate_text_id() yield streaming_service.format_text_start(gate_text_id) + if content_builder is not None: + content_builder.on_text_start(gate_text_id) yield streaming_service.format_text_delta(gate_text_id, gate_notice) + if content_builder is not None: + content_builder.on_text_delta(gate_text_id, gate_notice) yield streaming_service.format_text_end(gate_text_id) + if content_builder is not None: + content_builder.on_text_end(gate_text_id) yield streaming_service.format_terminal_info(gate_notice, "error") accumulated_text = gate_notice else: @@ -2270,6 +2375,7 @@ async def stream_new_chat( llm_config_id: int = -1, mentioned_document_ids: list[int] | None = None, mentioned_surfsense_doc_ids: list[int] | None = None, + mentioned_documents: list[dict[str, Any]] | None = None, checkpoint_id: str | None = None, needs_history_bootstrap: bool = False, thread_visibility: ChatVisibility | None = None, @@ -2949,6 +3055,96 @@ async def stream_new_chat( ) yield streaming_service.format_data("turn-status", {"status": "busy"}) + # Persist the user-side row for this turn before any expensive + # work runs. Closes the "ghost-thread" abuse vector + # (authenticated client hits POST /new_chat then never calls + # /messages — empty new_chat_messages, free LLM completion). + # Idempotent against the unique index in migration 141 so the + # legacy frontend appendMessage call is a no-op on the second + # writer. Hard failure aborts the turn so we never produce a + # title or assistant row that isn't anchored to a persisted + # user message. + from app.tasks.chat.content_builder import AssistantContentBuilder + from app.tasks.chat.persistence import ( + persist_assistant_shell, + persist_user_turn, + ) + + user_message_id = await persist_user_turn( + chat_id=chat_id, + user_id=user_id, + turn_id=stream_result.turn_id, + user_query=user_query, + user_image_data_urls=user_image_data_urls, + mentioned_documents=mentioned_documents, + ) + if user_message_id is None: + yield _emit_stream_error( + message=( + "We couldn't save your message. Please try again in a moment." + ), + error_kind="server_error", + error_code="MESSAGE_PERSIST_FAILED", + ) + yield streaming_service.format_data("turn-status", {"status": "idle"}) + yield streaming_service.format_finish_step() + yield streaming_service.format_finish() + yield streaming_service.format_done() + return + + # Emit canonical user message id BEFORE any LLM streaming so the + # FE can rename its optimistic ``msg-user-XXX`` placeholder to + # ``msg-{user_message_id}`` and unlock features gated on a real + # DB id (comments, edit-from-this-message). See B4 in + # ``sse-based_message_id_handshake`` plan. + yield streaming_service.format_data( + "user-message-id", + {"message_id": user_message_id, "turn_id": stream_result.turn_id}, + ) + + # Pre-write the assistant row for this turn so we have a stable + # ``message_id`` to anchor mid-stream metadata (token_usage, + # future agent_action_log.message_id correlation) and a + # write-once UPDATE target at finalize time. Idempotent against + # the (thread_id, turn_id, ASSISTANT) partial unique index from + # migration 141 — if the legacy frontend appendMessage races + # this, we recover the existing row's id. + assistant_message_id = await persist_assistant_shell( + chat_id=chat_id, + user_id=user_id, + turn_id=stream_result.turn_id, + ) + if assistant_message_id is None: + # Genuine DB failure — abort the turn rather than stream + # into a void. The user row is already persisted so the + # legacy "ghost-thread" gate isn't reopened. + yield _emit_stream_error( + message=( + "We couldn't initialize the assistant message. Please try again." + ), + error_kind="server_error", + error_code="MESSAGE_PERSIST_FAILED", + ) + yield streaming_service.format_data("turn-status", {"status": "idle"}) + yield streaming_service.format_finish_step() + yield streaming_service.format_finish() + yield streaming_service.format_done() + return + + # Emit canonical assistant message id BEFORE any LLM streaming + # so the FE can rename its optimistic ``msg-assistant-XXX`` + # placeholder to ``msg-{assistant_message_id}`` and bind + # ``tokenUsageStore`` / ``pendingInterrupt`` to the real id + # immediately. See B4 in ``sse-based_message_id_handshake`` + # plan. + yield streaming_service.format_data( + "assistant-message-id", + {"message_id": assistant_message_id, "turn_id": stream_result.turn_id}, + ) + + stream_result.assistant_message_id = assistant_message_id + stream_result.content_builder = AssistantContentBuilder() + # Initial thinking step - analyzing the request if mentioned_surfsense_docs: initial_title = "Analyzing referenced content" @@ -2981,6 +3177,15 @@ async def stream_new_chat( initial_items = [f"{action_verb}: {' '.join(processing_parts)}"] initial_step_id = "thinking-1" + # Drive the builder for this initial thinking step too — the + # ``_emit_thinking_step`` helper lives inside ``_stream_agent_events`` + # so it isn't in scope here, but the FE folds this step into + # the same singleton ``data-thinking-steps`` part as everything + # the agent stream emits later. Mirror that fold server-side. + if stream_result.content_builder is not None: + stream_result.content_builder.on_thinking_step( + initial_step_id, initial_title, "in_progress", initial_items + ) yield streaming_service.format_thinking_step( step_id=initial_step_id, title=initial_title, @@ -2997,16 +3202,34 @@ async def stream_new_chat( # Check if this is the first assistant response so we can generate # a title in parallel with the agent stream (better UX than waiting # until after the full response). - assistant_count_result = await session.execute( - select(func.count(NewChatMessage.id)).filter( + # Use a LIMIT 1 EXISTS-style probe rather than COUNT(*) because + # this is now a hot path executed on every turn, and COUNT scales + # with thread length (server-side persistence can grow rows + # quickly under power users). + # + # IMPORTANT: ``persist_assistant_shell`` above (line ~3112) already + # inserted THIS turn's assistant row. We must therefore exclude + # it from the probe — otherwise the gate fires on every turn + # except the very first, and title generation never runs for new + # threads. Excluding by primary key (``id != assistant_message_id``) + # is bulletproof regardless of ``turn_id`` shape (legacy NULLs, + # resume turns, etc.). + first_assistant_probe = await session.execute( + select(NewChatMessage.id) + .filter( NewChatMessage.thread_id == chat_id, NewChatMessage.role == "assistant", + NewChatMessage.id != assistant_message_id, ) + .limit(1) ) - is_first_response = (assistant_count_result.scalar() or 0) == 0 + is_first_response = first_assistant_probe.scalars().first() is None title_task: asyncio.Task[tuple[str | None, dict | None]] | None = None - if is_first_response: + # Gate title generation on a persisted user message so a stream + # that fails before persistence (we abort above) can never leave + # behind a thread with a generated title and no anchoring rows. + if is_first_response and user_message_id is not None: async def _generate_title() -> tuple[str | None, dict | None]: """Generate a short title via litellm.acompletion. @@ -3138,6 +3361,7 @@ async def stream_new_chat( ), fallback_commit_thread_id=chat_id, runtime_context=runtime_context, + content_builder=stream_result.content_builder, ): if not _first_event_logged: _perf_log.info( @@ -3493,6 +3717,81 @@ async def stream_new_chat( with contextlib.suppress(Exception): await session.close() + # Server-side assistant-message + token_usage finalization. + # Runs after the main session has been closed (uses its own + # shielded session) so we don't fight the same DB connection. + # Idempotent against the legacy frontend appendMessage: + # * the assistant row was already INSERTed by + # ``persist_assistant_shell`` above, so this just UPDATEs + # it with the rich ContentPart[] from the builder. + # * token_usage uses INSERT ... ON CONFLICT DO NOTHING + # against migration 142's partial unique index, so a + # racing append_message recovery branch can never + # double-write. + # ``mark_interrupted`` closes any open text/reasoning blocks + # and flips running tool-calls (no result) to state=aborted + # so the persisted JSONB reflects a coherent end-state even + # on client disconnect. + # Never raises (best-effort, logs only). + if ( + stream_result + and stream_result.turn_id + and stream_result.assistant_message_id + ): + from app.tasks.chat.persistence import finalize_assistant_turn + + builder_stats: dict[str, int] | None = None + if stream_result.content_builder is not None: + stream_result.content_builder.mark_interrupted() + # Snapshot stats BEFORE deepcopy in ``snapshot()`` so + # the perf log records the actual finalised payload + # (post-mark_interrupted), not the live-mutating + # builder state. + builder_stats = stream_result.content_builder.stats() + content_payload = stream_result.content_builder.snapshot() + else: + # Defensive fallback — we always set the builder + # alongside ``assistant_message_id`` above, so this + # branch only fires if a future refactor ever + # decouples them. Persist whatever accumulated + # text we captured so the row at least renders. + content_payload = [ + { + "type": "text", + "text": stream_result.accumulated_text or "", + } + ] + + if builder_stats is not None: + _perf_log.info( + "[stream_new_chat] finalize_payload chat_id=%s " + "message_id=%s parts=%d bytes=%d text=%d " + "reasoning=%d tool_calls=%d " + "tool_calls_completed=%d tool_calls_aborted=%d " + "thinking_step_parts=%d step_separators=%d", + chat_id, + stream_result.assistant_message_id, + builder_stats["parts"], + builder_stats["bytes"], + builder_stats["text"], + builder_stats["reasoning"], + builder_stats["tool_calls"], + builder_stats["tool_calls_completed"], + builder_stats["tool_calls_aborted"], + builder_stats["thinking_step_parts"], + builder_stats["step_separators"], + ) + + await finalize_assistant_turn( + message_id=stream_result.assistant_message_id, + chat_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + turn_id=stream_result.turn_id, + content=content_payload, + accumulator=accumulator, + ) + # Persist any sandbox-produced files to local storage so they # remain downloadable after the Daytona sandbox auto-deletes. if stream_result and stream_result.sandbox_files: @@ -3937,6 +4236,50 @@ async def stream_resume_chat( ) yield streaming_service.format_data("turn-status", {"status": "busy"}) + # Pre-write a fresh assistant row for this resume turn. The + # original (interrupted) ``stream_new_chat`` invocation already + # persisted its own assistant row anchored to a different + # ``turn_id``; resume allocates a new ``turn_id`` (above) so we + # need a separate row keyed on the same ``(thread_id, turn_id, + # ASSISTANT)`` invariant. Idempotent against migration 141's + # partial unique index — recovers existing id on retry. + from app.tasks.chat.content_builder import AssistantContentBuilder + from app.tasks.chat.persistence import persist_assistant_shell + + assistant_message_id = await persist_assistant_shell( + chat_id=chat_id, + user_id=user_id, + turn_id=stream_result.turn_id, + ) + if assistant_message_id is None: + yield _emit_stream_error( + message=( + "We couldn't initialize the assistant message. Please try again." + ), + error_kind="server_error", + error_code="MESSAGE_PERSIST_FAILED", + ) + yield streaming_service.format_data("turn-status", {"status": "idle"}) + yield streaming_service.format_finish_step() + yield streaming_service.format_finish() + yield streaming_service.format_done() + return + + # Emit canonical assistant message id BEFORE any LLM streaming + # so the FE can rename ``pendingInterrupt.assistantMsgId`` to + # ``msg-{assistant_message_id}`` immediately. Resume does NOT + # emit ``data-user-message-id`` because the user row is from + # the original interrupted turn (different ``turn_id``) and is + # never re-persisted here. See B5 in the + # ``sse-based_message_id_handshake`` plan. + yield streaming_service.format_data( + "assistant-message-id", + {"message_id": assistant_message_id, "turn_id": stream_result.turn_id}, + ) + + stream_result.assistant_message_id = assistant_message_id + stream_result.content_builder = AssistantContentBuilder() + # Resume path doesn't carry new ``mentioned_document_ids`` — # those are seeded in the original turn. We still pass a # context so future middleware extensions (Phase 2) can rely on @@ -3968,6 +4311,7 @@ async def stream_resume_chat( ), fallback_commit_thread_id=chat_id, runtime_context=runtime_context, + content_builder=stream_result.content_builder, ): if not _first_event_logged: _perf_log.info( @@ -4219,6 +4563,64 @@ async def stream_resume_chat( with contextlib.suppress(Exception): await session.close() + # Server-side assistant-message + token_usage finalization for + # the resume flow. The original user message was persisted by + # the original (interrupted) ``stream_new_chat`` invocation; + # the resume's own ``persist_assistant_shell`` write lives at + # the new ``turn_id`` above. This finalize updates that row + # with the rich ContentPart[] from the builder and writes + # token_usage idempotently via migration 142's partial + # unique index. Best-effort, never raises. + if ( + stream_result + and stream_result.turn_id + and stream_result.assistant_message_id + ): + from app.tasks.chat.persistence import finalize_assistant_turn + + builder_stats: dict[str, int] | None = None + if stream_result.content_builder is not None: + stream_result.content_builder.mark_interrupted() + builder_stats = stream_result.content_builder.stats() + content_payload = stream_result.content_builder.snapshot() + else: + content_payload = [ + { + "type": "text", + "text": stream_result.accumulated_text or "", + } + ] + + if builder_stats is not None: + _perf_log.info( + "[stream_resume] finalize_payload chat_id=%s " + "message_id=%s parts=%d bytes=%d text=%d " + "reasoning=%d tool_calls=%d " + "tool_calls_completed=%d tool_calls_aborted=%d " + "thinking_step_parts=%d step_separators=%d", + chat_id, + stream_result.assistant_message_id, + builder_stats["parts"], + builder_stats["bytes"], + builder_stats["text"], + builder_stats["reasoning"], + builder_stats["tool_calls"], + builder_stats["tool_calls_completed"], + builder_stats["tool_calls_aborted"], + builder_stats["thinking_step_parts"], + builder_stats["step_separators"], + ) + + await finalize_assistant_turn( + message_id=stream_result.assistant_message_id, + chat_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + turn_id=stream_result.turn_id, + content=content_payload, + accumulator=accumulator, + ) + agent = llm = connector_service = None stream_result = None session = None diff --git a/surfsense_backend/tests/integration/chat/__init__.py b/surfsense_backend/tests/integration/chat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/integration/chat/test_append_message_recovery.py b/surfsense_backend/tests/integration/chat/test_append_message_recovery.py new file mode 100644 index 000000000..a5182a978 --- /dev/null +++ b/surfsense_backend/tests/integration/chat/test_append_message_recovery.py @@ -0,0 +1,573 @@ +"""Integration tests for the cross-writer integration between the +streaming chat task and the legacy ``POST /threads/{id}/messages`` +(``append_message``) round-trip. + +Two scenarios anchor the contract introduced by the server-side +persistence rework: + +(a) **Tool-heavy turn streamed to completion.** + + Drives :class:`AssistantContentBuilder` with synthetic SSE events + that mirror what ``_stream_agent_events`` emits for a turn that + interleaves text, reasoning, a tool call (start/delta/available/ + output), and a final text block. Then runs + :func:`finalize_assistant_turn` and asserts: + + * ``new_chat_messages.content`` JSONB matches the + ``ContentPart[]`` shape the FE history loader expects, with full + ``args``/``argsText``/``result``/``langchainToolCallId`` for the + tool call. + * Exactly one ``token_usage`` row exists keyed on the assistant + ``message_id``. + +(b) **Stale FE ``appendMessage`` after server finalize.** + + Verifies the recovery branch of the ``append_message`` route now + returns the SERVER's authoritative ``ContentPart[]`` (not the FE's + stale payload) when the partial unique index from migration 141 + blocks the FE's INSERT, and that the ``ON CONFLICT DO NOTHING`` + clause from migration 142 stops the route from writing a duplicate + ``token_usage`` row. +""" + +from __future__ import annotations + +import json +from contextlib import asynccontextmanager + +import pytest +import pytest_asyncio +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import ( + ChatVisibility, + NewChatMessage, + NewChatMessageRole, + NewChatThread, + SearchSpace, + TokenUsage, + User, +) +from app.routes import new_chat_routes +from app.services.token_tracking_service import TurnTokenAccumulator +from app.tasks.chat import persistence as persistence_module +from app.tasks.chat.content_builder import AssistantContentBuilder +from app.tasks.chat.persistence import ( + finalize_assistant_turn, + persist_assistant_shell, +) + +pytestmark = pytest.mark.integration + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture +async def db_thread( + db_session: AsyncSession, db_user: User, db_search_space: SearchSpace +) -> NewChatThread: + thread = NewChatThread( + title="Test Chat", + search_space_id=db_search_space.id, + created_by_id=db_user.id, + visibility=ChatVisibility.PRIVATE, + ) + db_session.add(thread) + await db_session.flush() + return thread + + +@pytest.fixture +def patched_shielded_session(monkeypatch, db_session: AsyncSession): + """Route persistence helpers to the test's savepoint-bound session. + + Mirrors the helper from ``test_persistence.py`` so the helpers' + internal ``ws.commit()`` / ``ws.rollback()`` resolve to SAVEPOINT + operations on the test transaction instead of touching real + autocommit boundaries. + """ + + @asynccontextmanager + async def _fake_shielded_session(): + yield db_session + + monkeypatch.setattr( + persistence_module, + "shielded_async_session", + _fake_shielded_session, + ) + return db_session + + +@pytest.fixture +def bypass_permission_checks(monkeypatch): + """Replace RBAC + thread access checks with no-ops. + + The append_message route under test calls ``check_permission`` and + ``check_thread_access``; those rely on a SearchSpaceMembership row + that the existing integration fixtures don't create. The contract + we want to verify here is the ``IntegrityError`` -> recovery branch, + not the RBAC plumbing — so stub them. + """ + + async def _allow(*_args, **_kwargs): + return True + + monkeypatch.setattr(new_chat_routes, "check_permission", _allow) + monkeypatch.setattr(new_chat_routes, "check_thread_access", _allow) + return None + + +class _FakeRequest: + """Minimal Request stand-in used by ``append_message``. + + The route only calls ``await request.json()`` — keep the surface + area tight so this doesn't accidentally hide future signature + changes that we *would* want to break the test. + """ + + def __init__(self, body: dict): + self._body = body + + async def json(self) -> dict: + return self._body + + +def _build_tool_heavy_content() -> list[dict]: + """Drive ``AssistantContentBuilder`` through a tool-heavy turn. + + Produces the same ``ContentPart[]`` shape the streaming layer would + persist if ``_stream_agent_events`` ran a turn with: opening + reasoning -> text -> tool call (input start/delta/available/output) + -> closing text. Centralised here so the (a) and (b) scenarios use + the same authoritative payload. + """ + builder = AssistantContentBuilder() + + builder.on_reasoning_start("r1") + builder.on_reasoning_delta("r1", "Let me look up ") + builder.on_reasoning_delta("r1", "the file listing.") + builder.on_reasoning_end("r1") + + builder.on_text_start("t1") + builder.on_text_delta("t1", "Sure, listing files in ") + builder.on_text_delta("t1", "/.") + builder.on_text_end("t1") + + builder.on_tool_input_start( + "tool_call_ui_1", + tool_name="ls", + langchain_tool_call_id="lc_call_xyz", + ) + builder.on_tool_input_delta("tool_call_ui_1", '{"path"') + builder.on_tool_input_delta("tool_call_ui_1", ': "/"}') + builder.on_tool_input_available( + "tool_call_ui_1", + tool_name="ls", + args={"path": "/"}, + langchain_tool_call_id="lc_call_xyz", + ) + builder.on_tool_output_available( + "tool_call_ui_1", + output={"files": ["a.txt", "b.txt"]}, + langchain_tool_call_id="lc_call_xyz", + ) + + builder.on_text_start("t2") + builder.on_text_delta("t2", "Found 2 files: a.txt and b.txt.") + builder.on_text_end("t2") + + return builder.snapshot() + + +def _accumulator_with_one_call() -> TurnTokenAccumulator: + acc = TurnTokenAccumulator() + acc.add( + model="gpt-4o-mini", + prompt_tokens=200, + completion_tokens=80, + total_tokens=280, + cost_micros=22222, + ) + return acc + + +# --------------------------------------------------------------------------- +# (a) Tool-heavy stream finalize +# --------------------------------------------------------------------------- + + +class TestToolHeavyTurnFinalize: + async def test_full_tool_call_persisted_and_one_token_usage_row( + self, + db_session, + db_user, + db_thread, + db_search_space, + patched_shielded_session, + ): + """End-to-end seam: builder snapshot -> finalize -> DB row. + + Matches the production flow's *content* invariant: whatever + ``AssistantContentBuilder.snapshot()`` produces is what the + streaming layer hands to ``finalize_assistant_turn``, so this + test catches any drift between the JSONB shape the builder + emits and the one the FE history loader expects. + """ + thread_id = db_thread.id + user_id_str = str(db_user.id) + search_space_id = db_search_space.id + turn_id = f"{thread_id}:tool_heavy" + + msg_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + assert msg_id is not None + + snapshot = _build_tool_heavy_content() + # Sanity-check the snapshot before we hand it to the DB so a + # builder regression surfaces here, not deep inside an opaque + # JSONB diff. + assert any(p.get("type") == "reasoning" for p in snapshot) + text_parts = [p for p in snapshot if p.get("type") == "text"] + assert len(text_parts) == 2 + tool_parts = [p for p in snapshot if p.get("type") == "tool-call"] + assert len(tool_parts) == 1 + tool_part = tool_parts[0] + assert tool_part["toolCallId"] == "tool_call_ui_1" + assert tool_part["toolName"] == "ls" + assert tool_part["args"] == {"path": "/"} + # ``argsText`` ends up as the pretty-printed final args (the + # ``tool-input-available`` event replaces the streamed deltas + # with ``json.dumps(args, indent=2)`` to match the FE's + # post-stream rendering). + assert tool_part["argsText"] == '{\n "path": "/"\n}' + assert tool_part["result"] == {"files": ["a.txt", "b.txt"]} + # ``langchainToolCallId`` is the agent-side correlation id used + # by the regenerate path; a missing one breaks + # edit-from-tool-call later. + assert tool_part["langchainToolCallId"] == "lc_call_xyz" + + await finalize_assistant_turn( + message_id=msg_id, + chat_id=thread_id, + search_space_id=search_space_id, + user_id=user_id_str, + turn_id=turn_id, + content=snapshot, + accumulator=_accumulator_with_one_call(), + ) + + # ``content`` must round-trip byte-for-byte through the JSONB + # column. SQLAlchemy doesn't auto-refresh the row that survived + # the savepoint commit, so refresh explicitly. + row = await db_session.get(NewChatMessage, msg_id) + await db_session.refresh(row) + + # The history loader reads ``content`` straight into the FE's + # parts array, so a strict equality comparison is the right + # invariant here. + assert row.content == snapshot + # Tool-call parts must JSON-serialise cleanly — nothing in + # ``args`` / ``argsText`` / ``result`` should accidentally be a + # non-JSON-safe value (datetime, set, custom class). + assert json.dumps(row.content) + + usage_count = ( + await db_session.execute( + select(func.count()) + .select_from(TokenUsage) + .where(TokenUsage.message_id == msg_id) + ) + ).scalar_one() + assert usage_count == 1 + + usage = ( + await db_session.execute( + select(TokenUsage).where(TokenUsage.message_id == msg_id) + ) + ).scalar_one() + assert usage.usage_type == "chat" + assert usage.prompt_tokens == 200 + assert usage.completion_tokens == 80 + assert usage.total_tokens == 280 + assert usage.cost_micros == 22222 + assert usage.thread_id == thread_id + assert usage.search_space_id == search_space_id + + +# --------------------------------------------------------------------------- +# (b) FE appendMessage after server finalize +# --------------------------------------------------------------------------- + + +class TestAppendMessageRecoveryAfterFinalize: + async def test_returns_server_content_and_does_not_duplicate_token_usage( + self, + db_session, + db_user, + db_thread, + db_search_space, + patched_shielded_session, + bypass_permission_checks, + ): + """FE's stale ``appendMessage`` after server finalize. + + The frontend used to be the authoritative writer for assistant + ``content``. Now the server is. When the legacy FE round-trip + fires *after* the server has already finalized: + + * the route's INSERT trips the (thread_id, turn_id, role) + partial unique index from migration 141, + * the recovery branch fetches the existing row and returns + *its* ``content`` — discarding the FE payload — so the + history loader reads the rich server payload (full tool + args, argsText, langchainToolCallId, etc.) on next page + reload, + * the route's optional ``token_usage`` insert is keyed on the + partial unique index from migration 142 so it silently + no-ops if the server already wrote one. + """ + thread_id = db_thread.id + user_id_str = str(db_user.id) + search_space_id = db_search_space.id + turn_id = f"{thread_id}:fe_late_append" + + # Step 1: server stream completes. Server-built rich content is + # finalized. + msg_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + assert msg_id is not None + + server_content = _build_tool_heavy_content() + await finalize_assistant_turn( + message_id=msg_id, + chat_id=thread_id, + search_space_id=search_space_id, + user_id=user_id_str, + turn_id=turn_id, + content=server_content, + accumulator=_accumulator_with_one_call(), + ) + + # Step 2: simulate the legacy FE ``appendMessage`` round-trip + # arriving with stale, lossy content (missing tool args, etc.) + # plus a ``token_usage`` body. + fe_stale_content = [ + {"type": "text", "text": "Found 2 files: a.txt and b.txt."}, + ] + fe_request_body = { + "role": "assistant", + "content": fe_stale_content, + "turn_id": turn_id, + "token_usage": { + "prompt_tokens": 999, + "completion_tokens": 999, + "total_tokens": 1998, + "cost_micros": 88888, + "usage": {"any": "thing"}, + "call_details": {"calls": []}, + }, + } + request = _FakeRequest(fe_request_body) + + # ``db_user`` is bound to ``db_session``. The route's + # IntegrityError branch calls ``session.rollback()``, which + # expires every ORM row attached to the session including this + # user — historically causing ``user.id`` to lazy-load + # out-of-greenlet and crash the request with ``MissingGreenlet`` + # (observed in production logs at /api/v1/threads/531/messages + # 2026-05-04). The route now captures ``user.id`` to a primitive + # UUID at the top of the handler, so the rollback can't reach + # it. Pass the *attached* user here on purpose — that's the + # production scenario, and this test is the regression guard + # against that bug returning. + response = await new_chat_routes.append_message( + thread_id=thread_id, + request=request, + session=db_session, + user=db_user, + ) + + # Response must echo the SERVER's rich payload, not the FE's + # stale snapshot. This is the user-visible part of the + # contract: history reload + ThreadHistoryAdapter.append both + # read from the same authoritative source. + assert response.id == msg_id + assert response.role == NewChatMessageRole.ASSISTANT + assert response.turn_id == turn_id + assert response.content == server_content + assert response.content != fe_stale_content + + # The on-disk row must agree with the response. + row = await db_session.get(NewChatMessage, msg_id) + await db_session.refresh(row) + assert row.content == server_content + + # ``token_usage``: exactly one row, with the *server's* values + # (the FE's much larger token counts must not have overwritten + # them). + usage_count = ( + await db_session.execute( + select(func.count()) + .select_from(TokenUsage) + .where(TokenUsage.message_id == msg_id) + ) + ).scalar_one() + assert usage_count == 1 + + usage = ( + await db_session.execute( + select(TokenUsage).where(TokenUsage.message_id == msg_id) + ) + ).scalar_one() + assert usage.cost_micros == 22222 # Server's value, not 88888 + assert usage.total_tokens == 280 # Server's value, not 1998 + + async def test_legacy_fe_first_appendmessage_then_server_no_dupe( + self, + db_session, + db_user, + db_thread, + db_search_space, + patched_shielded_session, + bypass_permission_checks, + ): + """Inverse race: legacy FE writes first, server finalize second. + + Some clients still post ``appendMessage`` before the streaming + ``finally`` runs. The contract is symmetric: whichever writer + loses the (thread_id, turn_id, role) race silently lets the + winner keep its content. In particular the *server's* + finalize must NOT raise — it must look up the existing row and + UPDATE its content with the server-built payload (which is + always richer/more authoritative than whatever the FE + snapshot held). + """ + thread_id = db_thread.id + user_id_str = str(db_user.id) + search_space_id = db_search_space.id + turn_id = f"{thread_id}:fe_first" + + # Step 1: legacy FE appendMessage lands first. No prior shell + # row exists; the route does the INSERT itself. + fe_request_body = { + "role": "assistant", + "content": [{"type": "text", "text": "early FE write"}], + "turn_id": turn_id, + } + fe_response = await new_chat_routes.append_message( + thread_id=thread_id, + request=_FakeRequest(fe_request_body), + session=db_session, + user=db_user, + ) + assert fe_response.role == NewChatMessageRole.ASSISTANT + + # Step 2: server stream's persist_assistant_shell now races + # behind. It must adopt the existing row id, not raise. + adopted_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + assert adopted_id == fe_response.id + + # Step 3: server finalize then overwrites the FE's stub with + # the rich content (which is the correct, more authoritative + # payload). + server_content = _build_tool_heavy_content() + await finalize_assistant_turn( + message_id=adopted_id, + chat_id=thread_id, + search_space_id=search_space_id, + user_id=user_id_str, + turn_id=turn_id, + content=server_content, + accumulator=_accumulator_with_one_call(), + ) + + # Final state: one row, server content, one token_usage row. + msg_count = ( + await db_session.execute( + select(func.count()) + .select_from(NewChatMessage) + .where( + NewChatMessage.thread_id == thread_id, + NewChatMessage.turn_id == turn_id, + NewChatMessage.role == NewChatMessageRole.ASSISTANT, + ) + ) + ).scalar_one() + assert msg_count == 1 + + row = await db_session.get(NewChatMessage, adopted_id) + await db_session.refresh(row) + assert row.content == server_content + + usage_count = ( + await db_session.execute( + select(func.count()) + .select_from(TokenUsage) + .where(TokenUsage.message_id == adopted_id) + ) + ).scalar_one() + assert usage_count == 1 + + async def test_appendmessage_without_turn_id_legacy_400( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + bypass_permission_checks, + ): + """Defensive: a bare appendMessage with no turn_id and no + existing row is just a normal INSERT — must succeed. But if a + row with the same role already exists in this thread *without* + turn_id collisions, the route should fall through to the + legacy 400 path on a foreign-key / unrelated IntegrityError + (we don't ship that bug today, but pin the behaviour so a + future schema change can't silently regress it). + """ + thread_id = db_thread.id + + # Bare appendMessage with no turn_id — should just succeed + # without invoking the recovery branch. + ok_response = await new_chat_routes.append_message( + thread_id=thread_id, + request=_FakeRequest( + { + "role": "user", + "content": [{"type": "text", "text": "hi"}], + } + ), + session=db_session, + user=db_user, + ) + assert ok_response.role == NewChatMessageRole.USER + assert ok_response.turn_id is None + + # Sanity: the route did NOT silently swallow the missing + # turn_id by routing through the unique-index recovery branch + # — it took the happy path. + msg_count = ( + await db_session.execute( + select(func.count()) + .select_from(NewChatMessage) + .where( + NewChatMessage.thread_id == thread_id, + NewChatMessage.role == NewChatMessageRole.USER, + ) + ) + ).scalar_one() + assert msg_count == 1 diff --git a/surfsense_backend/tests/integration/chat/test_message_id_sse.py b/surfsense_backend/tests/integration/chat/test_message_id_sse.py new file mode 100644 index 000000000..8fc935eaa --- /dev/null +++ b/surfsense_backend/tests/integration/chat/test_message_id_sse.py @@ -0,0 +1,332 @@ +"""Integration tests for the SSE-based message ID handshake. + +The streaming generators (``stream_new_chat`` / ``stream_resume_chat``) +emit two new events after their respective persistence helpers resolve +the canonical ``new_chat_messages.id``: + +* ``data-user-message-id`` — emitted only by ``stream_new_chat``, + AFTER ``persist_user_turn`` and BEFORE any LLM streaming. +* ``data-assistant-message-id`` — emitted by both + ``stream_new_chat`` and ``stream_resume_chat``, AFTER + ``persist_assistant_shell`` and BEFORE any LLM streaming. + +The frontend renames its optimistic ``msg-user-XXX`` / +``msg-assistant-XXX`` placeholder ids to ``msg-{db_id}`` upon receiving +these events. This test suite anchors three contracts: + +1. ``format_data`` produces SSE bytes in the precise shape + ``data: {"type":"data-","data":{...}}\\n\\n`` that the FE's + ``readSSEStream`` consumer parses (matches ``surfsense_web/lib/chat/streaming-state.ts``). +2. The ``message_id`` carried in the SSE payload exactly equals the + primary key the persistence helper inserted into + ``new_chat_messages`` — so the FE rename produces ``msg-{real_pk}``, + which in turn unlocks DB-id-gated UI (comments, edit-from-message). +3. The same ``message_id`` is used for the ``token_usage.message_id`` + foreign key, so ``finalize_assistant_turn``'s row binds correctly. + +Direct end-to-end testing of ``stream_new_chat`` requires a fully +mocked agent + LLM stack (out-of-scope here); those flows are covered +by the harness-driven integration tests under +``tests/integration/agents/new_chat/`` plus the assertion in +``test_persistence.py`` that the helpers themselves return ``int`` +ids. The contracts above close the remaining gap between the persist +helpers and the bytes that ship to the FE. +""" + +from __future__ import annotations + +import json +from contextlib import asynccontextmanager + +import pytest +import pytest_asyncio +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import ( + ChatVisibility, + NewChatMessage, + NewChatMessageRole, + NewChatThread, + SearchSpace, + User, +) +from app.services.new_streaming_service import VercelStreamingService +from app.tasks.chat import persistence as persistence_module +from app.tasks.chat.persistence import ( + persist_assistant_shell, + persist_user_turn, +) + +pytestmark = pytest.mark.integration + + +# --------------------------------------------------------------------------- +# Fixtures (mirror test_persistence.py) +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture +async def db_thread( + db_session: AsyncSession, db_user: User, db_search_space: SearchSpace +) -> NewChatThread: + thread = NewChatThread( + title="Test Chat", + search_space_id=db_search_space.id, + created_by_id=db_user.id, + visibility=ChatVisibility.PRIVATE, + ) + db_session.add(thread) + await db_session.flush() + return thread + + +@pytest.fixture +def patched_shielded_session(monkeypatch, db_session: AsyncSession): + """Route persistence helpers to the test's savepoint-bound session.""" + + @asynccontextmanager + async def _fake_shielded_session(): + yield db_session + + monkeypatch.setattr( + persistence_module, + "shielded_async_session", + _fake_shielded_session, + ) + return db_session + + +# --------------------------------------------------------------------------- +# (1) SSE byte-shape contract +# --------------------------------------------------------------------------- + + +def _parse_sse_data_line(blob: str) -> dict: + """Unwrap a single ``data: \\n\\n`` SSE frame. + + Raises if there's more than one frame or the prefix is wrong — keeps + the parser strict so a regression in ``format_data`` produces a + test failure here, not in a downstream consumer. + """ + assert blob.endswith("\n\n"), f"missing terminator: {blob!r}" + line = blob.removesuffix("\n\n") + assert line.startswith("data: "), f"missing data prefix: {line!r}" + return json.loads(line.removeprefix("data: ")) + + +class TestSSEByteShape: + def test_data_user_message_id_byte_shape(self): + """``format_data("user-message-id", {...})`` must produce the + exact wire format the FE's + ``readSSEStream`` -> ``data-user-message-id`` case parses. + """ + svc = VercelStreamingService() + blob = svc.format_data( + "user-message-id", + {"message_id": 1843, "turn_id": "533:1762900000000"}, + ) + envelope = _parse_sse_data_line(blob) + assert envelope == { + "type": "data-user-message-id", + "data": {"message_id": 1843, "turn_id": "533:1762900000000"}, + } + + def test_data_assistant_message_id_byte_shape(self): + svc = VercelStreamingService() + blob = svc.format_data( + "assistant-message-id", + {"message_id": 1844, "turn_id": "533:1762900000000"}, + ) + envelope = _parse_sse_data_line(blob) + assert envelope == { + "type": "data-assistant-message-id", + "data": {"message_id": 1844, "turn_id": "533:1762900000000"}, + } + + +# --------------------------------------------------------------------------- +# (2) Helper-id <-> DB-pk coherence +# --------------------------------------------------------------------------- + + +class TestHandshakeIdMatchesDB: + """The SSE handshake's correctness hinges on the integer in + ``data-{user,assistant}-message-id`` being the EXACT primary key + the persistence helper inserted. If they ever diverge, the FE + rename produces ``msg-{wrong_id}``, comments break (regex match + fails), and downstream features (edit, regenerate) target the + wrong row. Anchor it here. + """ + + async def test_user_message_id_matches_new_chat_messages_pk( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + ): + thread_id = db_thread.id + user_id_str = str(db_user.id) + turn_id = f"{thread_id}:9000" + + # The streaming generator passes this same value into + # ``streaming_service.format_data("user-message-id", {...})``. + msg_id_from_helper = await persist_user_turn( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + user_query="hello", + ) + assert isinstance(msg_id_from_helper, int) + + # Look up the row the helper inserted via + # ``(thread_id, turn_id, role)`` — the same composite the FE + # uses to identify a turn — and confirm the PK matches. + row = ( + await db_session.execute( + select(NewChatMessage).where( + NewChatMessage.thread_id == thread_id, + NewChatMessage.turn_id == turn_id, + NewChatMessage.role == NewChatMessageRole.USER, + ) + ) + ).scalar_one() + assert row.id == msg_id_from_helper + + # The byte-stream the FE actually receives — confirms the + # round-trip from the helper return value to the SSE payload. + svc = VercelStreamingService() + envelope = _parse_sse_data_line( + svc.format_data( + "user-message-id", + {"message_id": msg_id_from_helper, "turn_id": turn_id}, + ) + ) + assert envelope["data"]["message_id"] == row.id + + async def test_assistant_message_id_matches_new_chat_messages_pk( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + ): + thread_id = db_thread.id + user_id_str = str(db_user.id) + turn_id = f"{thread_id}:9100" + + msg_id_from_helper = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + assert isinstance(msg_id_from_helper, int) + + row = ( + await db_session.execute( + select(NewChatMessage).where( + NewChatMessage.thread_id == thread_id, + NewChatMessage.turn_id == turn_id, + NewChatMessage.role == NewChatMessageRole.ASSISTANT, + ) + ) + ).scalar_one() + assert row.id == msg_id_from_helper + + svc = VercelStreamingService() + envelope = _parse_sse_data_line( + svc.format_data( + "assistant-message-id", + {"message_id": msg_id_from_helper, "turn_id": turn_id}, + ) + ) + assert envelope["data"]["message_id"] == row.id + + async def test_handshake_ids_for_full_turn_are_distinct_and_paired( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + ): + """Sanity: a full new-chat turn's two SSE events carry two + DIFFERENT ids (user row PK ≠ assistant row PK), both anchored + to the SAME ``turn_id`` in the DB. This pairing is what + ``finalize_assistant_turn`` and the regenerate / edit flows + rely on. + """ + thread_id = db_thread.id + user_id_str = str(db_user.id) + turn_id = f"{thread_id}:9200" + + user_msg_id = await persist_user_turn( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + user_query="hi", + ) + assistant_msg_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + assert user_msg_id is not None and assistant_msg_id is not None + assert user_msg_id != assistant_msg_id + + rows = ( + ( + await db_session.execute( + select(NewChatMessage) + .where( + NewChatMessage.thread_id == thread_id, + NewChatMessage.turn_id == turn_id, + ) + .order_by(NewChatMessage.id) + ) + ) + .scalars() + .all() + ) + assert len(rows) == 2 + ids_by_role = {r.role: r.id for r in rows} + assert ids_by_role[NewChatMessageRole.USER] == user_msg_id + assert ids_by_role[NewChatMessageRole.ASSISTANT] == assistant_msg_id + + +# --------------------------------------------------------------------------- +# (3) Parse helpers used by the FE — sanity-check our payload shape +# --------------------------------------------------------------------------- + + +class TestPayloadShapeMatchesFEReader: + """The FE's ``readStreamedMessageId`` (in + ``surfsense_web/lib/chat/stream-side-effects.ts``) requires: + + * ``message_id`` is a ``number`` (rejects null / string / NaN). + * ``turn_id`` is an optional non-empty string (else it's coerced + to ``null``). + + These tests exercise the BE side of that contract by inspecting + ``format_data`` output shapes that the FE consumes verbatim. + """ + + def test_message_id_is_serialised_as_a_json_number(self): + svc = VercelStreamingService() + envelope = _parse_sse_data_line( + svc.format_data("user-message-id", {"message_id": 42, "turn_id": "t"}) + ) + assert isinstance(envelope["data"]["message_id"], int) + assert envelope["data"]["message_id"] == 42 + + def test_turn_id_round_trips_as_string(self): + svc = VercelStreamingService() + # The actual format used in production: f"{chat_id}:{int(time.time()*1000)}" + production_turn_id = "533:1762900000000" + envelope = _parse_sse_data_line( + svc.format_data( + "assistant-message-id", + {"message_id": 1, "turn_id": production_turn_id}, + ) + ) + assert envelope["data"]["turn_id"] == production_turn_id diff --git a/surfsense_backend/tests/integration/chat/test_persistence.py b/surfsense_backend/tests/integration/chat/test_persistence.py new file mode 100644 index 000000000..66a04772e --- /dev/null +++ b/surfsense_backend/tests/integration/chat/test_persistence.py @@ -0,0 +1,747 @@ +"""Integration tests for ``app.tasks.chat.persistence``. + +Verifies the DB-side guarantees the streaming chat task relies on: + +* ``persist_assistant_shell`` is idempotent against the + ``(thread_id, turn_id, ASSISTANT)`` partial unique index from + migration 141. Two calls with the same ``turn_id`` return the SAME + ``message_id`` and never create a duplicate ``new_chat_messages`` row. +* ``finalize_assistant_turn`` writes a status-marker payload when given + empty content, never raises, and is safe to call twice on the same + ``message_id`` — the partial unique index from migration 142 + (``uq_token_usage_message_id``) prevents the second insert from + producing a duplicate ``token_usage`` row. +* The same ``ON CONFLICT DO NOTHING`` invariant covers the cross-writer + race where ``finalize_assistant_turn`` and the ``append_message`` + recovery branch both target the same ``message_id``. + +All tests run inside the conftest's outer-transaction-with-savepoint +fixture so commits inside the helpers (which open their own +``shielded_async_session``) are released as savepoints and rolled back +at test end. We monkey-patch ``shielded_async_session`` to yield the +same pooled test session so the integration transaction stays +in-scope. +""" + +from __future__ import annotations + +from contextlib import asynccontextmanager + +import pytest +import pytest_asyncio +from sqlalchemy import func, select +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import ( + ChatVisibility, + NewChatMessage, + NewChatMessageRole, + NewChatThread, + SearchSpace, + TokenUsage, + User, +) +from app.services.token_tracking_service import TurnTokenAccumulator +from app.tasks.chat import persistence as persistence_module +from app.tasks.chat.persistence import ( + finalize_assistant_turn, + persist_assistant_shell, + persist_user_turn, +) + +pytestmark = pytest.mark.integration + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture +async def db_thread( + db_session: AsyncSession, db_user: User, db_search_space: SearchSpace +) -> NewChatThread: + thread = NewChatThread( + title="Test Chat", + search_space_id=db_search_space.id, + created_by_id=db_user.id, + visibility=ChatVisibility.PRIVATE, + ) + db_session.add(thread) + await db_session.flush() + return thread + + +@pytest.fixture +def patched_shielded_session(monkeypatch, db_session: AsyncSession): + """Route persistence helpers to the test's savepoint-bound session. + + The persistence helpers use ``async with shielded_async_session() as + ws`` and call ``ws.commit()`` internally. Inside the conftest's + ``join_transaction_mode="create_savepoint"`` setup, those commits + release a SAVEPOINT instead of committing the outer transaction — + so the test session can see helper-staged rows immediately and the + outer rollback at end of test wipes them. + """ + + @asynccontextmanager + async def _fake_shielded_session(): + yield db_session + # Do NOT close — the outer fixture owns the session lifecycle. + + monkeypatch.setattr( + persistence_module, + "shielded_async_session", + _fake_shielded_session, + ) + return db_session + + +def _accumulator_with_one_call() -> TurnTokenAccumulator: + acc = TurnTokenAccumulator() + acc.add( + model="gpt-4o-mini", + prompt_tokens=100, + completion_tokens=50, + total_tokens=150, + cost_micros=12345, + ) + return acc + + +async def _count_assistant_rows( + session: AsyncSession, thread_id: int, turn_id: str +) -> int: + result = await session.execute( + select(func.count()) + .select_from(NewChatMessage) + .where( + NewChatMessage.thread_id == thread_id, + NewChatMessage.turn_id == turn_id, + NewChatMessage.role == NewChatMessageRole.ASSISTANT, + ) + ) + return int(result.scalar_one()) + + +async def _count_token_usage_rows(session: AsyncSession, message_id: int) -> int: + result = await session.execute( + select(func.count()) + .select_from(TokenUsage) + .where(TokenUsage.message_id == message_id) + ) + return int(result.scalar_one()) + + +# --------------------------------------------------------------------------- +# persist_assistant_shell +# --------------------------------------------------------------------------- + + +class TestPersistAssistantShell: + async def test_first_call_inserts_empty_shell_and_returns_id( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + ): + # Capture primitive ids before any persistence helper runs: + # the helpers commit/rollback the shared test session, which + # can detach ORM rows mid-test. + thread_id = db_thread.id + user_id_str = str(db_user.id) + turn_id = f"{thread_id}:1000" + + msg_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + assert msg_id is not None and isinstance(msg_id, int) + + row = await db_session.get(NewChatMessage, msg_id) + assert row is not None + assert row.thread_id == thread_id + assert row.role == NewChatMessageRole.ASSISTANT + assert row.turn_id == turn_id + # Empty shell payload — finalize_assistant_turn overwrites later. + assert row.content == [{"type": "text", "text": ""}] + + async def test_second_call_with_same_turn_id_returns_same_id( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + ): + # Capture primitive ids before any persistence helper runs: + # the helpers commit/rollback the shared test session, which + # can detach ORM rows mid-test. + thread_id = db_thread.id + user_id_str = str(db_user.id) + turn_id = f"{thread_id}:2000" + + first_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + second_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + + assert first_id is not None + assert first_id == second_id + # Exactly one row in the DB for this turn. + assert await _count_assistant_rows(db_session, thread_id, turn_id) == 1 + + async def test_missing_turn_id_returns_none( + self, + db_user, + db_thread, + patched_shielded_session, + ): + thread_id = db_thread.id + user_id_str = str(db_user.id) + + msg_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id="", + ) + assert msg_id is None + + async def test_after_persist_user_turn_resolves_assistant_id( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + ): + thread_id = db_thread.id + user_id_str = str(db_user.id) + turn_id = f"{thread_id}:3000" + + # The streaming layer always calls persist_user_turn first, so + # smoke-test the canonical sequence. + user_msg_id = await persist_user_turn( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + user_query="hello", + ) + assert isinstance(user_msg_id, int) + + msg_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + assert msg_id is not None + # User row + assistant shell row = 2 rows for this turn. + result = await db_session.execute( + select(func.count()) + .select_from(NewChatMessage) + .where( + NewChatMessage.thread_id == thread_id, + NewChatMessage.turn_id == turn_id, + ) + ) + assert result.scalar_one() == 2 + + async def test_double_call_with_same_turn_id_uses_on_conflict( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + ): + """Verifies the ON CONFLICT DO NOTHING path on the assistant + shell does not raise ``IntegrityError`` even when the second + writer races the first within a tight loop. ``test_second_call_with_same_turn_id_returns_same_id`` + already covers the same-id semantics; this test additionally + asserts neither call raises so the debugger's + ``raise-on-IntegrityError`` setting won't pause the streaming + path under contention. + """ + thread_id = db_thread.id + user_id_str = str(db_user.id) + turn_id = f"{thread_id}:3500" + + # Both calls go through ``pg_insert(...).on_conflict_do_nothing``; + # the second one returns RETURNING=∅ and falls into the SELECT + # branch. Neither path raises. + first_id = await persist_assistant_shell( + chat_id=thread_id, user_id=user_id_str, turn_id=turn_id + ) + second_id = await persist_assistant_shell( + chat_id=thread_id, user_id=user_id_str, turn_id=turn_id + ) + assert first_id is not None + assert first_id == second_id + + +# --------------------------------------------------------------------------- +# persist_user_turn +# --------------------------------------------------------------------------- + + +class TestPersistUserTurn: + async def test_returns_message_id_on_first_insert( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + ): + thread_id = db_thread.id + user_id_str = str(db_user.id) + turn_id = f"{thread_id}:8000" + + msg_id = await persist_user_turn( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + user_query="hello", + ) + assert isinstance(msg_id, int) and msg_id > 0 + + row = await db_session.get(NewChatMessage, msg_id) + assert row is not None + assert row.thread_id == thread_id + assert row.role == NewChatMessageRole.USER + assert row.turn_id == turn_id + assert row.content == [{"type": "text", "text": "hello"}] + + async def test_returns_existing_id_on_conflict( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + ): + thread_id = db_thread.id + user_id_str = str(db_user.id) + turn_id = f"{thread_id}:8100" + + first_id = await persist_user_turn( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + user_query="hello", + ) + # Second call simulates a legacy FE ``appendMessage`` racing the + # SSE stream: ON CONFLICT DO NOTHING short-circuits at the DB + # level, the helper recovers the existing id via SELECT, and + # crucially does NOT raise ``IntegrityError`` (the debugger + # would otherwise pause on it). + second_id = await persist_user_turn( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + user_query="ignored on conflict", + ) + assert first_id is not None + assert first_id == second_id + + # Exactly one user row for this turn. + count = await db_session.execute( + select(func.count()) + .select_from(NewChatMessage) + .where( + NewChatMessage.thread_id == thread_id, + NewChatMessage.turn_id == turn_id, + NewChatMessage.role == NewChatMessageRole.USER, + ) + ) + assert count.scalar_one() == 1 + + async def test_embeds_mentioned_documents_part( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + ): + """The full ``{id, title, document_type}`` triple forwarded by + the FE must round-trip into a single ``mentioned-documents`` + ContentPart on the persisted user message — the history loader + renders the chips on reload from this part directly. + """ + thread_id = db_thread.id + user_id_str = str(db_user.id) + turn_id = f"{thread_id}:8200" + + mentioned = [ + {"id": 11, "title": "Alpha", "document_type": "GENERAL"}, + {"id": 22, "title": "Beta", "document_type": "GENERAL"}, + ] + msg_id = await persist_user_turn( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + user_query="hello", + mentioned_documents=mentioned, + ) + assert isinstance(msg_id, int) + + row = await db_session.get(NewChatMessage, msg_id) + assert row is not None + # Content is a 2-part list: text + mentioned-documents. + assert isinstance(row.content, list) + assert row.content[0] == {"type": "text", "text": "hello"} + assert row.content[1] == { + "type": "mentioned-documents", + "documents": [ + {"id": 11, "title": "Alpha", "document_type": "GENERAL"}, + {"id": 22, "title": "Beta", "document_type": "GENERAL"}, + ], + } + + async def test_skips_mentioned_documents_when_empty_or_invalid( + self, + db_session, + db_user, + db_thread, + patched_shielded_session, + ): + """Empty list and entries missing required fields are dropped; + a ``mentioned-documents`` part is only emitted when at least + one normalised entry survived. + """ + thread_id = db_thread.id + user_id_str = str(db_user.id) + turn_id_empty = f"{thread_id}:8300" + turn_id_invalid = f"{thread_id}:8301" + + msg_id_empty = await persist_user_turn( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id_empty, + user_query="hi", + mentioned_documents=[], + ) + assert isinstance(msg_id_empty, int) + row_empty = await db_session.get(NewChatMessage, msg_id_empty) + assert row_empty is not None + assert row_empty.content == [{"type": "text", "text": "hi"}] + + # Each entry missing one required field — all skipped. + msg_id_invalid = await persist_user_turn( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id_invalid, + user_query="hi", + mentioned_documents=[ + {"title": "no id", "document_type": "GENERAL"}, # missing id + {"id": 99, "document_type": "GENERAL"}, # missing title + {"id": 100, "title": "no type"}, # missing document_type + ], + ) + assert isinstance(msg_id_invalid, int) + row_invalid = await db_session.get(NewChatMessage, msg_id_invalid) + assert row_invalid is not None + assert row_invalid.content == [{"type": "text", "text": "hi"}] + + async def test_missing_turn_id_returns_none( + self, + db_user, + db_thread, + patched_shielded_session, + ): + thread_id = db_thread.id + user_id_str = str(db_user.id) + + msg_id = await persist_user_turn( + chat_id=thread_id, + user_id=user_id_str, + turn_id="", + user_query="hello", + ) + assert msg_id is None + + +# --------------------------------------------------------------------------- +# finalize_assistant_turn +# --------------------------------------------------------------------------- + + +class TestFinalizeAssistantTurn: + async def test_writes_content_and_token_usage( + self, + db_session, + db_user, + db_thread, + db_search_space, + patched_shielded_session, + ): + thread_id = db_thread.id + user_id_uuid = db_user.id + user_id_str = str(user_id_uuid) + search_space_id = db_search_space.id + turn_id = f"{thread_id}:4000" + + msg_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + assert msg_id is not None + + rich_content = [ + {"type": "text", "text": "Hello world"}, + { + "type": "tool-call", + "toolCallId": "call_x", + "toolName": "ls", + "args": {"path": "/"}, + "argsText": '{\n "path": "/"\n}', + "result": {"files": []}, + "langchainToolCallId": "lc_x", + }, + ] + await finalize_assistant_turn( + message_id=msg_id, + chat_id=thread_id, + search_space_id=search_space_id, + user_id=user_id_str, + turn_id=turn_id, + content=rich_content, + accumulator=_accumulator_with_one_call(), + ) + + row = await db_session.get(NewChatMessage, msg_id) + await db_session.refresh(row) + assert row.content == rich_content + + # Exactly one token_usage row keyed on this message_id. + usage_rows = ( + ( + await db_session.execute( + select(TokenUsage).where(TokenUsage.message_id == msg_id) + ) + ) + .scalars() + .all() + ) + assert len(usage_rows) == 1 + usage = usage_rows[0] + assert usage.usage_type == "chat" + assert usage.prompt_tokens == 100 + assert usage.completion_tokens == 50 + assert usage.total_tokens == 150 + assert usage.cost_micros == 12345 + assert usage.thread_id == thread_id + assert usage.search_space_id == search_space_id + + async def test_empty_content_writes_status_marker( + self, + db_session, + db_user, + db_thread, + db_search_space, + patched_shielded_session, + ): + thread_id = db_thread.id + user_id_str = str(db_user.id) + search_space_id = db_search_space.id + turn_id = f"{thread_id}:5000" + + msg_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + assert msg_id is not None + + # Pure tool-call turn that aborted before any output, or + # interrupt before any event arrived — empty list. + await finalize_assistant_turn( + message_id=msg_id, + chat_id=thread_id, + search_space_id=search_space_id, + user_id=user_id_str, + turn_id=turn_id, + content=[], + accumulator=None, + ) + + row = await db_session.get(NewChatMessage, msg_id) + await db_session.refresh(row) + assert row.content == [{"type": "status", "text": "(no text response)"}] + + async def test_double_call_safe_via_on_conflict( + self, + db_session, + db_user, + db_thread, + db_search_space, + patched_shielded_session, + ): + thread_id = db_thread.id + user_id_str = str(db_user.id) + search_space_id = db_search_space.id + turn_id = f"{thread_id}:6000" + + msg_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + assert msg_id is not None + + first_acc = _accumulator_with_one_call() + await finalize_assistant_turn( + message_id=msg_id, + chat_id=thread_id, + search_space_id=search_space_id, + user_id=user_id_str, + turn_id=turn_id, + content=[{"type": "text", "text": "first finalize"}], + accumulator=first_acc, + ) + + # Simulate a follow-up finalize (e.g., resume retry within the + # shielded finally block firing twice). Different content, but + # ON CONFLICT DO NOTHING on token_usage means the cost from the + # first finalize stays authoritative. + second_acc = TurnTokenAccumulator() + second_acc.add( + model="gpt-4o-mini", + prompt_tokens=999, + completion_tokens=999, + total_tokens=1998, + cost_micros=99999, + ) + await finalize_assistant_turn( + message_id=msg_id, + chat_id=thread_id, + search_space_id=search_space_id, + user_id=user_id_str, + turn_id=turn_id, + content=[{"type": "text", "text": "second finalize"}], + accumulator=second_acc, + ) + + # Content was overwritten by the second UPDATE. + row = await db_session.get(NewChatMessage, msg_id) + await db_session.refresh(row) + assert row.content == [{"type": "text", "text": "second finalize"}] + + # But token_usage stayed at exactly one row, preserving the + # first finalize's authoritative cost. + assert await _count_token_usage_rows(db_session, msg_id) == 1 + usage = ( + await db_session.execute( + select(TokenUsage).where(TokenUsage.message_id == msg_id) + ) + ).scalar_one() + assert usage.cost_micros == 12345 # First finalize's value + + async def test_append_message_style_insert_after_finalize_no_dupe( + self, + db_session, + db_user, + db_thread, + db_search_space, + patched_shielded_session, + ): + """Cross-writer race: ``append_message`` arrives after ``finalize_assistant_turn``. + + Both target the same ``message_id``; the partial unique index + ``uq_token_usage_message_id`` (migration 142) makes the second + insert a no-op via ``ON CONFLICT DO NOTHING``. + """ + from sqlalchemy import text as sa_text + + thread_id = db_thread.id + user_uuid = db_user.id + user_id_str = str(user_uuid) + search_space_id = db_search_space.id + turn_id = f"{thread_id}:7000" + + msg_id = await persist_assistant_shell( + chat_id=thread_id, + user_id=user_id_str, + turn_id=turn_id, + ) + assert msg_id is not None + + await finalize_assistant_turn( + message_id=msg_id, + chat_id=thread_id, + search_space_id=search_space_id, + user_id=user_id_str, + turn_id=turn_id, + content=[{"type": "text", "text": "from server"}], + accumulator=_accumulator_with_one_call(), + ) + + # Now simulate the FE's append_message branch firing AFTER — + # the same INSERT ... ON CONFLICT DO NOTHING shape used by the + # route handler, keyed on the migration-142 partial unique + # index. + late_insert = ( + pg_insert(TokenUsage) + .values( + usage_type="chat", + prompt_tokens=42, + completion_tokens=42, + total_tokens=84, + cost_micros=1, + model_breakdown=None, + call_details=None, + thread_id=thread_id, + message_id=msg_id, + search_space_id=search_space_id, + user_id=user_uuid, + ) + .on_conflict_do_nothing( + index_elements=["message_id"], + index_where=sa_text("message_id IS NOT NULL"), + ) + ) + await db_session.execute(late_insert) + await db_session.flush() + + # Still exactly one row, with the original (server) cost value. + assert await _count_token_usage_rows(db_session, msg_id) == 1 + usage = ( + await db_session.execute( + select(TokenUsage).where(TokenUsage.message_id == msg_id) + ) + ).scalar_one() + assert usage.cost_micros == 12345 + + async def test_helper_never_raises_on_missing_message_id( + self, + db_session, + db_user, + db_thread, + db_search_space, + patched_shielded_session, + ): + thread_id = db_thread.id + user_id_str = str(db_user.id) + search_space_id = db_search_space.id + + # message_id that doesn't exist — finalize must log+return, + # never raise (called from shielded finally). + await finalize_assistant_turn( + message_id=999_999_999, + chat_id=thread_id, + search_space_id=search_space_id, + user_id=user_id_str, + turn_id="anything", + content=[{"type": "text", "text": "x"}], + accumulator=_accumulator_with_one_call(), + ) + # If we got here without an exception, the test passes. + # Sanity: no token_usage row created (FK to message would have + # been rejected anyway, but ON CONFLICT path may swallow + # FK errors as well; check directly). + assert await _count_token_usage_rows(db_session, 999_999_999) == 0 diff --git a/surfsense_backend/tests/unit/tasks/chat/test_content_builder.py b/surfsense_backend/tests/unit/tasks/chat/test_content_builder.py new file mode 100644 index 000000000..c317eba20 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/test_content_builder.py @@ -0,0 +1,526 @@ +"""Unit tests for ``AssistantContentBuilder``. + +Pins the in-memory ``ContentPart[]`` projection so the JSONB the server +persists matches what the frontend renders live (see +``surfsense_web/lib/chat/streaming-state.ts``). Every test asserts both +the structural shape of ``snapshot()`` and that the snapshot is +``json.dumps``-safe (the streaming finally block writes it directly to +``new_chat_messages.content`` without an explicit serialization round +trip). +""" + +from __future__ import annotations + +import json + +import pytest + +from app.tasks.chat.content_builder import AssistantContentBuilder + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _assert_jsonb_safe(parts: list[dict]) -> None: + """Sanity check: any snapshot must round-trip through ``json.dumps``.""" + serialized = json.dumps(parts) + assert json.loads(serialized) == parts + + +# --------------------------------------------------------------------------- +# Text turns +# --------------------------------------------------------------------------- + + +class TestTextOnly: + def test_single_text_block_collapses_consecutive_deltas(self): + b = AssistantContentBuilder() + b.on_text_start("text-1") + b.on_text_delta("text-1", "Hello") + b.on_text_delta("text-1", " ") + b.on_text_delta("text-1", "world") + b.on_text_end("text-1") + + snap = b.snapshot() + assert snap == [{"type": "text", "text": "Hello world"}] + assert not b.is_empty() + _assert_jsonb_safe(snap) + + def test_empty_text_start_end_pair_leaves_no_part(self): + # Mirrors the FE: a text-start without any deltas should + # not materialise an empty ``{"type":"text","text":""}`` part. + b = AssistantContentBuilder() + b.on_text_start("text-1") + b.on_text_end("text-1") + + assert b.snapshot() == [] + assert b.is_empty() + + def test_text_after_text_end_starts_fresh_part(self): + b = AssistantContentBuilder() + b.on_text_start("text-1") + b.on_text_delta("text-1", "first") + b.on_text_end("text-1") + + b.on_text_start("text-2") + b.on_text_delta("text-2", "second") + b.on_text_end("text-2") + + snap = b.snapshot() + assert snap == [ + {"type": "text", "text": "first"}, + {"type": "text", "text": "second"}, + ] + + +class TestReasoningThenText: + def test_reasoning_followed_by_text_yields_two_parts_in_order(self): + b = AssistantContentBuilder() + b.on_reasoning_start("r-1") + b.on_reasoning_delta("r-1", "Considering options...") + b.on_reasoning_end("r-1") + + b.on_text_start("text-1") + b.on_text_delta("text-1", "The answer is 42.") + b.on_text_end("text-1") + + snap = b.snapshot() + assert snap == [ + {"type": "reasoning", "text": "Considering options..."}, + {"type": "text", "text": "The answer is 42."}, + ] + _assert_jsonb_safe(snap) + + def test_text_delta_after_reasoning_implicitly_closes_reasoning(self): + # Mirrors FE ``appendText``: a text delta arriving while a + # reasoning part is "active" still produces a fresh text + # part, never appends into the reasoning block. + b = AssistantContentBuilder() + b.on_reasoning_start("r-1") + b.on_reasoning_delta("r-1", "thinking") + # No explicit reasoning_end — text delta should close it. + b.on_text_delta("text-1", "answer") + + snap = b.snapshot() + assert snap == [ + {"type": "reasoning", "text": "thinking"}, + {"type": "text", "text": "answer"}, + ] + + +# --------------------------------------------------------------------------- +# Tool calls +# --------------------------------------------------------------------------- + + +class TestToolHeavyTurn: + def test_full_tool_lifecycle_produces_complete_tool_call_part(self): + b = AssistantContentBuilder() + # Some narration before the tool fires. + b.on_text_start("text-1") + b.on_text_delta("text-1", "Searching...") + b.on_text_end("text-1") + + b.on_tool_input_start( + ui_id="call_run123", + tool_name="web_search", + langchain_tool_call_id="lc_tool_abc", + ) + b.on_tool_input_delta("call_run123", '{"query":') + b.on_tool_input_delta("call_run123", '"surfsense"}') + b.on_tool_input_available( + ui_id="call_run123", + tool_name="web_search", + args={"query": "surfsense"}, + langchain_tool_call_id="lc_tool_abc", + ) + b.on_tool_output_available( + ui_id="call_run123", + output={"status": "completed", "citations": {}}, + langchain_tool_call_id="lc_tool_abc", + ) + + snap = b.snapshot() + assert snap[0] == {"type": "text", "text": "Searching..."} + tool_part = snap[1] + assert tool_part["type"] == "tool-call" + assert tool_part["toolCallId"] == "call_run123" + assert tool_part["toolName"] == "web_search" + assert tool_part["args"] == {"query": "surfsense"} + # ``argsText`` is the pretty-printed final JSON, not the raw + # streaming buffer (FE ``stream-pipeline.ts:128``). + assert tool_part["argsText"] == json.dumps( + {"query": "surfsense"}, indent=2, ensure_ascii=False + ) + assert tool_part["langchainToolCallId"] == "lc_tool_abc" + assert tool_part["result"] == {"status": "completed", "citations": {}} + _assert_jsonb_safe(snap) + + def test_tool_input_available_without_prior_start_creates_card(self): + # Legacy / parity_v2-OFF path: tool-input-available may be + # emitted without a prior tool-input-start (no streamed + # tool_call_chunks). The card should still be created. + b = AssistantContentBuilder() + b.on_tool_input_available( + ui_id="call_run42", + tool_name="grep", + args={"pattern": "TODO"}, + langchain_tool_call_id="lc_x", + ) + b.on_tool_output_available( + ui_id="call_run42", + output={"matches": 3}, + langchain_tool_call_id="lc_x", + ) + + snap = b.snapshot() + assert len(snap) == 1 + part = snap[0] + assert part["type"] == "tool-call" + assert part["toolCallId"] == "call_run42" + assert part["args"] == {"pattern": "TODO"} + assert part["langchainToolCallId"] == "lc_x" + assert part["result"] == {"matches": 3} + + def test_tool_input_start_idempotent_for_same_ui_id(self): + # parity_v2: tool-input-start can fire from BOTH the chunk + # registration path AND the canonical ``on_tool_start`` path. + # The second call must not create a duplicate part. + b = AssistantContentBuilder() + b.on_tool_input_start("call_x", "ls", "lc_x") + b.on_tool_input_start("call_x", "ls", "lc_x") + snap = b.snapshot() + assert len(snap) == 1 + + def test_tool_input_delta_without_prior_start_is_silently_dropped(self): + b = AssistantContentBuilder() + b.on_tool_input_delta("call_unknown", '{"orphan": "delta"}') + assert b.snapshot() == [] + + def test_langchain_tool_call_id_backfills_only_when_absent(self): + b = AssistantContentBuilder() + b.on_tool_input_start("call_x", "ls", "lc_first") + # Late event must NOT clobber an already-set lc id. + b.on_tool_input_start("call_x", "ls", "lc_late") + snap = b.snapshot() + assert snap[0]["langchainToolCallId"] == "lc_first" + + def test_args_text_streaming_buffer_reflects_concatenation(self): + b = AssistantContentBuilder() + b.on_tool_input_start("call_x", "save_doc", "lc_y") + b.on_tool_input_delta("call_x", '{"title":') + b.on_tool_input_delta("call_x", '"Hi"}') + # Snapshot mid-stream should see the partial buffer (the FE + # tolerates invalid JSON and renders it as-is). + mid = b.snapshot() + assert mid[0]["argsText"] == '{"title":"Hi"}' + # Then tool-input-available replaces with pretty-printed. + b.on_tool_input_available( + "call_x", + "save_doc", + {"title": "Hi"}, + "lc_y", + ) + final = b.snapshot() + assert final[0]["argsText"] == json.dumps( + {"title": "Hi"}, indent=2, ensure_ascii=False + ) + + +# --------------------------------------------------------------------------- +# Thinking steps & separators +# --------------------------------------------------------------------------- + + +class TestThinkingSteps: + def test_first_thinking_step_unshifts_singleton_to_index_zero(self): + b = AssistantContentBuilder() + b.on_text_start("text-1") + b.on_text_delta("text-1", "Hello") + b.on_text_end("text-1") + + b.on_thinking_step("step-1", "Analyzing", "in_progress", ["item-a"]) + + snap = b.snapshot() + # Singleton goes to index 0 (FE ``updateThinkingSteps`` unshift). + assert snap[0]["type"] == "data-thinking-steps" + assert snap[0]["data"]["steps"] == [ + { + "id": "step-1", + "title": "Analyzing", + "status": "in_progress", + "items": ["item-a"], + } + ] + assert snap[1] == {"type": "text", "text": "Hello"} + + def test_subsequent_thinking_steps_mutate_the_singleton_in_place(self): + b = AssistantContentBuilder() + b.on_thinking_step("step-1", "Analyzing", "in_progress", []) + b.on_thinking_step("step-2", "Searching", "in_progress", ["q"]) + b.on_thinking_step("step-1", "Analyzing", "completed", ["done"]) + + snap = b.snapshot() + assert len([p for p in snap if p["type"] == "data-thinking-steps"]) == 1 + steps = snap[0]["data"]["steps"] + assert len(steps) == 2 + assert steps[0]["id"] == "step-1" + assert steps[0]["status"] == "completed" + assert steps[0]["items"] == ["done"] + assert steps[1]["id"] == "step-2" + + def test_thinking_step_with_text_continues_appending_to_text(self): + b = AssistantContentBuilder() + b.on_text_start("text-1") + b.on_text_delta("text-1", "first") + + # Thinking step inserts at index 0, bumps text idx from 0 to 1. + b.on_thinking_step("step-1", "Working", "in_progress", []) + b.on_text_delta("text-1", " second") + + snap = b.snapshot() + text_parts = [p for p in snap if p["type"] == "text"] + assert text_parts == [{"type": "text", "text": "first second"}] + + def test_thinking_step_without_id_is_dropped(self): + b = AssistantContentBuilder() + b.on_thinking_step("", "noop", "in_progress", None) + assert b.snapshot() == [] + assert b.is_empty() + + +class TestStepSeparators: + def test_separator_no_op_before_any_content(self): + b = AssistantContentBuilder() + b.on_step_separator() + assert b.snapshot() == [] + + def test_separator_after_text_appends_with_step_index_zero(self): + b = AssistantContentBuilder() + b.on_text_start("text-1") + b.on_text_delta("text-1", "first") + b.on_text_end("text-1") + + b.on_step_separator() + + snap = b.snapshot() + assert snap[-1] == { + "type": "data-step-separator", + "data": {"stepIndex": 0}, + } + + def test_consecutive_separators_collapse_to_one(self): + b = AssistantContentBuilder() + b.on_text_delta("text-1", "x") + b.on_step_separator() + b.on_step_separator() # No-op: previous part is already a separator. + snap = b.snapshot() + assert sum(1 for p in snap if p["type"] == "data-step-separator") == 1 + + def test_step_index_increments_across_separators(self): + b = AssistantContentBuilder() + b.on_text_delta("text-1", "a") + b.on_step_separator() + b.on_text_delta("text-2", "b") + b.on_step_separator() + snap = b.snapshot() + seps = [p for p in snap if p["type"] == "data-step-separator"] + assert [s["data"]["stepIndex"] for s in seps] == [0, 1] + + +# --------------------------------------------------------------------------- +# Interruption handling +# --------------------------------------------------------------------------- + + +class TestMarkInterrupted: + def test_running_tool_calls_get_state_aborted(self): + b = AssistantContentBuilder() + b.on_tool_input_start("call_a", "ls", "lc_a") + b.on_tool_input_available("call_a", "ls", {"path": "/"}, "lc_a") + # No tool-output-available — simulates client disconnect mid-tool. + + b.mark_interrupted() + + snap = b.snapshot() + assert snap[0]["state"] == "aborted" + assert "result" not in snap[0] + + def test_completed_tool_calls_are_not_marked_aborted(self): + b = AssistantContentBuilder() + b.on_tool_input_start("call_a", "ls", "lc_a") + b.on_tool_input_available("call_a", "ls", {"path": "/"}, "lc_a") + b.on_tool_output_available("call_a", {"files": []}, "lc_a") + + b.mark_interrupted() + + snap = b.snapshot() + assert "state" not in snap[0] + assert snap[0]["result"] == {"files": []} + + def test_open_text_block_keeps_accumulated_content(self): + b = AssistantContentBuilder() + b.on_text_start("text-1") + b.on_text_delta("text-1", "partial") + # No on_text_end — disconnect mid-stream. + + b.mark_interrupted() + + snap = b.snapshot() + assert snap == [{"type": "text", "text": "partial"}] + + +# --------------------------------------------------------------------------- +# is_empty / snapshot semantics +# --------------------------------------------------------------------------- + + +class TestIsEmpty: + def test_fresh_builder_is_empty(self): + assert AssistantContentBuilder().is_empty() + + def test_text_part_breaks_emptiness(self): + b = AssistantContentBuilder() + b.on_text_delta("text-1", "x") + assert not b.is_empty() + + def test_tool_call_breaks_emptiness(self): + b = AssistantContentBuilder() + b.on_tool_input_start("call_x", "ls", None) + assert not b.is_empty() + + def test_thinking_step_alone_does_not_break_emptiness(self): + # Mirrors the "status marker fallback" semantic: a turn that + # only emitted a thinking step before being interrupted should + # still be treated as empty for finalize_assistant_turn's + # status-marker substitution. + b = AssistantContentBuilder() + b.on_thinking_step("step-1", "Working", "in_progress", []) + assert b.is_empty() + + def test_step_separator_alone_does_not_break_emptiness(self): + b = AssistantContentBuilder() + # Force a separator (it would normally no-op without content, + # but we simulate the underlying state to verify is_empty is + # not fooled by a stray separator). + b.parts.append({"type": "data-step-separator", "data": {"stepIndex": 0}}) + assert b.is_empty() + + +class TestSnapshotSemantics: + def test_snapshot_is_deep_copied_so_mutations_do_not_leak(self): + b = AssistantContentBuilder() + b.on_tool_input_start("call_x", "ls", "lc_x") + b.on_tool_input_available("call_x", "ls", {"path": "/"}, "lc_x") + snap = b.snapshot() + # Mutate the returned snapshot — original should be untouched. + snap[0]["args"]["mutated"] = True + snap[0]["state"] = "tampered" + + again = b.snapshot() + assert "mutated" not in again[0]["args"] + assert "state" not in again[0] + + def test_snapshot_round_trips_through_json(self): + b = AssistantContentBuilder() + b.on_thinking_step("step-1", "Analyzing", "in_progress", ["item"]) + b.on_text_delta("text-1", "answer") + b.on_tool_input_start("call_x", "ls", "lc_x") + b.on_tool_input_available("call_x", "ls", {"path": "/"}, "lc_x") + b.on_tool_output_available("call_x", {"files": ["a.txt"]}, "lc_x") + b.on_step_separator() + snap = b.snapshot() + + encoded = json.dumps(snap) + assert json.loads(encoded) == snap + + +class TestStats: + """``stats()`` is the perf-log handle for [PERF] [stream_*] + finalize_payload lines. Pin the schema so an ops dashboard can + rely on these keys being present and meaningful. + """ + + def test_fresh_builder_reports_all_zeros(self): + b = AssistantContentBuilder() + s = b.stats() + assert s == { + "parts": 0, + "bytes": 2, # ``[]`` is two bytes + "text": 0, + "reasoning": 0, + "tool_calls": 0, + "tool_calls_completed": 0, + "tool_calls_aborted": 0, + "thinking_step_parts": 0, + "step_separators": 0, + } + + def test_counts_each_part_type_independently(self): + b = AssistantContentBuilder() + b.on_text_start("t1") + b.on_text_delta("t1", "hi") + b.on_text_end("t1") + b.on_reasoning_start("r1") + b.on_reasoning_delta("r1", "thinking") + b.on_reasoning_end("r1") + b.on_thinking_step("step-1", "Analyzing", "completed", ["item"]) + b.on_step_separator() + b.on_tool_input_start("call_done", "ls", "lc_done") + b.on_tool_input_available("call_done", "ls", {}, "lc_done") + b.on_tool_output_available("call_done", {"ok": True}, "lc_done") + b.on_tool_input_start("call_running", "rm", "lc_running") + b.on_tool_input_available("call_running", "rm", {}, "lc_running") + + s = b.stats() + assert s["text"] == 1 + assert s["reasoning"] == 1 + assert s["tool_calls"] == 2 + assert s["tool_calls_completed"] == 1 + assert s["tool_calls_aborted"] == 0 + assert s["thinking_step_parts"] == 1 + assert s["step_separators"] == 1 + assert s["parts"] == sum( + [ + s["text"], + s["reasoning"], + s["tool_calls"], + s["thinking_step_parts"], + s["step_separators"], + ] + ) + assert s["bytes"] > 0 + + def test_mark_interrupted_flips_running_calls_to_aborted_in_stats(self): + b = AssistantContentBuilder() + b.on_tool_input_start("call_done", "ls", "lc_done") + b.on_tool_input_available("call_done", "ls", {}, "lc_done") + b.on_tool_output_available("call_done", {"ok": True}, "lc_done") + b.on_tool_input_start("call_running", "rm", "lc_running") + b.on_tool_input_available("call_running", "rm", {}, "lc_running") + + # Pre-interrupt: one completed, one still running (no result). + pre = b.stats() + assert pre["tool_calls_completed"] == 1 + assert pre["tool_calls_aborted"] == 0 + + b.mark_interrupted() + post = b.stats() + assert post["tool_calls_completed"] == 1 + assert post["tool_calls_aborted"] == 1 + assert post["tool_calls"] == 2 + + def test_bytes_reflects_jsonb_payload_size(self): + # Each text-delta adds bytes monotonically — useful for catching + # an unbounded delta buffer regression in the perf signal. + b = AssistantContentBuilder() + b.on_text_start("t1") + b.on_text_delta("t1", "x" * 10) + small = b.stats()["bytes"] + b.on_text_delta("t1", "x" * 1000) + large = b.stats()["bytes"] + assert large > small + 900 diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 64e4d5157..208204ca9 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -457,6 +457,9 @@ def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows() source = page_path.read_text(encoding="utf-8") # Each flow tracks accepted boundary and passes it into shared terminal handling. + # The acceptance boundary is still meaningful post-refactor: it gates + # local-state cleanup (onPreAcceptFailure path) and lets the shared + # terminal handler distinguish pre-accept aborts from in-stream errors. assert "let newAccepted = false;" in source assert "let resumeAccepted = false;" in source assert "let regenerateAccepted = false;" in source @@ -464,12 +467,23 @@ def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows() assert "accepted: resumeAccepted," in source assert "accepted: regenerateAccepted," in source - # Pre-accept abort in resume/regenerate exits without persistence. - assert "if (!resumeAccepted) return;" in source - assert "if (!regenerateAccepted) return;" in source + # NOTE: The FE-side persistence guards previously asserted here + # ("if (!resumeAccepted) return;", "if (!regenerateAccepted) return;", + # "if (newAccepted && !userPersisted) {") have been intentionally + # removed by the SSE-based message-id handshake refactor. Persistence + # is now server-authoritative: persist_user_turn / persist_assistant_shell + # run inside stream_new_chat / stream_resume_chat unconditionally and + # the FE consumes data-user-message-id / data-assistant-message-id + # SSE events to learn the canonical primary keys. There is therefore + # no FE call-site to guard, and the shared terminal handler relies + # purely on the `accepted` field above (forwarded to onAbort / + # onAcceptedStreamError) to drive UI cleanup. See + # tests/integration/chat/test_message_id_sse.py for the new + # cross-tier ID coherence guarantees. - # New flow persists only when accepted and not already persisted. - assert "if (newAccepted && !userPersisted) {" in source + # The TURN_CANCELLING / THREAD_BUSY retry plumbing is independent + # of the persistence refactor and must still exist on every + # start-stream fetch. assert "const fetchWithTurnCancellingRetry = useCallback(" in source assert "computeFallbackTurnCancellingRetryDelay" in source assert 'withMeta.errorCode === "TURN_CANCELLING"' in source diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 4c8e4fe93..f2bae4167 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -82,6 +82,7 @@ import { mergeChatTurnIdIntoMessage, mergeEditedInterruptAction, readStreamedChatTurnId, + readStreamedMessageId, } from "@/lib/chat/stream-side-effects"; import { buildContentForPersistence, @@ -256,110 +257,17 @@ export default function NewChatPage() { [tokenUsageStore] ); - const persistUserTurn = useCallback( - async ({ - threadId, - userMsgId, - content, - mentionedDocs, - turnId, - logContext, - }: { - threadId: number | null; - userMsgId: string; - content: unknown; - mentionedDocs?: MentionedDocumentInfo[]; - turnId?: string | null; - logContext: string; - }) => { - if (!threadId) return null; - try { - const normalizedContent = Array.isArray(content) ? ([...content] as unknown[]) : [content]; - const hasMentionedDocumentsPart = normalizedContent.some( - (part) => MentionedDocumentsPartSchema.safeParse(part).success - ); - if (mentionedDocs && mentionedDocs.length > 0 && !hasMentionedDocumentsPart) { - normalizedContent.push({ - type: "mentioned-documents", - documents: mentionedDocs, - }); - } - - const savedUserMessage = await appendMessage(threadId, { - role: "user", - content: normalizedContent as AppendMessage["content"], - turn_id: turnId, - }); - const newUserMsgId = `msg-${savedUserMessage.id}`; - setMessages((prev) => - prev.map((m) => - m.id === userMsgId - ? mergeChatTurnIdIntoMessage({ ...m, id: newUserMsgId }, savedUserMessage.turn_id) - : m - ) - ); - if (mentionedDocs && mentionedDocs.length > 0) { - setMessageDocumentsMap((prev) => { - const { [userMsgId]: _, ...rest } = prev; - return { - ...rest, - [newUserMsgId]: mentionedDocs, - }; - }); - } - return newUserMsgId; - } catch (err) { - console.error(`Failed to persist ${logContext} user message:`, err); - return null; - } - }, - [setMessageDocumentsMap] - ); - - const persistAssistantTurn = useCallback( - async ({ - threadId, - assistantMsgId, - content, - tokenUsage, - turnId, - logContext, - onRemapped, - }: { - threadId: number | null; - assistantMsgId: string; - content: unknown; - tokenUsage?: TokenUsageData; - turnId?: string | null; - logContext: string; - onRemapped?: (newMsgId: string) => void; - }) => { - if (!threadId) return null; - try { - const savedMessage = await appendMessage(threadId, { - role: "assistant", - content: content as AppendMessage["content"], - token_usage: tokenUsage, - turn_id: turnId, - }); - const newMsgId = `msg-${savedMessage.id}`; - tokenUsageStore.rename(assistantMsgId, newMsgId); - setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId - ? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id) - : m - ) - ); - onRemapped?.(newMsgId); - return newMsgId; - } catch (err) { - console.error(`Failed to persist ${logContext} assistant message:`, err); - return null; - } - }, - [tokenUsageStore] - ); + // NOTE: ``persistUserTurn`` / ``persistAssistantTurn`` callbacks + // were removed in the SSE-based message ID handshake refactor. + // ``stream_new_chat`` and ``stream_resume_chat`` now persist both + // the user and assistant rows server-side via + // ``persist_user_turn`` / ``persist_assistant_shell`` and emit + // ``data-user-message-id`` / ``data-assistant-message-id`` SSE + // events; the consumers below rename the optimistic ids in real + // time. ``persistAssistantErrorMessage`` (above) is intentionally + // kept — it is the pre-stream-error fallback fired when the + // server NEVER accepted the request, and the BE has nothing to + // persist in that case. // Get disabled tools from the tool toggle UI const disabledTools = useAtomValue(disabledToolsAtom); @@ -891,8 +799,13 @@ export default function NewChatPage() { setPendingUserImageUrls((prev) => prev.filter((u) => !urlsSnapshot.includes(u))); } - // Add user message to state - const userMsgId = `msg-user-${Date.now()}`; + // Add user message to state. Mutable because the SSE + // ``data-user-message-id`` handler (below) renames this + // optimistic id to the canonical ``msg-{db_id}`` once the + // backend's ``persist_user_turn`` resolves the row, and + // the in-stream flush / interrupt closures need to see + // the post-rename value via this live ``let`` binding. + let userMsgId = `msg-user-${Date.now()}`; // Always include author metadata so the UI layer can decide visibility const authorMetadata = currentUser @@ -958,22 +871,16 @@ export default function NewChatPage() { })); } - const persistContent: unknown[] = [...userDisplayContent]; - - if (allMentionedDocs.length > 0) { - persistContent.push({ - type: "mentioned-documents", - documents: allMentionedDocs, - }); - } - // Start streaming response setIsRunning(true); const controller = new AbortController(); abortControllerRef.current = controller; - // Prepare assistant message - const assistantMsgId = `msg-assistant-${Date.now()}`; + // Prepare assistant message. Mutable for the same reason + // as ``userMsgId`` above — the ``data-assistant-message-id`` + // SSE handler reassigns this once + // ``persist_assistant_shell`` returns its canonical id. + let assistantMsgId = `msg-assistant-${Date.now()}`; const currentThinkingSteps = new Map(); const contentPartsState: ContentPartsState = { contentParts: [], @@ -983,11 +890,7 @@ export default function NewChatPage() { }; const { contentParts } = contentPartsState; let wasInterrupted = false; - let tokenUsageData: TokenUsageData | null = null; let newAccepted = false; - let userPersisted = false; - // Captured from ``data-turn-info`` at stream start. - let streamedChatTurnId: string | null = null; let streamBatcher: FrameBatchedUpdater | null = null; try { @@ -1047,6 +950,18 @@ export default function NewChatPage() { mentioned_surfsense_doc_ids: hasSurfsenseDocIds ? mentionedDocumentIds.surfsense_doc_ids : undefined, + // Full mention metadata so the BE can embed a + // ``mentioned-documents`` ContentPart on the + // persisted user message (replaces the old FE-side + // injection in ``persistUserTurn``). + mentioned_documents: + allMentionedDocs.length > 0 + ? allMentionedDocs.map((d) => ({ + id: d.id, + title: d.title, + document_type: d.document_type, + })) + : undefined, disabled_tools: disabledTools.length > 0 ? disabledTools : undefined, ...(userImages.length > 0 ? { user_images: userImages } : {}), }), @@ -1089,7 +1004,6 @@ export default function NewChatPage() { scheduleFlush, forceFlush, onTokenUsage: (data) => { - tokenUsageData = data; tokenUsageStore.set(assistantMsgId, data); }, onTurnStatus: (data) => { @@ -1189,7 +1103,6 @@ export default function NewChatPage() { case "data-turn-info": { const turnId = readStreamedChatTurnId(parsed.data); - streamedChatTurnId = turnId; if (turnId) { setMessages((prev) => applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId) @@ -1197,46 +1110,96 @@ export default function NewChatPage() { } break; } + + case "data-user-message-id": { + // Server-authoritative user message id resolved by + // ``persist_user_turn`` (or recovered via ON CONFLICT). + // Rename the optimistic ``msg-user-XXX`` placeholder to + // the canonical ``msg-{db_id}`` so DB-id-gated UI + // (comments, edit-from-this-message) unlocks immediately, + // migrate the local mentioned-documents map, and reassign + // the closure variable so all downstream + // ``m.id === userMsgId`` checks see the new value. + const parsedMsg = readStreamedMessageId(parsed.data); + if (!parsedMsg) break; + const newUserMsgId = `msg-${parsedMsg.messageId}`; + const oldUserMsgId = userMsgId; + setMessages((prev) => + prev.map((m) => + m.id === oldUserMsgId + ? mergeChatTurnIdIntoMessage( + { ...m, id: newUserMsgId }, + parsedMsg.turnId + ) + : m + ) + ); + if (allMentionedDocs.length > 0) { + setMessageDocumentsMap((prev) => { + if (!(oldUserMsgId in prev)) { + return { ...prev, [newUserMsgId]: allMentionedDocs }; + } + const { [oldUserMsgId]: _removed, ...rest } = prev; + return { ...rest, [newUserMsgId]: allMentionedDocs }; + }); + } + userMsgId = newUserMsgId; + if (isNewThread) { + // First user-side row landed in ``new_chat_messages``; + // refresh the sidebar so the freshly-bumped + // ``thread.updated_at`` reorders this thread. + queryClient.invalidateQueries({ + queryKey: ["threads", String(searchSpaceId)], + }); + } + break; + } + + case "data-assistant-message-id": { + // Server-authoritative assistant message id resolved + // by ``persist_assistant_shell``. Rename the optimistic + // id, migrate ``tokenUsageStore`` so any pending + // ``data-token-usage`` payload binds to the new id, + // remap any in-flight ``pendingInterrupt`` reference, + // and reassign the closure variable so the in-stream + // flush callback (line ~1074) keeps writing to the + // renamed message. + const parsedMsg = readStreamedMessageId(parsed.data); + if (!parsedMsg) break; + const newAssistantMsgId = `msg-${parsedMsg.messageId}`; + const oldAssistantMsgId = assistantMsgId; + tokenUsageStore.rename(oldAssistantMsgId, newAssistantMsgId); + setMessages((prev) => + prev.map((m) => + m.id === oldAssistantMsgId + ? mergeChatTurnIdIntoMessage( + { ...m, id: newAssistantMsgId }, + parsedMsg.turnId + ) + : m + ) + ); + setPendingInterrupt((prev) => + prev && prev.assistantMsgId === oldAssistantMsgId + ? { ...prev, assistantMsgId: newAssistantMsgId } + : prev + ); + assistantMsgId = newAssistantMsgId; + break; + } } }); batcher.flush(); - // Skip persistence for interrupted messages -- handleResume will persist the final version - const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI); + // Server-authoritative persistence: ``stream_new_chat`` + // already wrote the user row in ``persist_user_turn`` + // (the FE renamed the optimistic id mid-stream via + // ``data-user-message-id``) and finalises the assistant + // row in ``finalize_assistant_turn`` from a shielded + // ``finally`` block. Nothing left for the FE to persist + // here — track the response and unblock the UI. if (contentParts.length > 0 && !wasInterrupted) { - if (!userPersisted) { - const persistedUserMsgId = await persistUserTurn({ - threadId: currentThreadId, - userMsgId, - content: persistContent, - mentionedDocs: allMentionedDocs, - turnId: streamedChatTurnId, - logContext: "new chat", - }); - userPersisted = Boolean(persistedUserMsgId); - if (userPersisted && isNewThread) { - queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] }); - } - } - - await persistAssistantTurn({ - threadId: currentThreadId, - assistantMsgId, - content: finalContent, - tokenUsage: tokenUsageData ?? undefined, - turnId: streamedChatTurnId, - logContext: "new chat", - onRemapped: (newMsgId) => { - setPendingInterrupt((prev) => - prev && prev.assistantMsgId === assistantMsgId - ? { ...prev, assistantMsgId: newMsgId } - : prev - ); - }, - }); - - // Track successful response trackChatResponseReceived(searchSpaceId, currentThreadId); } } catch (error) { @@ -1247,51 +1210,21 @@ export default function NewChatPage() { threadId: currentThreadId, assistantMsgId, accepted: newAccepted, - onAbort: async () => { - if (newAccepted && !userPersisted) { - const persistedUserMsgId = await persistUserTurn({ - threadId: currentThreadId, - userMsgId, - content: persistContent, - mentionedDocs: allMentionedDocs, - turnId: streamedChatTurnId, - logContext: "new chat (aborted)", - }); - userPersisted = Boolean(persistedUserMsgId); - if (userPersisted && isNewThread) { - queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] }); - } - } - - const hasContent = hasPersistableContent(contentParts, toolsWithUI); - if (hasContent && currentThreadId) { - const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); - await persistAssistantTurn({ - threadId: currentThreadId, - assistantMsgId, - content: partialContent, - turnId: streamedChatTurnId, - logContext: "partial new chat", - }); - } - }, - onAcceptedStreamError: async () => { - if (!userPersisted) { - const persistedUserMsgId = await persistUserTurn({ - threadId: currentThreadId, - userMsgId, - content: persistContent, - mentionedDocs: allMentionedDocs, - turnId: streamedChatTurnId, - logContext: "new chat (stream error)", - }); - userPersisted = Boolean(persistedUserMsgId); - if (userPersisted && isNewThread) { - queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] }); - } - } - }, + // Server-side ``finalize_assistant_turn`` runs from a + // shielded ``anyio.CancelScope(shield=True)`` finally + // block, so partial content (incl. abort-mid-stream) + // is already persisted by the BE for the assistant + // row, and ``persist_user_turn`` ran before any LLM + // call. The FE's only remaining responsibility on + // abort / accepted-stream-error is to surface the + // error toast (handled by ``handleStreamTerminalError`` + // itself). onPreAcceptFailure: async () => { + // Pre-accept failure means the BE never accepted the + // request — no server-side persistence ran. Roll + // back the optimistic UI insertions we made before + // the fetch so the user message and any local + // mentioned-docs metadata don't linger. setMessages((prev) => prev.filter((m) => m.id !== userMsgId)); setMessageDocumentsMap((prev) => { if (!(userMsgId in prev)) return prev; @@ -1325,8 +1258,6 @@ export default function NewChatPage() { fetchWithTurnCancellingRetry, handleStreamTerminalError, handleChatFailure, - persistAssistantTurn, - persistUserTurn, ] ); @@ -1339,7 +1270,12 @@ export default function NewChatPage() { }> ) => { if (!pendingInterrupt) return; - const { threadId: resumeThreadId, assistantMsgId } = pendingInterrupt; + const { threadId: resumeThreadId } = pendingInterrupt; + // Destructured separately as ``let`` so the SSE + // ``data-assistant-message-id`` handler (resume always + // allocates a fresh server-side row) can rename it to + // the canonical ``msg-{db_id}`` mid-stream. + let assistantMsgId = pendingInterrupt.assistantMsgId; setPendingInterrupt(null); setIsRunning(true); @@ -1362,10 +1298,7 @@ export default function NewChatPage() { toolCallIndices: new Map(), }; const { contentParts, toolCallIndices } = contentPartsState; - let tokenUsageData: TokenUsageData | null = null; let resumeAccepted = false; - // Captured from ``data-turn-info`` at stream start. - let streamedChatTurnId: string | null = null; let streamBatcher: FrameBatchedUpdater | null = null; const existingMsg = messages.find((m) => m.id === assistantMsgId); @@ -1466,7 +1399,6 @@ export default function NewChatPage() { scheduleFlush, forceFlush, onTokenUsage: (data) => { - tokenUsageData = data; tokenUsageStore.set(assistantMsgId, data); }, onTurnStatus: (data) => { @@ -1514,7 +1446,6 @@ export default function NewChatPage() { case "data-turn-info": { const turnId = readStreamedChatTurnId(parsed.data); - streamedChatTurnId = turnId; if (turnId) { setMessages((prev) => applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId) @@ -1522,22 +1453,44 @@ export default function NewChatPage() { } break; } + + case "data-assistant-message-id": { + // Resume always allocates a fresh ``new_chat_messages`` + // row anchored to a new ``turn_id`` (the original + // interrupted turn's row stays as-is), so this is a + // real id swap. Rename the optimistic placeholder to + // ``msg-{db_id}`` and reassign closure state. Resume + // does NOT emit ``data-user-message-id`` — the user + // row belongs to the original interrupted turn. + const parsedMsg = readStreamedMessageId(parsed.data); + if (!parsedMsg) break; + const newAssistantMsgId = `msg-${parsedMsg.messageId}`; + const oldAssistantMsgId = assistantMsgId; + tokenUsageStore.rename(oldAssistantMsgId, newAssistantMsgId); + setMessages((prev) => + prev.map((m) => + m.id === oldAssistantMsgId + ? mergeChatTurnIdIntoMessage( + { ...m, id: newAssistantMsgId }, + parsedMsg.turnId + ) + : m + ) + ); + assistantMsgId = newAssistantMsgId; + break; + } } }); batcher.flush(); - const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI); - if (contentParts.length > 0) { - await persistAssistantTurn({ - threadId: resumeThreadId, - assistantMsgId, - content: finalContent, - tokenUsage: tokenUsageData ?? undefined, - turnId: streamedChatTurnId, - logContext: "resumed chat", - }); - } + // Server-authoritative persistence: ``stream_resume_chat`` + // finalises the assistant row in + // ``finalize_assistant_turn`` from a shielded + // ``finally`` block (covers both happy-path and + // abort-mid-stream). FE has no remaining persistence + // work here. } catch (error) { streamBatcher?.dispose(); await handleStreamTerminalError({ @@ -1546,19 +1499,6 @@ export default function NewChatPage() { threadId: resumeThreadId, assistantMsgId, accepted: resumeAccepted, - onAbort: async () => { - if (!resumeAccepted) return; - const hasContent = hasPersistableContent(contentParts, toolsWithUI); - if (!hasContent) return; - const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); - await persistAssistantTurn({ - threadId: resumeThreadId, - assistantMsgId, - content: partialContent, - turnId: streamedChatTurnId, - logContext: "partial resumed chat", - }); - }, }); } finally { setIsRunning(false); @@ -1574,7 +1514,6 @@ export default function NewChatPage() { tokenUsageStore, fetchWithTurnCancellingRetry, handleStreamTerminalError, - persistAssistantTurn, ] ); @@ -1715,9 +1654,12 @@ export default function NewChatPage() { const controller = new AbortController(); abortControllerRef.current = controller; - // Add placeholder user message if we have a new query (edit mode) - const userMsgId = `msg-user-${Date.now()}`; - const assistantMsgId = `msg-assistant-${Date.now()}`; + // Add placeholder user message if we have a new query (edit mode). + // Mutable for the same reason as in ``onNew`` — both ids are + // renamed mid-stream by the new ``data-user-message-id`` / + // ``data-assistant-message-id`` SSE handlers below. + let userMsgId = `msg-user-${Date.now()}`; + let assistantMsgId = `msg-assistant-${Date.now()}`; const currentThinkingSteps = new Map(); const contentPartsState: ContentPartsState = { @@ -1727,13 +1669,7 @@ export default function NewChatPage() { toolCallIndices: new Map(), }; const { contentParts } = contentPartsState; - let tokenUsageData: TokenUsageData | null = null; let regenerateAccepted = false; - let userPersisted = false; - // Captured from ``data-turn-info`` at stream start; stamped - // onto persisted messages so future edits can locate the - // right LangGraph checkpoint. - let streamedChatTurnId: string | null = null; let streamBatcher: FrameBatchedUpdater | null = null; // Add placeholder messages to UI @@ -1747,9 +1683,6 @@ export default function NewChatPage() { createdAt: new Date(), metadata: isEdit ? undefined : originalUserMessageMetadata, }; - const userContentToPersist = isEdit - ? (editExtras?.userMessageContent ?? [{ type: "text", text: newUserQuery ?? "" }]) - : originalUserMessageContent || [{ type: "text", text: userQueryToDisplay || "" }]; const sourceMentionedDocs = sourceUserMessageId && messageDocumentsMap[sourceUserMessageId] ? messageDocumentsMap[sourceUserMessageId] @@ -1765,6 +1698,18 @@ export default function NewChatPage() { filesystem_mode: selection.filesystem_mode, client_platform: selection.client_platform, local_filesystem_mounts: selection.local_filesystem_mounts, + // Full mention metadata for the regenerate-specific + // source list. Only meaningful for edit (the BE only + // re-persists a user row when ``user_query`` is set); + // reload reuses the original turn's mentioned_documents. + mentioned_documents: + sourceMentionedDocs.length > 0 + ? sourceMentionedDocs.map((d) => ({ + id: d.id, + title: d.title, + document_type: d.document_type, + })) + : undefined, }; if (isEdit) { requestBody.user_images = editExtras?.userImages ?? []; @@ -1852,7 +1797,6 @@ export default function NewChatPage() { scheduleFlush, forceFlush, onTokenUsage: (data) => { - tokenUsageData = data; tokenUsageStore.set(assistantMsgId, data); }, onTurnStatus: (data) => { @@ -1897,7 +1841,6 @@ export default function NewChatPage() { case "data-turn-info": { const turnId = readStreamedChatTurnId(parsed.data); - streamedChatTurnId = turnId; if (turnId) { setMessages((prev) => applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId) @@ -1906,6 +1849,57 @@ export default function NewChatPage() { break; } + case "data-user-message-id": { + // Same role as in ``onNew`` but the regenerate-specific + // mention metadata (``sourceMentionedDocs``) is the + // list to migrate onto the canonical id key. + const parsedMsg = readStreamedMessageId(parsed.data); + if (!parsedMsg) break; + const newUserMsgId = `msg-${parsedMsg.messageId}`; + const oldUserMsgId = userMsgId; + setMessages((prev) => + prev.map((m) => + m.id === oldUserMsgId + ? mergeChatTurnIdIntoMessage( + { ...m, id: newUserMsgId }, + parsedMsg.turnId + ) + : m + ) + ); + if (sourceMentionedDocs.length > 0) { + setMessageDocumentsMap((prev) => { + if (!(oldUserMsgId in prev)) { + return { ...prev, [newUserMsgId]: sourceMentionedDocs }; + } + const { [oldUserMsgId]: _removed, ...rest } = prev; + return { ...rest, [newUserMsgId]: sourceMentionedDocs }; + }); + } + userMsgId = newUserMsgId; + break; + } + + case "data-assistant-message-id": { + const parsedMsg = readStreamedMessageId(parsed.data); + if (!parsedMsg) break; + const newAssistantMsgId = `msg-${parsedMsg.messageId}`; + const oldAssistantMsgId = assistantMsgId; + tokenUsageStore.rename(oldAssistantMsgId, newAssistantMsgId); + setMessages((prev) => + prev.map((m) => + m.id === oldAssistantMsgId + ? mergeChatTurnIdIntoMessage( + { ...m, id: newAssistantMsgId }, + parsedMsg.turnId + ) + : m + ) + ); + assistantMsgId = newAssistantMsgId; + break; + } + case "data-revert-results": { const summary = parsed.data; // failureCount must include every "not undone" bucket @@ -1946,28 +1940,14 @@ export default function NewChatPage() { batcher.flush(); - // Persist messages after streaming completes - const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI); + // Server-authoritative persistence: ``stream_new_chat`` + // (regenerate flow) wrote the user row in + // ``persist_user_turn`` and finalises the assistant row + // in ``finalize_assistant_turn`` from a shielded + // ``finally`` block (covers both happy-path and + // abort-mid-stream). FE only needs to track the + // successful response here. if (contentParts.length > 0) { - const persistedUserMsgId = await persistUserTurn({ - threadId, - userMsgId, - content: userContentToPersist, - mentionedDocs: sourceMentionedDocs, - turnId: streamedChatTurnId, - logContext: "regenerated", - }); - userPersisted = Boolean(persistedUserMsgId); - - await persistAssistantTurn({ - threadId, - assistantMsgId, - content: finalContent, - tokenUsage: tokenUsageData ?? undefined, - turnId: streamedChatTurnId, - logContext: "regenerated", - }); - trackChatResponseReceived(searchSpaceId, threadId); } } catch (error) { @@ -1978,44 +1958,6 @@ export default function NewChatPage() { threadId, assistantMsgId, accepted: regenerateAccepted, - onAbort: async () => { - if (!regenerateAccepted) return; - if (!userPersisted) { - const persistedUserMsgId = await persistUserTurn({ - threadId, - userMsgId, - content: userContentToPersist, - mentionedDocs: sourceMentionedDocs, - turnId: streamedChatTurnId, - logContext: "regenerated (aborted)", - }); - userPersisted = Boolean(persistedUserMsgId); - } - const hasContent = hasPersistableContent(contentParts, toolsWithUI); - if (!hasContent) return; - const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); - await persistAssistantTurn({ - threadId, - assistantMsgId, - content: partialContent, - tokenUsage: tokenUsageData ?? undefined, - turnId: streamedChatTurnId, - logContext: "partial regenerated chat", - }); - }, - onAcceptedStreamError: async () => { - if (!userPersisted) { - const persistedUserMsgId = await persistUserTurn({ - threadId, - userMsgId, - content: userContentToPersist, - mentionedDocs: sourceMentionedDocs, - turnId: streamedChatTurnId, - logContext: "regenerated (stream error)", - }); - userPersisted = Boolean(persistedUserMsgId); - } - }, }); } finally { setIsRunning(false); @@ -2034,8 +1976,6 @@ export default function NewChatPage() { tokenUsageStore, fetchWithTurnCancellingRetry, handleStreamTerminalError, - persistAssistantTurn, - persistUserTurn, ] ); diff --git a/surfsense_web/lib/chat/stream-side-effects.ts b/surfsense_web/lib/chat/stream-side-effects.ts index 5483ff14b..136afce44 100644 --- a/surfsense_web/lib/chat/stream-side-effects.ts +++ b/surfsense_web/lib/chat/stream-side-effects.ts @@ -114,6 +114,29 @@ export function readStreamedChatTurnId(data: unknown): string | null { return typeof value === "string" && value.length > 0 ? value : null; } +/** + * Parse the payload of `data-user-message-id` / `data-assistant-message-id` + * SSE events emitted by `stream_new_chat` and `stream_resume_chat` after + * `persist_user_turn` / `persist_assistant_shell` resolve a canonical + * `new_chat_messages.id`. Mirrors {@link readStreamedChatTurnId}. + * + * Returns `null` when the payload is malformed (missing or non-numeric + * `message_id`); callers should treat this as "ignore the event" so a + * malformed BE payload never overwrites the optimistic id with a bogus + * value. + */ +export function readStreamedMessageId( + data: unknown +): { messageId: number; turnId: string | null } | null { + if (typeof data !== "object" || data === null) return null; + const obj = data as { message_id?: unknown; turn_id?: unknown }; + if (typeof obj.message_id !== "number" || !Number.isFinite(obj.message_id)) { + return null; + } + const turnId = typeof obj.turn_id === "string" && obj.turn_id.length > 0 ? obj.turn_id : null; + return { messageId: obj.message_id, turnId }; +} + export function applyTurnIdToAssistantMessageList( messages: ThreadMessageLike[], assistantMsgId: string, diff --git a/surfsense_web/lib/chat/streaming-state.ts b/surfsense_web/lib/chat/streaming-state.ts index 6df56f0ce..27047ecfe 100644 --- a/surfsense_web/lib/chat/streaming-state.ts +++ b/surfsense_web/lib/chat/streaming-state.ts @@ -487,6 +487,37 @@ export type SSEEvent = type: "data-turn-info"; data: { chat_turn_id: string }; } + | { + /** + * Emitted by ``stream_new_chat`` AFTER ``data-turn-info`` / + * ``data-turn-status`` and BEFORE any LLM streaming events, + * once ``persist_user_turn`` has resolved the canonical + * ``new_chat_messages.id`` for the user-side row of the + * current turn. The frontend renames its optimistic + * ``msg-user-XXX`` placeholder id to ``msg-{message_id}`` + * so DB-id-gated UI (comments, edit-from-this-message) + * unlocks immediately. Not emitted by ``stream_resume_chat`` + * (resume reuses the original turn's user message). + */ + type: "data-user-message-id"; + data: { message_id: number; turn_id: string }; + } + | { + /** + * Emitted by ``stream_new_chat`` AND ``stream_resume_chat`` + * AFTER ``data-turn-info`` / ``data-turn-status`` and BEFORE + * any LLM streaming events, once ``persist_assistant_shell`` + * has resolved the canonical ``new_chat_messages.id`` for + * the assistant-side row of the current turn. The frontend + * renames its optimistic ``msg-assistant-XXX`` placeholder + * id, migrates the local ``tokenUsageStore`` and + * ``pendingInterrupt`` references, and binds the running + * mutable ``assistantMsgId`` closure variable to the + * canonical id for the rest of the stream. + */ + type: "data-assistant-message-id"; + data: { message_id: number; turn_id: string }; + } | { /** * Best-effort revert pass that ran BEFORE this regeneration. diff --git a/surfsense_web/lib/chat/thread-persistence.ts b/surfsense_web/lib/chat/thread-persistence.ts index 7fec60a23..2fb283f87 100644 --- a/surfsense_web/lib/chat/thread-persistence.ts +++ b/surfsense_web/lib/chat/thread-persistence.ts @@ -144,6 +144,17 @@ export async function getThreadMessages(threadId: number): Promise