Merge pull request #1341 from MODSetter/dev

feat: moved chat persistance to Server Side
This commit is contained in:
Rohan Verma 2026-05-04 03:10:47 -07:00 committed by GitHub
commit 743eff42cd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 4515 additions and 390 deletions

31
.vscode/launch.json vendored
View file

@ -26,7 +26,16 @@
"pythonArgs": [
"run",
"python"
]
],
// Mute LangGraph/Pydantic checkpoint serializer warnings
// (UserWarnings emitted from pydantic/main.py when the
// runtime snapshots a SurfSenseContextSchema into a field
// typed `None`) so the debugger's "Raised Exceptions"
// breakpoint doesn't pause on a known-harmless event.
// Production logs are unaffected.
"env": {
"PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
}
},
{
"name": "Backend: FastAPI (No Reload)",
@ -40,7 +49,10 @@
"pythonArgs": [
"run",
"python"
]
],
"env": {
"PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
}
},
{
"name": "Backend: FastAPI (main.py)",
@ -54,7 +66,10 @@
"pythonArgs": [
"run",
"python"
]
],
"env": {
"PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
}
},
{
"name": "Frontend: Next.js",
@ -104,7 +119,10 @@
"pythonArgs": [
"run",
"python"
]
],
"env": {
"PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
}
},
{
"name": "Celery: Beat Scheduler",
@ -124,7 +142,10 @@
"pythonArgs": [
"run",
"python"
]
],
"env": {
"PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
}
}
],
"compounds": [

View file

@ -0,0 +1,66 @@
"""141_unique_chat_message_turn_role
Revision ID: 141
Revises: 140
Create Date: 2026-05-04
Add a partial unique index on ``new_chat_messages(thread_id, turn_id, role)``
where ``turn_id IS NOT NULL``.
Why
---
The streaming chat path (`stream_new_chat` / `stream_resume_chat`) is being
moved to write its own ``new_chat_messages`` rows server-side instead of
relying on the frontend's later ``POST /threads/{id}/messages`` call. This
closes the "ghost-thread" abuse vector where authenticated callers got free
LLM completions while ``new_chat_messages`` stayed empty.
For server-side and legacy frontend writes to coexist we need an idempotency
key. The natural triple is ``(thread_id, turn_id, role)``: the server issues
exactly one ``turn_id`` per turn, and a turn produces at most one user
message and one assistant message. Whichever side wins the race writes the
row; the loser hits ``IntegrityError`` and recovers gracefully.
Partial ``WHERE turn_id IS NOT NULL`` so:
* Legacy rows that predate the ``turn_id`` column (migration 136) keep
co-existing without de-dup.
* Clone / snapshot inserts in
``app/services/public_chat_service.py`` that build ``NewChatMessage``
without ``turn_id`` are unaffected (multiple snapshot copies of the same
user/assistant pair are intentional).
This index coexists with the existing single-column ``ix_new_chat_messages_turn_id``
from migration 136 no collision.
"""
from __future__ import annotations
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "141"
down_revision: str | None = "140"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
INDEX_NAME = "uq_new_chat_messages_thread_turn_role"
TABLE_NAME = "new_chat_messages"
def upgrade() -> None:
op.create_index(
INDEX_NAME,
TABLE_NAME,
["thread_id", "turn_id", "role"],
unique=True,
postgresql_where=sa.text("turn_id IS NOT NULL"),
)
def downgrade() -> None:
op.drop_index(INDEX_NAME, table_name=TABLE_NAME)

View file

@ -0,0 +1,134 @@
"""142_token_usage_message_id_unique
Revision ID: 142
Revises: 141
Create Date: 2026-05-04
Add a partial unique index on ``token_usage(message_id)`` where
``message_id IS NOT NULL``.
Why
---
Two writers can race on the same assistant turn's ``token_usage`` row:
* ``finalize_assistant_turn`` (server-side, called from the streaming
finally block in ``stream_new_chat`` / ``stream_resume_chat``)
* ``append_message``'s recovery branch in
``app/routes/new_chat_routes.py`` (legacy frontend round-trip)
Both currently use ``SELECT ... THEN INSERT`` in separate sessions, so a
micro-second-aligned race could observe "no row" on each side and double
INSERT, producing duplicate ``token_usage`` rows for the same
``message_id``.
A partial unique index on ``message_id`` (``WHERE message_id IS NOT NULL``)
turns both writes into ``INSERT ... ON CONFLICT (message_id) DO NOTHING``
no-ops for the loser, hard-eliminating the race at the DB level. Partial
because non-chat usage rows (indexing, image generation, podcasts) keep
``message_id`` NULL they're per-event, no de-dup needed.
Pre-flight
----------
Today's schema only has a non-unique index on ``message_id`` so a
duplicate population could already exist from any past race. We:
* Detect duplicate ``message_id`` groups (``HAVING COUNT(*) > 1``).
* If the group count is at or below ``DUPLICATE_ABORT_THRESHOLD`` (50)
we dedupe by deleting all but the smallest ``id`` per group.
* If the count exceeds the threshold we abort with a descriptive
error rather than silently mutate prod data operator must
investigate before retrying.
Concurrency
-----------
``CREATE INDEX CONCURRENTLY`` is required on this hot table to avoid
stalling production writes during deploy (a regular ``CREATE INDEX``
holds an ACCESS EXCLUSIVE lock for the duration of the build, which
would block ``token_usage`` INSERTs for every active streaming chat).
The trade-off is a slower migration (CONCURRENTLY scans the table
twice) and the ``CREATE`` statement cannot run inside alembic's default
transaction wrapper ``autocommit_block()`` handles that.
"""
from __future__ import annotations
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "142"
down_revision: str | None = "141"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
INDEX_NAME = "uq_token_usage_message_id"
TABLE_NAME = "token_usage"
# Refuse to silently mutate prod data if the duplicate population is
# unexpectedly large — operator should investigate the upstream cause
# before retrying. 50 is comfortably above any plausible duplicate
# count from the existing race window (the race is microseconds wide).
DUPLICATE_ABORT_THRESHOLD = 50
def upgrade() -> None:
conn = op.get_bind()
dup_groups = conn.execute(
sa.text(
"SELECT message_id, COUNT(*) AS n "
"FROM token_usage "
"WHERE message_id IS NOT NULL "
"GROUP BY message_id "
"HAVING COUNT(*) > 1"
)
).fetchall()
if len(dup_groups) > DUPLICATE_ABORT_THRESHOLD:
raise RuntimeError(
f"token_usage has {len(dup_groups)} duplicate message_id groups "
f"(threshold={DUPLICATE_ABORT_THRESHOLD}). "
"Resolve the duplicates manually before re-running this migration."
)
if dup_groups:
# Delete all but the smallest-id row per duplicate group. The
# smallest id is by definition the earliest insert, so we keep
# the row most likely to reflect the actual stream's first
# successful write.
conn.execute(
sa.text(
"""
DELETE FROM token_usage
WHERE id IN (
SELECT id FROM (
SELECT
id,
row_number() OVER (
PARTITION BY message_id ORDER BY id ASC
) AS rn
FROM token_usage
WHERE message_id IS NOT NULL
) ranked
WHERE rn > 1
)
"""
)
)
# CREATE INDEX CONCURRENTLY cannot run inside a transaction. Drop
# alembic's auto-transaction for this op only.
with op.get_context().autocommit_block():
op.execute(
f"CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS {INDEX_NAME} "
f"ON {TABLE_NAME} (message_id) "
"WHERE message_id IS NOT NULL"
)
def downgrade() -> None:
with op.get_context().autocommit_block():
op.execute(f"DROP INDEX CONCURRENTLY IF EXISTS {INDEX_NAME}")

View file

@ -675,6 +675,23 @@ class NewChatMessage(BaseModel, TimestampMixin):
__tablename__ = "new_chat_messages"
# Partial unique index on (thread_id, turn_id, role) where turn_id IS NOT NULL.
# Mirrors alembic migration 141. Lets the streaming agent and the
# legacy frontend appendMessage call coexist idempotently — the second
# writer trips the unique and recovers without creating a duplicate row.
# Partial so legacy NULL turn_id rows and clone/snapshot inserts in
# app/services/public_chat_service.py (which omit turn_id) are unaffected.
__table_args__ = (
Index(
"uq_new_chat_messages_thread_turn_role",
"thread_id",
"turn_id",
"role",
unique=True,
postgresql_where=text("turn_id IS NOT NULL"),
),
)
role = Column(SQLAlchemyEnum(NewChatMessageRole), nullable=False)
# Content stored as JSONB to support rich content (text, tool calls, etc.)
content = Column(JSONB, nullable=False)
@ -728,6 +745,22 @@ class TokenUsage(BaseModel, TimestampMixin):
__tablename__ = "token_usage"
# Partial unique index on (message_id) where message_id IS NOT NULL.
# Mirrors alembic migration 142. Lets the streaming agent's
# ``finalize_assistant_turn`` and the legacy frontend ``append_message``
# recovery branch both use ``INSERT ... ON CONFLICT DO NOTHING`` without
# racing on a SELECT-then-INSERT window. Partial so non-chat usage rows
# (indexing, image generation, podcasts) — which keep ``message_id`` NULL
# because there is no per-message anchor — are unaffected.
__table_args__ = (
Index(
"uq_token_usage_message_id",
"message_id",
unique=True,
postgresql_where=text("message_id IS NOT NULL"),
),
)
prompt_tokens = Column(Integer, nullable=False, default=0)
completion_tokens = Column(Integer, nullable=False, default=0)
total_tokens = Column(Integer, nullable=False, default=0)

View file

@ -17,7 +17,8 @@ from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from sqlalchemy import func, or_
from sqlalchemy import func, or_, text as sa_text
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
@ -44,6 +45,7 @@ from app.db import (
NewChatThread,
Permission,
SearchSpace,
TokenUsage,
User,
get_async_session,
shielded_async_session,
@ -69,9 +71,9 @@ from app.schemas.new_chat import (
TokenUsageSummary,
TurnStatusResponse,
)
from app.services.token_tracking_service import record_token_usage
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
from app.users import current_active_user
from app.utils.perf import get_perf_logger
from app.utils.rbac import check_permission
from app.utils.user_message_multimodal import (
split_langchain_human_content,
@ -79,6 +81,7 @@ from app.utils.user_message_multimodal import (
)
_logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
_background_tasks: set[asyncio.Task] = set()
TURN_CANCELLING_INITIAL_DELAY_MS = 200
TURN_CANCELLING_BACKOFF_FACTOR = 2
@ -1287,6 +1290,24 @@ async def append_message(
user: User = Depends(current_active_user),
):
"""
.. deprecated:: 2026-05
Replaced by the **SSE-based message ID handshake**. The streaming
generator (`stream_new_chat` / `stream_resume_chat`) now persists
both the user and assistant rows server-side via
``persist_user_turn`` / ``persist_assistant_shell`` and emits
``data-user-message-id`` / ``data-assistant-message-id`` SSE events
so the frontend can rename its optimistic IDs in real time. The
new FE bundle no longer calls this route.
This handler is retained as a **silent no-op for legacy / cached
FE bundles**: the underlying ``INSERT ... ON CONFLICT DO NOTHING``
pattern means a stale bundle hitting this route after the SSE
handshake already wrote the row simply returns the existing row
(200 OK) without raising or duplicating data. After a 2-week soak
(target: ``[persist_user_turn] outcome=race_recovered`` rate ~0)
this entire route and the FE ``appendMessage`` function is
earmarked for removal.
Append a message to a thread.
This is used by ThreadHistoryAdapter.append() to persist messages.
@ -1297,6 +1318,22 @@ async def append_message(
Requires CHATS_UPDATE permission.
"""
try:
# Capture ``user.id`` as a primitive UUID up front. The
# ``current_active_user`` dependency hands us a ``User`` ORM
# row bound to ``session``; if the outer ``except
# IntegrityError`` block below ever fires (an unexpected
# constraint like a foreign key violation — the common
# ``(thread_id, turn_id, role)`` race is now handled silently
# by ``ON CONFLICT DO NOTHING`` so it never raises) it calls
# ``session.rollback()``, which expires every attached ORM
# row including this user. Any later ``user.id`` access would
# then trigger a lazy PK reload — which on async SQLAlchemy
# fails with ``MissingGreenlet`` because the reload happens
# outside the awaitable greenlet boundary. Reading ``id``
# once here pins the value as a plain UUID so all downstream
# uses (TokenUsage insert, response build) are immune.
user_uuid = user.id
# Parse raw body - extract only role and content, ignoring extra fields
raw_body = await request.json()
role = raw_body.get("role")
@ -1351,42 +1388,166 @@ async def append_message(
else None
)
db_message = NewChatMessage(
thread_id=thread_id,
role=message_role,
content=content,
author_id=user.id,
turn_id=turn_id_value,
)
session.add(db_message)
# Update thread's updated_at timestamp
# Update thread's updated_at timestamp (always — both insert
# and recovery paths represent thread activity).
thread.updated_at = datetime.now(UTC)
# flush assigns the PK/defaults without a round-trip SELECT
await session.flush()
# Insert the new message via ``INSERT ... ON CONFLICT DO NOTHING``
# keyed on the ``(thread_id, turn_id, role)`` partial unique
# index from migration 141 (``WHERE turn_id IS NOT NULL``).
#
# Why ON CONFLICT instead of ``session.add() + flush() + except
# IntegrityError``:
# 1. The conflict between this legacy FE ``appendMessage``
# round-trip and the server-side
# ``finalize_assistant_turn`` writer is a NORMAL,
# *expected* race — every assistant turn fires it. Using
# catch-and-recover means asyncpg raises
# ``UniqueViolationError`` -> SQLAlchemy wraps it as
# ``IntegrityError`` -> our handler catches and recovers.
# Functionally fine, but every ``raise`` event lights up
# VS Code's debugger (debugpy's ``justMyCode=false`` mode
# loses track of the catch frame across SQLAlchemy's
# async greenlet boundary, so even ``Raised Exceptions``
# being unchecked doesn't reliably suppress the pause).
# ON CONFLICT pushes the conflict resolution into Postgres
# where no Python exception is constructed at all.
# 2. No ``session.rollback()`` -> no expiring of attached
# ORM rows -> no risk of ``MissingGreenlet`` from
# lazy-loading expired user/thread state later in the
# handler.
# 3. Cleaner production logs (no SQLAlchemy ``IntegrityError``
# tracebacks emitted by uvicorn's logger between the
# ``raise`` and our ``except``).
#
# When ``turn_id_value`` is ``None`` the partial index doesn't
# apply and the INSERT proceeds normally. Other constraint
# violations (FK, NOT NULL, etc.) still raise ``IntegrityError``
# and are caught by the outer ``except IntegrityError`` block
# to preserve the legacy 400 behavior.
#
# Note on ``content``: when we recover the existing row, we
# intentionally discard the FE's ``content`` payload from
# ``raw_body`` and return the row's existing ``content``. The
# streaming task is now the *authoritative writer* for
# assistant ``ContentPart[]`` shape (mid-stream
# ``AssistantContentBuilder`` -> ``finalize_assistant_turn``)
# so the FE's later ``appendMessage`` is just a stale snapshot
# of the same data — keeping the server-built rich content
# (with full tool-call args / argsText / langchainToolCallId)
# is correct, not lossy.
insert_stmt = (
pg_insert(NewChatMessage)
.values(
thread_id=thread_id,
role=message_role,
content=content,
author_id=user_uuid,
turn_id=turn_id_value,
)
.on_conflict_do_nothing(
index_elements=["thread_id", "turn_id", "role"],
index_where=sa_text("turn_id IS NOT NULL"),
)
.returning(NewChatMessage.id)
)
inserted_id = (await session.execute(insert_stmt)).scalar()
if inserted_id is None:
# Conflict on partial unique index — server-side stream
# already wrote this row. Look it up and reuse it.
if turn_id_value is None:
# Defensive: ON CONFLICT only fires for ``turn_id IS
# NOT NULL`` rows, so this branch should be
# unreachable. Preserve the legacy 400 just in case
# Postgres ever surprises us.
raise HTTPException(
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
lookup = await session.execute(
select(NewChatMessage).filter(
NewChatMessage.thread_id == thread_id,
NewChatMessage.turn_id == turn_id_value,
NewChatMessage.role == message_role,
)
)
existing_message = lookup.scalars().first()
if existing_message is None:
# Conflict reported but the row vanished between
# INSERT and SELECT — extremely unlikely (would
# require a concurrent DELETE within the same
# transaction visibility), but preserve safe
# behavior.
raise HTTPException(
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
db_message = existing_message
# Perf signal: counts how often the legacy FE round-trip
# races the server-side ``finalize_assistant_turn``. A
# rising rate after the rework is OK (it's exactly the
# ghost-thread fix's recovery path firing); a sudden drop
# to zero would mean the FE isn't posting appendMessage
# at all (different bug).
_perf_log.info(
"[append_message] outcome=recovered_via_unique_index "
"thread_id=%s turn_id=%s role=%s message_id=%s",
thread_id,
turn_id_value,
message_role.value,
db_message.id,
)
else:
# INSERT succeeded — load the full ORM row so the
# response can include server-side-defaulted columns
# (``created_at``, etc.) and the relationship surface
# stays consistent with the recovery path.
inserted_row = await session.get(NewChatMessage, inserted_id)
if inserted_row is None:
# Should be impossible: we just inserted it in this
# same transaction. Fail loud if it happens.
raise HTTPException(
status_code=500,
detail="Inserted message could not be loaded.",
) from None
db_message = inserted_row
# Persist token usage if provided (for assistant messages).
# ``cost_micros`` is the provider USD cost reported by LiteLLM,
# forwarded by the FE through the appendMessage round-trip so
# the historical TokenUsage row matches the credit debit applied
# at finalize time.
#
# De-dup: ``finalize_assistant_turn`` may also race to write a
# token_usage row for this same ``message_id`` (cross-session,
# cross-shielded). Use ``INSERT ... ON CONFLICT DO NOTHING`` keyed
# on the ``uq_token_usage_message_id`` partial unique index
# (migration 142). The loser silently drops its insert; exactly
# one row results regardless of which writer commits first.
token_usage_data = raw_body.get("token_usage")
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
await record_token_usage(
session,
usage_type="chat",
search_space_id=thread.search_space_id,
user_id=user.id,
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
completion_tokens=token_usage_data.get("completion_tokens", 0),
total_tokens=token_usage_data.get("total_tokens", 0),
cost_micros=token_usage_data.get("cost_micros", 0),
model_breakdown=token_usage_data.get("usage"),
call_details=token_usage_data.get("call_details"),
thread_id=thread_id,
message_id=db_message.id,
insert_stmt = (
pg_insert(TokenUsage)
.values(
usage_type="chat",
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
completion_tokens=token_usage_data.get("completion_tokens", 0),
total_tokens=token_usage_data.get("total_tokens", 0),
cost_micros=token_usage_data.get("cost_micros", 0),
model_breakdown=token_usage_data.get("usage"),
call_details=token_usage_data.get("call_details"),
thread_id=thread_id,
message_id=db_message.id,
search_space_id=thread.search_space_id,
user_id=user_uuid,
)
.on_conflict_do_nothing(
index_elements=["message_id"],
index_where=sa_text("message_id IS NOT NULL"),
)
)
await session.execute(insert_stmt)
await session.commit()
@ -1406,6 +1567,9 @@ async def append_message(
except HTTPException:
raise
except IntegrityError:
# Any IntegrityError that escaped the inline handler above
# comes from a *different* constraint (foreign key, etc.) —
# preserve the legacy 400 path.
await session.rollback()
raise HTTPException(
status_code=400,
@ -1599,6 +1763,12 @@ async def handle_new_chat(
else None
)
mentioned_documents_payload = (
[doc.model_dump() for doc in request.mentioned_documents]
if request.mentioned_documents
else None
)
return StreamingResponse(
stream_new_chat(
user_query=request.user_query,
@ -1608,6 +1778,7 @@ async def handle_new_chat(
llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids,
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
mentioned_documents=mentioned_documents_payload,
needs_history_bootstrap=thread.needs_history_bootstrap,
thread_visibility=thread.visibility,
current_user_display_name=user.display_name or "A team member",
@ -2078,6 +2249,11 @@ async def regenerate_response(
"data": revert_results,
}
yield f"data: {json.dumps(envelope, default=str)}\n\n".encode()
mentioned_documents_payload = (
[doc.model_dump() for doc in request.mentioned_documents]
if request.mentioned_documents
else None
)
try:
async for chunk in stream_new_chat(
user_query=str(user_query_to_use),
@ -2087,6 +2263,7 @@ async def regenerate_response(
llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids,
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
mentioned_documents=mentioned_documents_payload,
checkpoint_id=target_checkpoint_id,
needs_history_bootstrap=thread.needs_history_bootstrap,
thread_visibility=thread.visibility,

View file

@ -200,6 +200,21 @@ class NewChatUserImagePart(BaseModel):
return to_data_url(self.media_type, self.data)
class MentionedDocumentInfo(BaseModel):
"""Display metadata for a single ``@``-mentioned document.
The full triple ``{id, title, document_type}`` is forwarded by the
frontend mention chip so the server can embed it in the persisted
user message ``ContentPart[]`` (single ``mentioned-documents`` part).
The history loader then renders the chips on reload without an extra
fetch mirrors the pre-refactor frontend ``persistUserTurn`` shape.
"""
id: int
title: str = Field(..., min_length=1, max_length=500)
document_type: str = Field(..., min_length=1, max_length=100)
class NewChatRequest(BaseModel):
"""Request schema for the deep agent chat endpoint."""
@ -213,6 +228,17 @@ class NewChatRequest(BaseModel):
mentioned_surfsense_doc_ids: list[int] | None = (
None # Optional SurfSense documentation IDs mentioned with @ in the chat
)
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
default=None,
description=(
"Display metadata (id, title, document_type) for every "
"@-mentioned document. Persisted as a ``mentioned-documents`` "
"ContentPart on the user message so reload renders chips "
"without an extra fetch. Optional and additive — when None "
"the user message is persisted without a mentioned-documents "
"part."
),
)
disabled_tools: list[str] | None = (
None # Optional list of tool names the user has disabled from the UI
)
@ -264,6 +290,16 @@ class RegenerateRequest(BaseModel):
)
mentioned_document_ids: list[int] | None = None
mentioned_surfsense_doc_ids: list[int] | None = None
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
default=None,
description=(
"Display metadata (id, title, document_type) for every "
"@-mentioned document on the edited user turn. Only used "
"when ``user_query`` is non-None (edit). Persisted as a "
"``mentioned-documents`` ContentPart on the new user "
"message. None means no chip metadata."
),
)
disabled_tools: list[str] | None = None
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
client_platform: Literal["web", "desktop"] = "web"
@ -334,6 +370,16 @@ class ResumeRequest(BaseModel):
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
client_platform: Literal["web", "desktop"] = "web"
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
default=None,
description=(
"Display metadata forwarded for symmetry with /new_chat and "
"/regenerate. Resume reuses the original interrupted user "
"turn so the server does not write a new user message. "
"Currently unused but accepted to keep request bodies "
"uniform across the three streaming entrypoints."
),
)
class CancelActiveTurnResponse(BaseModel):

View file

@ -0,0 +1,515 @@
"""Server-side mirror of the frontend's assistant-ui ``ContentPart`` projection.
Background
----------
The streaming chat task in ``stream_new_chat`` / ``stream_resume_chat`` yields
SSE events that the frontend folds into a ``ContentPartsState`` (see
``surfsense_web/lib/chat/streaming-state.ts`` and the matching pipeline in
``stream-pipeline.ts``). When a turn ends, the frontend calls
``buildContentForPersistence(...)`` and round-trips that ``ContentPart[]``
JSONB to ``POST /threads/{id}/messages``, which is what was historically
written to ``new_chat_messages.content``.
After the ghost-thread fix moved persistence server-side, the assistant
row is written by ``finalize_assistant_turn`` in the streaming finally
block. The frontend's later ``appendMessage`` is now a no-op (recovers
via the ``(thread_id, turn_id, role)`` partial unique index added in
migration 141), which means the *server* is now responsible for
producing the rich ``ContentPart[]`` shape the FE expects on history
reload text + reasoning + tool-call cards (with ``args``, ``argsText``,
``result``, ``langchainToolCallId``) + thinking-step buckets +
step-separators.
This module is the in-memory accumulator that mirrors the FE state for
exactly that purpose. The streaming code calls ``on_text_*`` / ``on_reasoning_*``
/ ``on_tool_*`` / ``on_thinking_step`` / ``on_step_separator`` /
``mark_interrupted`` at the same call sites it yields the matching
``streaming_service.format_*`` SSE event, so the in-memory ``parts`` list
stays in lockstep with what the FE's pipeline would have produced live.
``snapshot()`` is then taken once in the ``finally`` block and persisted
in a single UPDATE.
Pure synchronous state no DB I/O, no async, no flush callbacks. The
streaming code is responsible for driving lifecycle methods; this class
is a thin projection helper.
"""
from __future__ import annotations
import copy
import json
import logging
from typing import Any
logger = logging.getLogger(__name__)
# Mirrors the FE's filter in ``buildContentForPersistence`` / ``buildContentForUI``:
# only text/reasoning/tool-call parts count as "meaningful". data-thinking-steps
# and data-step-separator decorate the meaningful parts but never stand alone
# in a successful turn.
_MEANINGFUL_PART_TYPES: frozenset[str] = frozenset({"text", "reasoning", "tool-call"})
class AssistantContentBuilder:
"""Server-side projection of ``surfsense_web/lib/chat/streaming-state.ts``.
Output shape (deep copy of ``self.parts`` via ``snapshot()``) strictly
matches the FE ``ContentPart`` union::
| { type: "text"; text: string }
| { type: "reasoning"; text: string }
| { type: "tool-call"; toolCallId: str; toolName: str;
args: dict; result?: any; argsText?: str; langchainToolCallId?: str;
state?: "aborted" }
| { type: "data-thinking-steps"; data: { steps: ThinkingStepData[] } }
| { type: "data-step-separator"; data: { stepIndex: int } }
Order matches the wire order of the SSE events that drive the lifecycle
methods, with two FE-mirrored exceptions:
1. ``data-thinking-steps`` is a *singleton* and pinned at index 0 the
first time we see a ``data-thinking-step`` SSE event (the FE's
``updateThinkingSteps`` does ``unshift`` on first sight). Subsequent
thinking-step updates mutate that singleton in place.
2. ``data-step-separator`` is appended only when the message already has
meaningful content and the previous part isn't itself a separator
(so the FIRST step of a turn doesn't generate a leading divider).
"""
def __init__(self) -> None:
self.parts: list[dict[str, Any]] = []
# Index of the active text/reasoning part within ``parts`` while
# streaming is open; -1 means "no active part" and the next delta
# opens a fresh one. Mirrors ``ContentPartsState.currentTextPartIndex``.
self._current_text_idx: int = -1
self._current_reasoning_idx: int = -1
# ``ui_id``-keyed indexes for tool-call parts. ``ui_id`` is the
# synthetic ``call_<run_id>`` (legacy) or the LangChain
# ``tool_call.id`` (parity_v2) — same key the streaming layer
# threads through every ``tool-input-*`` / ``tool-output-*`` event.
self._tool_call_idx_by_ui_id: dict[str, int] = {}
# Live argsText accumulator (concatenated ``tool-input-delta`` chunks)
# so we can reproduce the FE's ``appendToolInputDelta`` behaviour
# before ``tool-input-available`` overwrites it with the
# pretty-printed final JSON.
self._args_text_by_ui_id: dict[str, str] = {}
# ------------------------------------------------------------------
# Text
# ------------------------------------------------------------------
def on_text_start(self, text_id: str) -> None:
"""Begin a fresh text block.
Symmetric to FE ``appendText``: opening text closes any active
reasoning so the renderer treats them as separate parts. The
actual text part isn't materialised here — it's lazily created
on the first ``on_text_delta`` so an empty start/end pair
leaves no trace. Matches the FE pipeline which has no explicit
``text-start`` handler at all.
"""
if self._current_reasoning_idx >= 0:
self._current_reasoning_idx = -1
def on_text_delta(self, text_id: str, delta: str) -> None:
if not delta:
return
if self._current_reasoning_idx >= 0:
# FE behaviour: a text delta after reasoning implicitly
# closes the reasoning block (see ``appendText`` lines
# 178-180).
self._current_reasoning_idx = -1
if (
self._current_text_idx >= 0
and 0 <= self._current_text_idx < len(self.parts)
and self.parts[self._current_text_idx].get("type") == "text"
):
self.parts[self._current_text_idx]["text"] += delta
return
self.parts.append({"type": "text", "text": delta})
self._current_text_idx = len(self.parts) - 1
def on_text_end(self, text_id: str) -> None:
"""Close the active text block.
Mirrors the wire-level ``text-end`` boundary the streaming layer
emits before tool calls / reasoning / step boundaries. The FE
pipeline implicitly closes via ``currentTextPartIndex = -1``
in ``addToolCall`` / ``appendReasoning`` / ``addStepSeparator``;
our helper does the same explicitly so callers don't have to
maintain that invariant per call site.
"""
self._current_text_idx = -1
# ------------------------------------------------------------------
# Reasoning
# ------------------------------------------------------------------
def on_reasoning_start(self, reasoning_id: str) -> None:
if self._current_text_idx >= 0:
self._current_text_idx = -1
def on_reasoning_delta(self, reasoning_id: str, delta: str) -> None:
if not delta:
return
if self._current_text_idx >= 0:
self._current_text_idx = -1
if (
self._current_reasoning_idx >= 0
and 0 <= self._current_reasoning_idx < len(self.parts)
and self.parts[self._current_reasoning_idx].get("type") == "reasoning"
):
self.parts[self._current_reasoning_idx]["text"] += delta
return
self.parts.append({"type": "reasoning", "text": delta})
self._current_reasoning_idx = len(self.parts) - 1
def on_reasoning_end(self, reasoning_id: str) -> None:
self._current_reasoning_idx = -1
# ------------------------------------------------------------------
# Tool calls
# ------------------------------------------------------------------
def on_tool_input_start(
self,
ui_id: str,
tool_name: str,
langchain_tool_call_id: str | None,
) -> None:
"""Register a tool-call card. Args are filled in by later events."""
if not ui_id:
return
# Skip duplicate registration: parity_v2 may emit
# ``tool-input-start`` from both ``on_chat_model_stream``
# (when tool_call_chunks register a name) and ``on_tool_start``
# (the canonical path). The FE de-dupes via ``toolCallIndices``;
# we mirror that here.
if ui_id in self._tool_call_idx_by_ui_id:
if langchain_tool_call_id:
idx = self._tool_call_idx_by_ui_id[ui_id]
part = self.parts[idx]
if not part.get("langchainToolCallId"):
part["langchainToolCallId"] = langchain_tool_call_id
return
part: dict[str, Any] = {
"type": "tool-call",
"toolCallId": ui_id,
"toolName": tool_name,
"args": {},
}
if langchain_tool_call_id:
part["langchainToolCallId"] = langchain_tool_call_id
self.parts.append(part)
self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1
self._current_text_idx = -1
self._current_reasoning_idx = -1
def on_tool_input_delta(self, ui_id: str, args_chunk: str) -> None:
"""Append a streamed args-delta chunk to the matching card's argsText.
Mirrors FE ``appendToolInputDelta``: no-ops when no card has been
registered yet for the given ``ui_id`` the deltas have nowhere
safe to land.
"""
if not ui_id or not args_chunk:
return
idx = self._tool_call_idx_by_ui_id.get(ui_id)
if idx is None:
return
if not (0 <= idx < len(self.parts)):
return
part = self.parts[idx]
if part.get("type") != "tool-call":
return
new_text = (part.get("argsText") or "") + args_chunk
part["argsText"] = new_text
self._args_text_by_ui_id[ui_id] = new_text
def on_tool_input_available(
self,
ui_id: str,
tool_name: str,
args: dict[str, Any],
langchain_tool_call_id: str | None,
) -> None:
"""Finalize the tool-call card's input.
Mirrors FE ``stream-pipeline.ts`` lines 127-153: replaces ``argsText``
with ``json.dumps(input, indent=2)`` so the post-stream card renders
pretty-printed JSON, sets the full ``args`` dict, and backfills
``langchainToolCallId`` if it wasn't known at ``tool-input-start`` time.
Also creates the card if no prior ``tool-input-start`` registered it
(legacy parity_v2-OFF / late-registration paths).
"""
if not ui_id:
return
try:
final_args_text = json.dumps(args or {}, indent=2, ensure_ascii=False)
except (TypeError, ValueError):
# Defensive: ``args`` should already be JSON-safe (the
# streaming layer sanitizes it before emitting), but if a
# caller hands us a non-serializable value we still want
# to record the call without breaking the snapshot.
final_args_text = str(args)
idx = self._tool_call_idx_by_ui_id.get(ui_id)
if idx is not None and 0 <= idx < len(self.parts):
part = self.parts[idx]
if part.get("type") == "tool-call":
part["args"] = args or {}
part["argsText"] = final_args_text
if langchain_tool_call_id and not part.get("langchainToolCallId"):
part["langchainToolCallId"] = langchain_tool_call_id
return
# No prior tool-input-start: register the card now.
new_part: dict[str, Any] = {
"type": "tool-call",
"toolCallId": ui_id,
"toolName": tool_name,
"args": args or {},
"argsText": final_args_text,
}
if langchain_tool_call_id:
new_part["langchainToolCallId"] = langchain_tool_call_id
self.parts.append(new_part)
self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1
self._current_text_idx = -1
self._current_reasoning_idx = -1
def on_tool_output_available(
self,
ui_id: str,
output: Any,
langchain_tool_call_id: str | None,
) -> None:
"""Attach the tool's output (``result``) to the matching card.
Mirrors FE ``updateToolCall``: backfill ``langchainToolCallId``
only if not already set (a NULL late-arriving value never blows
away an earlier known good one).
"""
if not ui_id:
return
idx = self._tool_call_idx_by_ui_id.get(ui_id)
if idx is None or not (0 <= idx < len(self.parts)):
return
part = self.parts[idx]
if part.get("type") != "tool-call":
return
part["result"] = output
if langchain_tool_call_id and not part.get("langchainToolCallId"):
part["langchainToolCallId"] = langchain_tool_call_id
# ------------------------------------------------------------------
# Thinking steps & step separators
# ------------------------------------------------------------------
def on_thinking_step(
self,
step_id: str,
title: str,
status: str,
items: list[str] | None,
) -> None:
"""Update / insert the singleton ``data-thinking-steps`` part.
Mirrors FE ``updateThinkingSteps``: maintain a single
``data-thinking-steps`` part anchored at index 0, replacing or
unshifting on first sight. Each ``on_thinking_step`` call
replaces the entry in the steps list keyed by ``step_id`` (or
appends if new).
"""
if not step_id:
return
new_step = {
"id": step_id,
"title": title or "",
"status": status or "in_progress",
"items": list(items) if items else [],
}
# Find existing data-thinking-steps part.
existing_idx = -1
for i, p in enumerate(self.parts):
if p.get("type") == "data-thinking-steps":
existing_idx = i
break
if existing_idx >= 0:
current_steps = self.parts[existing_idx].get("data", {}).get("steps") or []
replaced = False
for i, step in enumerate(current_steps):
if step.get("id") == step_id:
current_steps[i] = new_step
replaced = True
break
if not replaced:
current_steps.append(new_step)
self.parts[existing_idx] = {
"type": "data-thinking-steps",
"data": {"steps": current_steps},
}
return
# First sight: unshift to position 0 (FE parity).
self.parts.insert(
0,
{
"type": "data-thinking-steps",
"data": {"steps": [new_step]},
},
)
# Bump tracked indices since we inserted at the head.
if self._current_text_idx >= 0:
self._current_text_idx += 1
if self._current_reasoning_idx >= 0:
self._current_reasoning_idx += 1
for ui_id, idx in list(self._tool_call_idx_by_ui_id.items()):
self._tool_call_idx_by_ui_id[ui_id] = idx + 1
def on_step_separator(self) -> None:
"""Append a ``data-step-separator`` between consecutive model steps.
Mirrors FE ``addStepSeparator``: only emit when the message
already has meaningful content AND the previous part isn't
itself a separator. ``stepIndex`` is the running count of
separators already in ``parts``.
"""
has_content = any(p.get("type") in _MEANINGFUL_PART_TYPES for p in self.parts)
if not has_content:
return
if self.parts and self.parts[-1].get("type") == "data-step-separator":
return
step_index = sum(
1 for p in self.parts if p.get("type") == "data-step-separator"
)
self.parts.append(
{
"type": "data-step-separator",
"data": {"stepIndex": step_index},
}
)
self._current_text_idx = -1
self._current_reasoning_idx = -1
# ------------------------------------------------------------------
# Interruption handling
# ------------------------------------------------------------------
def mark_interrupted(self) -> None:
"""Close any open text/reasoning and flip running tools to aborted.
Called from the streaming ``finally`` block before ``snapshot()`` so
the persisted JSONB reflects a coherent end-state even when the
client disconnected mid-turn or the agent hit a fatal error.
- Active text/reasoning blocks: simply lose their "active"
marker (no synthetic content appended). Whatever was streamed
stays as-is.
- Tool-call parts that never received a ``result`` get
``state="aborted"`` so the FE history loader can render them
as "interrupted" rather than "still running".
"""
self._current_text_idx = -1
self._current_reasoning_idx = -1
for part in self.parts:
if part.get("type") != "tool-call":
continue
if "result" in part:
continue
part["state"] = "aborted"
# ------------------------------------------------------------------
# Snapshot & introspection
# ------------------------------------------------------------------
def snapshot(self) -> list[dict[str, Any]]:
"""Return a deep copy of ``parts`` ready for SQL UPDATE / json.dumps.
Deep-copied so callers that finalize from the shielded ``finally``
block can't accidentally mutate the persisted payload while the
SQL UPDATE is in flight (the streaming layer doesn't touch the
builder after this call, but defensive copies are cheap and cheap
is what we want in a finally block).
"""
return copy.deepcopy(self.parts)
def is_empty(self) -> bool:
"""True if no meaningful content was captured.
``data-thinking-steps`` and ``data-step-separator`` decorate
meaningful content but don't count on their own — a turn that
only emitted a thinking step before being interrupted should
still be treated as empty for the status-marker fallback.
"""
return not any(p.get("type") in _MEANINGFUL_PART_TYPES for p in self.parts)
def stats(self) -> dict[str, int]:
"""Return counts of each part-type plus rough byte size.
Used by the streaming layer's perf logger so an ops dashboard
can correlate finalize latency with payload size, and so a
regression that quietly stops emitting tool-call parts (or
starts emitting hundreds) shows up in [PERF] grep rather than
only as a "history reload looks weird" bug report.
``bytes`` is the JSON-serialised payload length what actually
crosses the wire to PostgreSQL's JSONB column. We compute it
with ``ensure_ascii=False`` to match the JSONB encoder's UTF-8
on-disk layout closely enough for back-of-the-envelope sizing.
Reasoning/text/tool-call/thinking-step/step-separator counts are
independent so any one can spike without the others.
Defensive: ``json.dumps`` failure (a non-serializable value
slipped past the streaming layer's sanitization) is reported as
``bytes=-1`` rather than raised perf logging must not be the
thing that breaks the streaming finally block.
"""
text_blocks = 0
reasoning_blocks = 0
tool_calls = 0
tool_calls_completed = 0
tool_calls_aborted = 0
thinking_step_parts = 0
step_separators = 0
for part in self.parts:
kind = part.get("type")
if kind == "text":
text_blocks += 1
elif kind == "reasoning":
reasoning_blocks += 1
elif kind == "tool-call":
tool_calls += 1
if part.get("state") == "aborted":
tool_calls_aborted += 1
elif "result" in part:
tool_calls_completed += 1
elif kind == "data-thinking-steps":
thinking_step_parts += 1
elif kind == "data-step-separator":
step_separators += 1
try:
byte_size = len(json.dumps(self.parts, ensure_ascii=False, default=str))
except (TypeError, ValueError):
byte_size = -1
return {
"parts": len(self.parts),
"bytes": byte_size,
"text": text_blocks,
"reasoning": reasoning_blocks,
"tool_calls": tool_calls,
"tool_calls_completed": tool_calls_completed,
"tool_calls_aborted": tool_calls_aborted,
"thinking_step_parts": thinking_step_parts,
"step_separators": step_separators,
}

View file

@ -0,0 +1,534 @@
"""Server-side message persistence helpers for the streaming chat agent.
Historically the streaming task (``stream_new_chat``/``stream_resume_chat``)
left ``new_chat_messages`` empty and relied on the frontend to round-trip
``POST /threads/{id}/messages`` afterwards. That gave authenticated clients
a "ghost-thread" abuse vector: skip the round-trip and burn LLM tokens
without leaving an audit trail. These helpers move both writes (the user
turn that triggered the stream and the assistant turn the stream produced)
into the server itself, idempotent against the partial unique index
``uq_new_chat_messages_thread_turn_role`` so legacy frontends that *do*
keep posting via ``appendMessage`` simply hit the unique-index recovery
path on the second writer instead of creating duplicates.
Assistant turn lifecycle
------------------------
The assistant side is split into two helpers so we can capture the row id
*before* the stream produces any output:
* ``persist_assistant_shell`` runs immediately after ``persist_user_turn``
and INSERTs an empty assistant row anchored to ``(thread_id, turn_id,
ASSISTANT)``. Returns the row id so the streaming layer can correlate
later writes (token_usage, AgentActionLog future-correlation) against
a stable PK from the start of the turn.
* ``finalize_assistant_turn`` runs from the streaming ``finally`` block.
It UPDATEs the row's ``content`` to the rich ``ContentPart[]`` snapshot
produced server-side by ``AssistantContentBuilder`` and writes the
``token_usage`` row using ``INSERT ... ON CONFLICT DO NOTHING`` against
the ``uq_token_usage_message_id`` partial unique index from migration
142, hard-eliminating any race against ``append_message``'s recovery
branch.
Defensive contract
------------------
* Every helper runs inside ``shielded_async_session()`` so ``session.close()``
survives starlette's mid-stream cancel scope on client disconnect.
* ``persist_user_turn`` and ``persist_assistant_shell`` use ``INSERT ... ON
CONFLICT DO NOTHING ... RETURNING id`` keyed on the ``(thread_id, turn_id,
role)`` partial unique index. On conflict the insert silently no-ops at
the DB level no Python ``IntegrityError`` is constructed, which
eliminates spurious debugger pauses and keeps logs clean. On conflict a
follow-up ``SELECT`` resolves the existing row id so the streaming layer
can correlate writes against a stable PK.
* ``finalize_assistant_turn`` is best-effort: it never raises. The
streaming ``finally`` block calls it from within
``anyio.CancelScope(shield=True)`` and any raised exception there
would mask the real error.
"""
from __future__ import annotations
import logging
import time
from datetime import UTC, datetime
from typing import Any
from uuid import UUID
from sqlalchemy import text as sa_text
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.future import select
from app.db import (
NewChatMessage,
NewChatMessageRole,
NewChatThread,
TokenUsage,
shielded_async_session,
)
from app.services.token_tracking_service import (
TurnTokenAccumulator,
)
from app.utils.perf import get_perf_logger
logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
# Empty initial assistant content. ``finalize_assistant_turn`` overwrites
# this in a single UPDATE at end-of-stream with the full ``ContentPart[]``
# snapshot produced by ``AssistantContentBuilder``. We persist a one-element
# list with an empty text part so a crash between shell-INSERT and finalize
# leaves the row in a FE-renderable shape (blank bubble) instead of
# blowing up the history loader.
_EMPTY_SHELL_CONTENT: list[dict[str, Any]] = [{"type": "text", "text": ""}]
# Substituted content for genuinely empty turns (no text, no reasoning,
# no tool calls). The streaming layer flips to this when
# ``AssistantContentBuilder.is_empty()`` returns True so the persisted
# row is at least somewhat self-describing instead of an empty text
# bubble. The FE's ``ContentPart`` union doesn't include ``status``
# yet, so the history loader will silently drop this part and render
# a blank bubble (matches today's behaviour for empty turns); a follow-up
# FE PR adds the explicit "no response" rendering.
_STATUS_NO_RESPONSE: list[dict[str, Any]] = [
{"type": "status", "text": "(no text response)"}
]
def _build_user_content(
user_query: str,
user_image_data_urls: list[str] | None,
mentioned_documents: list[dict[str, Any]] | None = None,
) -> list[dict[str, Any]]:
"""Build the persisted user-message ``content`` (assistant-ui v2 parts).
Mirrors the shape the existing frontend posts via
``appendMessage`` (see ``surfsense_web/.../new-chat/[[...chat_id]]/page.tsx``):
[{"type": "text", "text": "..."},
{"type": "image", "image": "data:..."},
{"type": "mentioned-documents", "documents": [{"id": int,
"title": str, "document_type": str}, ...]}]
The companion reader is
``app.utils.user_message_multimodal.split_persisted_user_content_parts``
which expects exactly this shape keep them in sync.
``mentioned_documents``: optional list of ``{id, title, document_type}``
dicts. When non-empty (and a ``mentioned-documents`` part is not already
in some other input shape), a single ``{"type": "mentioned-documents",
"documents": [...]}`` part is appended. Mirrors the FE injection at
``page.tsx:281-286`` (``persistUserTurn``).
"""
parts: list[dict[str, Any]] = [{"type": "text", "text": user_query or ""}]
for url in user_image_data_urls or ():
if isinstance(url, str) and url:
parts.append({"type": "image", "image": url})
if mentioned_documents:
normalized: list[dict[str, Any]] = []
for doc in mentioned_documents:
if not isinstance(doc, dict):
continue
doc_id = doc.get("id")
title = doc.get("title")
document_type = doc.get("document_type")
if doc_id is None or title is None or document_type is None:
continue
normalized.append(
{
"id": doc_id,
"title": str(title),
"document_type": str(document_type),
}
)
if normalized:
parts.append({"type": "mentioned-documents", "documents": normalized})
return parts
async def persist_user_turn(
*,
chat_id: int,
user_id: str | None,
turn_id: str,
user_query: str,
user_image_data_urls: list[str] | None = None,
mentioned_documents: list[dict[str, Any]] | None = None,
) -> int | None:
"""Persist the user-side row for a chat turn and return its ``id``.
Uses ``INSERT ... ON CONFLICT DO NOTHING ... RETURNING id`` keyed on the
``(thread_id, turn_id, role)`` partial unique index from migration 141
(``WHERE turn_id IS NOT NULL``). On conflict the insert silently no-ops
at the DB level no Python ``IntegrityError`` is constructed, which
eliminates the debugger pause that ``justMyCode=false`` + async greenlet
interactions used to produce, and keeps production logs clean.
Returns the ``id`` of the row that exists for this turn after the call:
the freshly inserted ``id`` on the happy path, or the existing ``id``
when a previous writer (legacy FE ``appendMessage`` racing the SSE
stream, redelivered request, etc.) already wrote it. Returns ``None``
only on genuine DB failure; the caller should yield a streaming error
and abort the turn so we never produce a title/assistant row that
isn't anchored to a persisted user message.
Other constraint violations (FK, NOT NULL, etc.) still raise
``IntegrityError`` only the ``(thread_id, turn_id, role)`` collision
is silenced.
"""
if not turn_id:
# Defensive: turn_id is always populated by the streaming path
# before this helper is called. If it isn't, we cannot be
# idempotent against the unique index — refuse to write rather
# than create a row the unique index can't dedupe.
logger.error(
"persist_user_turn called without a turn_id (chat_id=%s); skipping",
chat_id,
)
return None
t0 = time.perf_counter()
outcome = "failed"
resolved_id: int | None = None
try:
async with shielded_async_session() as ws:
# Re-attach the thread row so we can also bump updated_at
# in the same write — keeps the sidebar ordering accurate
# when a user fires off a turn but never reaches the
# legacy appendMessage.
thread = await ws.get(NewChatThread, chat_id)
author_uuid: UUID | None = None
if user_id:
try:
author_uuid = UUID(user_id)
except (TypeError, ValueError):
logger.warning(
"persist_user_turn: invalid user_id=%r, persisting as anonymous",
user_id,
)
content_payload = _build_user_content(
user_query, user_image_data_urls, mentioned_documents
)
insert_stmt = (
pg_insert(NewChatMessage)
.values(
thread_id=chat_id,
role=NewChatMessageRole.USER,
content=content_payload,
author_id=author_uuid,
turn_id=turn_id,
)
.on_conflict_do_nothing(
index_elements=["thread_id", "turn_id", "role"],
index_where=sa_text("turn_id IS NOT NULL"),
)
.returning(NewChatMessage.id)
)
inserted_id = (await ws.execute(insert_stmt)).scalar()
if inserted_id is None:
# Conflict on partial unique index — another writer
# (legacy FE appendMessage, redelivered request, etc.)
# already persisted this row. Look it up and reuse.
lookup = await ws.execute(
select(NewChatMessage.id).where(
NewChatMessage.thread_id == chat_id,
NewChatMessage.turn_id == turn_id,
NewChatMessage.role == NewChatMessageRole.USER,
)
)
existing_id = lookup.scalars().first()
if existing_id is None:
# Conflict reported but no row found — extremely
# unlikely (concurrent DELETE). Surface as failure.
logger.warning(
"persist_user_turn: conflict but no matching row "
"(chat_id=%s, turn_id=%s)",
chat_id,
turn_id,
)
outcome = "integrity_no_match"
return None
resolved_id = int(existing_id)
outcome = "race_recovered"
else:
resolved_id = int(inserted_id)
outcome = "inserted"
# Bump thread.updated_at only on a real insert — when
# we recovered an existing row the prior writer
# already touched the thread.
if thread is not None:
thread.updated_at = datetime.now(UTC)
await ws.commit()
return resolved_id
except Exception:
logger.exception(
"persist_user_turn failed (chat_id=%s, turn_id=%s)",
chat_id,
turn_id,
)
return None
finally:
_perf_log.info(
"[persist_user_turn] outcome=%s chat_id=%s turn_id=%s "
"message_id=%s query_len=%d images=%d mentioned_docs=%d "
"in %.3fs",
outcome,
chat_id,
turn_id,
resolved_id,
len(user_query or ""),
len(user_image_data_urls or ()),
len(mentioned_documents or ()),
time.perf_counter() - t0,
)
async def persist_assistant_shell(
*,
chat_id: int,
user_id: str | None,
turn_id: str,
) -> int | None:
"""Pre-write an empty assistant row for the turn and return its id.
Inserts a placeholder ``new_chat_messages`` row (empty text content) so
the streaming layer has a stable ``message_id`` to correlate against
for the rest of the turn. ``finalize_assistant_turn`` overwrites the
``content`` field at end-of-stream with the rich ``ContentPart[]``
snapshot produced by ``AssistantContentBuilder``.
Returns the row id on success, ``None`` on a genuine DB failure (caller
should abort the turn rather than stream into a void).
Idempotent against the ``(thread_id, turn_id, ASSISTANT)`` partial unique
index from migration 141: if a row already exists (resume retry, racing
legacy frontend, redelivered request, etc.) we look it up by
``(thread_id, turn_id, role)`` and return its existing id. The streaming
layer is then free to UPDATE that row at finalize time.
"""
if not turn_id:
logger.error(
"persist_assistant_shell called without a turn_id (chat_id=%s); skipping",
chat_id,
)
return None
t0 = time.perf_counter()
outcome = "failed"
resolved_id: int | None = None
try:
async with shielded_async_session() as ws:
insert_stmt = (
pg_insert(NewChatMessage)
.values(
thread_id=chat_id,
role=NewChatMessageRole.ASSISTANT,
content=_EMPTY_SHELL_CONTENT,
author_id=None,
turn_id=turn_id,
)
.on_conflict_do_nothing(
index_elements=["thread_id", "turn_id", "role"],
index_where=sa_text("turn_id IS NOT NULL"),
)
.returning(NewChatMessage.id)
)
inserted_id = (await ws.execute(insert_stmt)).scalar()
if inserted_id is None:
# Conflict — another writer (legacy FE appendMessage,
# resume retry, redelivered request) wrote the
# (thread_id, turn_id, ASSISTANT) row first. Look it up
# so the streaming layer can UPDATE the same row at
# finalize time.
lookup = await ws.execute(
select(NewChatMessage.id).where(
NewChatMessage.thread_id == chat_id,
NewChatMessage.turn_id == turn_id,
NewChatMessage.role == NewChatMessageRole.ASSISTANT,
)
)
existing_id = lookup.scalars().first()
if existing_id is None:
logger.warning(
"persist_assistant_shell: conflict but no matching "
"(thread_id, turn_id, role) row found "
"(chat_id=%s, turn_id=%s)",
chat_id,
turn_id,
)
outcome = "integrity_no_match"
return None
resolved_id = int(existing_id)
outcome = "race_recovered"
else:
resolved_id = int(inserted_id)
outcome = "inserted"
await ws.commit()
return resolved_id
except Exception:
logger.exception(
"persist_assistant_shell failed (chat_id=%s, turn_id=%s)",
chat_id,
turn_id,
)
return None
finally:
_perf_log.info(
"[persist_assistant_shell] outcome=%s chat_id=%s turn_id=%s "
"message_id=%s in %.3fs",
outcome,
chat_id,
turn_id,
resolved_id,
time.perf_counter() - t0,
)
async def finalize_assistant_turn(
*,
message_id: int,
chat_id: int,
search_space_id: int,
user_id: str | None,
turn_id: str,
content: list[dict[str, Any]],
accumulator: TurnTokenAccumulator | None,
) -> None:
"""Finalize the assistant row and write its token_usage.
Two writes in a single shielded session:
1. ``UPDATE new_chat_messages SET content = :c, updated_at = now()
WHERE id = :id`` overwrites the placeholder ``persist_assistant_shell``
wrote with the full ``ContentPart[]`` snapshot produced server-side.
2. ``INSERT INTO token_usage (...) VALUES (...) ON CONFLICT (message_id)
WHERE message_id IS NOT NULL DO NOTHING`` uses the partial unique
index ``uq_token_usage_message_id`` from migration 142 to make the
insert idempotent against ``append_message``'s recovery branch
(which uses the same ON CONFLICT clause).
Substitutes the status-marker payload when ``content`` is empty
(pure tool-call turn that aborted before any output, or interrupt
before any event arrived). The status marker is preferable to a
blank text bubble because token accounting still runs and an ops
dashboard can flag the row.
Best-effort never raises. The streaming ``finally`` calls this
from within ``anyio.CancelScope(shield=True)``; any raised exception
here would mask the real error that triggered the cleanup.
"""
if not turn_id:
logger.error(
"finalize_assistant_turn called without turn_id "
"(chat_id=%s, message_id=%s); skipping",
chat_id,
message_id,
)
return
if not message_id:
logger.error(
"finalize_assistant_turn called without message_id "
"(chat_id=%s, turn_id=%s); skipping",
chat_id,
turn_id,
)
return
payload: list[dict[str, Any]]
is_status_marker = False
if content:
payload = content
else:
payload = _STATUS_NO_RESPONSE
is_status_marker = True
t0 = time.perf_counter()
outcome = "failed"
token_usage_attempted = bool(
accumulator is not None and accumulator.calls and user_id
)
try:
async with shielded_async_session() as ws:
assistant_row = await ws.get(NewChatMessage, message_id)
if assistant_row is None:
logger.warning(
"finalize_assistant_turn: row not found "
"(chat_id=%s, message_id=%s, turn_id=%s); skipping",
chat_id,
message_id,
turn_id,
)
outcome = "row_missing"
return
assistant_row.content = payload
assistant_row.updated_at = datetime.now(UTC)
# Token usage. ``record_token_usage`` (used elsewhere) does
# SELECT-then-INSERT in two statements which races with
# ``append_message``. Switch to a single INSERT ... ON
# CONFLICT DO NOTHING keyed on the migration-142 partial
# unique index so the loser silently drops its write at
# the DB level — exactly one row per ``message_id``,
# regardless of which session committed first.
if accumulator is not None and accumulator.calls and user_id:
try:
user_uuid = UUID(user_id)
except (TypeError, ValueError):
logger.warning(
"finalize_assistant_turn: invalid user_id=%r, "
"skipping token_usage row",
user_id,
)
else:
insert_stmt = (
pg_insert(TokenUsage)
.values(
usage_type="chat",
prompt_tokens=accumulator.total_prompt_tokens,
completion_tokens=accumulator.total_completion_tokens,
total_tokens=accumulator.grand_total,
cost_micros=accumulator.total_cost_micros,
model_breakdown=accumulator.per_message_summary(),
call_details={"calls": accumulator.serialized_calls()},
thread_id=chat_id,
message_id=message_id,
search_space_id=search_space_id,
user_id=user_uuid,
)
.on_conflict_do_nothing(
index_elements=["message_id"],
index_where=sa_text("message_id IS NOT NULL"),
)
)
await ws.execute(insert_stmt)
await ws.commit()
outcome = "ok"
except Exception:
logger.exception(
"finalize_assistant_turn failed (chat_id=%s, message_id=%s, turn_id=%s)",
chat_id,
message_id,
turn_id,
)
finally:
_perf_log.info(
"[finalize_assistant_turn] outcome=%s chat_id=%s message_id=%s "
"turn_id=%s parts=%d status_marker=%s "
"token_usage_attempted=%s in %.3fs",
outcome,
chat_id,
message_id,
turn_id,
len(payload),
is_status_marker,
token_usage_attempted,
time.perf_counter() - t0,
)

View file

@ -25,7 +25,6 @@ from uuid import UUID
import anyio
from langchain_core.messages import HumanMessage
from sqlalchemy import func
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
@ -314,6 +313,19 @@ class StreamResult:
verification_succeeded: bool = False
commit_gate_passed: bool = True
commit_gate_reason: str = ""
# Pre-allocated assistant ``new_chat_messages.id`` for this turn,
# captured by ``persist_assistant_shell`` right after the user row is
# persisted. ``None`` for the legacy / anonymous code paths that don't
# opt in to server-side ``ContentPart[]`` projection.
assistant_message_id: int | None = None
# In-memory mirror of the FE's assistant-ui ``ContentPartsState``,
# populated by the lifecycle methods called from ``_stream_agent_events``
# at each ``streaming_service.format_*`` yield site. Snapshot in the
# streaming ``finally`` to produce the rich JSONB persisted by
# ``finalize_assistant_turn``. ``repr=False`` keeps the
# log-on-error path (``StreamResult`` is logged in some error
# branches) from dumping a potentially-large parts list.
content_builder: Any | None = field(default=None, repr=False)
def _safe_float(value: Any, default: float = 0.0) -> float:
@ -721,6 +733,7 @@ async def _stream_agent_events(
fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
fallback_commit_thread_id: int | None = None,
runtime_context: Any = None,
content_builder: Any | None = None,
) -> AsyncGenerator[str, None]:
"""Shared async generator that streams and formats astream_events from the agent.
@ -737,6 +750,15 @@ async def _stream_agent_events(
initial_step_id: If set, the helper inherits an already-active thinking step.
initial_step_title: Title of the inherited thinking step.
initial_step_items: Items of the inherited thinking step.
content_builder: Optional ``AssistantContentBuilder``. When set, every
``streaming_service.format_*`` yield site also drives the matching
builder lifecycle method (``on_text_*``, ``on_reasoning_*``,
``on_tool_*``, ``on_thinking_step``, ``on_step_separator``) so the
in-memory ``ContentPart[]`` projection stays in lockstep with what
the FE renders live. Pure in-memory accumulation no DB I/O
consumed by the streaming ``finally`` to produce the rich JSONB
persisted via ``finalize_assistant_turn``. ``None`` (the default)
is used by the anonymous / legacy code paths and is a no-op.
Yields:
SSE-formatted strings for each event.
@ -801,12 +823,46 @@ async def _stream_agent_events(
current_lc_tool_call_id: dict[str, str | None] = {"value": None}
def _emit_tool_output(call_id: str, output: Any) -> str:
# Drive the builder before formatting the SSE so the in-memory
# ContentPart[] mirror sees the result attached to the same
# card the FE will render. Builder method is a no-op when
# ``content_builder`` is None (anonymous / legacy paths).
if content_builder is not None:
content_builder.on_tool_output_available(
call_id, output, current_lc_tool_call_id["value"]
)
return streaming_service.format_tool_output_available(
call_id,
output,
langchain_tool_call_id=current_lc_tool_call_id["value"],
)
def _emit_thinking_step(
*,
step_id: str,
title: str,
status: str = "in_progress",
items: list[str] | None = None,
) -> str:
"""Format a thinking-step SSE event and notify the builder.
Single helper used at every ``format_thinking_step`` yield site
in this generator. Drives ``AssistantContentBuilder.on_thinking_step``
first so the FE-mirror state lands the update before the SSE
carrying the same data leaves the wire order matches the FE
pipeline (``processSharedStreamEvent`` updates state, then
flushes). Builder call is a no-op when ``content_builder`` is
None (anonymous / legacy paths).
"""
if content_builder is not None:
content_builder.on_thinking_step(step_id, title, status, items)
return streaming_service.format_thinking_step(
step_id=step_id,
title=title,
status=status,
items=items,
)
def next_thinking_step_id() -> str:
nonlocal thinking_step_counter
thinking_step_counter += 1
@ -816,7 +872,7 @@ async def _stream_agent_events(
nonlocal last_active_step_id
if last_active_step_id and last_active_step_id not in completed_step_ids:
completed_step_ids.add(last_active_step_id)
event = streaming_service.format_thinking_step(
event = _emit_thinking_step(
step_id=last_active_step_id,
title=last_active_step_title,
status="completed",
@ -861,6 +917,8 @@ async def _stream_agent_events(
if parity_v2 and reasoning_delta:
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
if content_builder is not None:
content_builder.on_text_end(current_text_id)
current_text_id = None
if current_reasoning_id is None:
completion_event = complete_current_step()
@ -873,13 +931,21 @@ async def _stream_agent_events(
just_finished_tool = False
current_reasoning_id = streaming_service.generate_reasoning_id()
yield streaming_service.format_reasoning_start(current_reasoning_id)
if content_builder is not None:
content_builder.on_reasoning_start(current_reasoning_id)
yield streaming_service.format_reasoning_delta(
current_reasoning_id, reasoning_delta
)
if content_builder is not None:
content_builder.on_reasoning_delta(
current_reasoning_id, reasoning_delta
)
if text_delta:
if current_reasoning_id is not None:
yield streaming_service.format_reasoning_end(current_reasoning_id)
if content_builder is not None:
content_builder.on_reasoning_end(current_reasoning_id)
current_reasoning_id = None
if current_text_id is None:
completion_event = complete_current_step()
@ -892,8 +958,12 @@ async def _stream_agent_events(
just_finished_tool = False
current_text_id = streaming_service.generate_text_id()
yield streaming_service.format_text_start(current_text_id)
if content_builder is not None:
content_builder.on_text_start(current_text_id)
yield streaming_service.format_text_delta(current_text_id, text_delta)
accumulated_text += text_delta
if content_builder is not None:
content_builder.on_text_delta(current_text_id, text_delta)
# Live tool-call argument streaming. Runs AFTER text/reasoning
# processing so chunks containing both stay in their natural
@ -925,11 +995,17 @@ async def _stream_agent_events(
# within the same stream window.
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
if content_builder is not None:
content_builder.on_text_end(current_text_id)
current_text_id = None
if current_reasoning_id is not None:
yield streaming_service.format_reasoning_end(
current_reasoning_id
)
if content_builder is not None:
content_builder.on_reasoning_end(
current_reasoning_id
)
current_reasoning_id = None
index_to_meta[idx] = {
@ -942,6 +1018,8 @@ async def _stream_agent_events(
name,
langchain_tool_call_id=lc_id,
)
if content_builder is not None:
content_builder.on_tool_input_start(ui_id, name, lc_id)
# Emit args delta for any chunk at a registered
# index (including idless continuations). Once an
@ -957,6 +1035,10 @@ async def _stream_agent_events(
yield streaming_service.format_tool_input_delta(
meta["ui_id"], args_chunk
)
if content_builder is not None:
content_builder.on_tool_input_delta(
meta["ui_id"], args_chunk
)
else:
pending_tool_call_chunks.append(tcc)
@ -974,6 +1056,8 @@ async def _stream_agent_events(
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
if content_builder is not None:
content_builder.on_text_end(current_text_id)
current_text_id = None
if last_active_step_title != "Synthesizing response":
@ -994,7 +1078,7 @@ async def _stream_agent_events(
)
last_active_step_title = "Listing files"
last_active_step_items = [ls_path]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Listing files",
status="in_progress",
@ -1009,7 +1093,7 @@ async def _stream_agent_events(
display_fp = fp if len(fp) <= 80 else "" + fp[-77:]
last_active_step_title = "Reading file"
last_active_step_items = [display_fp]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Reading file",
status="in_progress",
@ -1024,7 +1108,7 @@ async def _stream_agent_events(
display_fp = fp if len(fp) <= 80 else "" + fp[-77:]
last_active_step_title = "Writing file"
last_active_step_items = [display_fp]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Writing file",
status="in_progress",
@ -1039,7 +1123,7 @@ async def _stream_agent_events(
display_fp = fp if len(fp) <= 80 else "" + fp[-77:]
last_active_step_title = "Editing file"
last_active_step_items = [display_fp]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Editing file",
status="in_progress",
@ -1056,7 +1140,7 @@ async def _stream_agent_events(
)
last_active_step_title = "Searching files"
last_active_step_items = [f"{pat} in {base_path}"]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Searching files",
status="in_progress",
@ -1076,7 +1160,7 @@ async def _stream_agent_events(
last_active_step_items = [
f'"{display_pat}"' + (f" in {grep_path}" if grep_path else "")
]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Searching content",
status="in_progress",
@ -1091,7 +1175,7 @@ async def _stream_agent_events(
display_path = rm_path if len(rm_path) <= 80 else "" + rm_path[-77:]
last_active_step_title = "Deleting file"
last_active_step_items = [display_path] if display_path else []
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Deleting file",
status="in_progress",
@ -1108,7 +1192,7 @@ async def _stream_agent_events(
)
last_active_step_title = "Deleting folder"
last_active_step_items = [display_path] if display_path else []
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Deleting folder",
status="in_progress",
@ -1125,7 +1209,7 @@ async def _stream_agent_events(
)
last_active_step_title = "Creating folder"
last_active_step_items = [display_path] if display_path else []
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Creating folder",
status="in_progress",
@ -1148,7 +1232,7 @@ async def _stream_agent_events(
last_active_step_items = (
[f"{display_src}{display_dst}"] if src or dst else []
)
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Moving file",
status="in_progress",
@ -1165,7 +1249,7 @@ async def _stream_agent_events(
if todo_count
else []
)
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Planning tasks",
status="in_progress",
@ -1180,7 +1264,7 @@ async def _stream_agent_events(
display_title = doc_title[:60] + ("" if len(doc_title) > 60 else "")
last_active_step_title = "Saving document"
last_active_step_items = [display_title]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Saving document",
status="in_progress",
@ -1196,7 +1280,7 @@ async def _stream_agent_events(
last_active_step_items = [
f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}"
]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Generating image",
status="in_progress",
@ -1212,7 +1296,7 @@ async def _stream_agent_events(
last_active_step_items = [
f"URL: {url[:80]}{'...' if len(url) > 80 else ''}"
]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Scraping webpage",
status="in_progress",
@ -1235,7 +1319,7 @@ async def _stream_agent_events(
f"Content: {content_len:,} characters",
"Preparing audio generation...",
]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Generating podcast",
status="in_progress",
@ -1256,7 +1340,7 @@ async def _stream_agent_events(
f"Topic: {report_topic}",
"Analyzing source content...",
]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title=step_title,
status="in_progress",
@ -1271,7 +1355,7 @@ async def _stream_agent_events(
display_cmd = cmd[:80] + ("" if len(cmd) > 80 else "")
last_active_step_title = "Running command"
last_active_step_items = [f"$ {display_cmd}"]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Running command",
status="in_progress",
@ -1288,7 +1372,7 @@ async def _stream_agent_events(
tool_name.replace("_", " ").strip().capitalize() or tool_name
)
last_active_step_items = []
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title=last_active_step_title,
status="in_progress",
@ -1349,6 +1433,10 @@ async def _stream_agent_events(
tool_name,
langchain_tool_call_id=langchain_tool_call_id,
)
if content_builder is not None:
content_builder.on_tool_input_start(
tool_call_id, tool_name, langchain_tool_call_id
)
if run_id:
ui_tool_call_id_by_run[run_id] = tool_call_id
@ -1371,6 +1459,13 @@ async def _stream_agent_events(
_safe_input,
langchain_tool_call_id=langchain_tool_call_id,
)
if content_builder is not None:
content_builder.on_tool_input_available(
tool_call_id,
tool_name,
_safe_input,
langchain_tool_call_id,
)
elif event_type == "on_tool_end":
active_tool_depth = max(0, active_tool_depth - 1)
@ -1443,70 +1538,70 @@ async def _stream_agent_events(
current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id]
if tool_name == "read_file":
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Reading file",
status="completed",
items=last_active_step_items,
)
elif tool_name == "write_file":
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Writing file",
status="completed",
items=last_active_step_items,
)
elif tool_name == "edit_file":
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Editing file",
status="completed",
items=last_active_step_items,
)
elif tool_name == "glob":
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Searching files",
status="completed",
items=last_active_step_items,
)
elif tool_name == "grep":
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Searching content",
status="completed",
items=last_active_step_items,
)
elif tool_name == "rm":
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Deleting file",
status="completed",
items=last_active_step_items,
)
elif tool_name == "rmdir":
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Deleting folder",
status="completed",
items=last_active_step_items,
)
elif tool_name == "mkdir":
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Creating folder",
status="completed",
items=last_active_step_items,
)
elif tool_name == "move_file":
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Moving file",
status="completed",
items=last_active_step_items,
)
elif tool_name == "write_todos":
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Planning tasks",
status="completed",
@ -1523,7 +1618,7 @@ async def _stream_agent_events(
*last_active_step_items,
result_str[:80] if is_error else "Saved to knowledge base",
]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Saving document",
status="completed",
@ -1542,7 +1637,7 @@ async def _stream_agent_events(
else "Generation failed"
)
completed_items = [*last_active_step_items, f"Error: {error_msg}"]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Generating image",
status="completed",
@ -1566,7 +1661,7 @@ async def _stream_agent_events(
]
else:
completed_items = [*last_active_step_items, "Content extracted"]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Scraping webpage",
status="completed",
@ -1612,7 +1707,7 @@ async def _stream_agent_events(
]
else:
completed_items = last_active_step_items
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Generating podcast",
status="completed",
@ -1647,7 +1742,7 @@ async def _stream_agent_events(
]
else:
completed_items = last_active_step_items
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Generating video presentation",
status="completed",
@ -1695,7 +1790,7 @@ async def _stream_agent_events(
else:
completed_items = last_active_step_items
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title=step_title,
status="completed",
@ -1721,7 +1816,7 @@ async def _stream_agent_events(
]
else:
completed_items = [*last_active_step_items, "Finished"]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Running command",
status="completed",
@ -1761,7 +1856,7 @@ async def _stream_agent_events(
completed_items.append(f"(+{len(file_names) - 4} more)")
else:
completed_items = ["No files found"]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Listing files",
status="completed",
@ -1773,7 +1868,7 @@ async def _stream_agent_events(
fallback_title = (
tool_name.replace("_", " ").strip().capitalize() or tool_name
)
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title=fallback_title,
status="completed",
@ -2113,7 +2208,7 @@ async def _stream_agent_events(
# Phase transitions: replace everything after topic
last_active_step_items = [*topic_items, message]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=last_active_step_id,
title=last_active_step_title,
status="in_progress",
@ -2155,10 +2250,14 @@ async def _stream_agent_events(
elif event_type in ("on_chain_end", "on_agent_end"):
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
if content_builder is not None:
content_builder.on_text_end(current_text_id)
current_text_id = None
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
if content_builder is not None:
content_builder.on_text_end(current_text_id)
completion_event = complete_current_step()
if completion_event:
@ -2243,8 +2342,14 @@ async def _stream_agent_events(
)
gate_text_id = streaming_service.generate_text_id()
yield streaming_service.format_text_start(gate_text_id)
if content_builder is not None:
content_builder.on_text_start(gate_text_id)
yield streaming_service.format_text_delta(gate_text_id, gate_notice)
if content_builder is not None:
content_builder.on_text_delta(gate_text_id, gate_notice)
yield streaming_service.format_text_end(gate_text_id)
if content_builder is not None:
content_builder.on_text_end(gate_text_id)
yield streaming_service.format_terminal_info(gate_notice, "error")
accumulated_text = gate_notice
else:
@ -2270,6 +2375,7 @@ async def stream_new_chat(
llm_config_id: int = -1,
mentioned_document_ids: list[int] | None = None,
mentioned_surfsense_doc_ids: list[int] | None = None,
mentioned_documents: list[dict[str, Any]] | None = None,
checkpoint_id: str | None = None,
needs_history_bootstrap: bool = False,
thread_visibility: ChatVisibility | None = None,
@ -2949,6 +3055,96 @@ async def stream_new_chat(
)
yield streaming_service.format_data("turn-status", {"status": "busy"})
# Persist the user-side row for this turn before any expensive
# work runs. Closes the "ghost-thread" abuse vector
# (authenticated client hits POST /new_chat then never calls
# /messages — empty new_chat_messages, free LLM completion).
# Idempotent against the unique index in migration 141 so the
# legacy frontend appendMessage call is a no-op on the second
# writer. Hard failure aborts the turn so we never produce a
# title or assistant row that isn't anchored to a persisted
# user message.
from app.tasks.chat.content_builder import AssistantContentBuilder
from app.tasks.chat.persistence import (
persist_assistant_shell,
persist_user_turn,
)
user_message_id = await persist_user_turn(
chat_id=chat_id,
user_id=user_id,
turn_id=stream_result.turn_id,
user_query=user_query,
user_image_data_urls=user_image_data_urls,
mentioned_documents=mentioned_documents,
)
if user_message_id is None:
yield _emit_stream_error(
message=(
"We couldn't save your message. Please try again in a moment."
),
error_kind="server_error",
error_code="MESSAGE_PERSIST_FAILED",
)
yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
return
# Emit canonical user message id BEFORE any LLM streaming so the
# FE can rename its optimistic ``msg-user-XXX`` placeholder to
# ``msg-{user_message_id}`` and unlock features gated on a real
# DB id (comments, edit-from-this-message). See B4 in
# ``sse-based_message_id_handshake`` plan.
yield streaming_service.format_data(
"user-message-id",
{"message_id": user_message_id, "turn_id": stream_result.turn_id},
)
# Pre-write the assistant row for this turn so we have a stable
# ``message_id`` to anchor mid-stream metadata (token_usage,
# future agent_action_log.message_id correlation) and a
# write-once UPDATE target at finalize time. Idempotent against
# the (thread_id, turn_id, ASSISTANT) partial unique index from
# migration 141 — if the legacy frontend appendMessage races
# this, we recover the existing row's id.
assistant_message_id = await persist_assistant_shell(
chat_id=chat_id,
user_id=user_id,
turn_id=stream_result.turn_id,
)
if assistant_message_id is None:
# Genuine DB failure — abort the turn rather than stream
# into a void. The user row is already persisted so the
# legacy "ghost-thread" gate isn't reopened.
yield _emit_stream_error(
message=(
"We couldn't initialize the assistant message. Please try again."
),
error_kind="server_error",
error_code="MESSAGE_PERSIST_FAILED",
)
yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
return
# Emit canonical assistant message id BEFORE any LLM streaming
# so the FE can rename its optimistic ``msg-assistant-XXX``
# placeholder to ``msg-{assistant_message_id}`` and bind
# ``tokenUsageStore`` / ``pendingInterrupt`` to the real id
# immediately. See B4 in ``sse-based_message_id_handshake``
# plan.
yield streaming_service.format_data(
"assistant-message-id",
{"message_id": assistant_message_id, "turn_id": stream_result.turn_id},
)
stream_result.assistant_message_id = assistant_message_id
stream_result.content_builder = AssistantContentBuilder()
# Initial thinking step - analyzing the request
if mentioned_surfsense_docs:
initial_title = "Analyzing referenced content"
@ -2981,6 +3177,15 @@ async def stream_new_chat(
initial_items = [f"{action_verb}: {' '.join(processing_parts)}"]
initial_step_id = "thinking-1"
# Drive the builder for this initial thinking step too — the
# ``_emit_thinking_step`` helper lives inside ``_stream_agent_events``
# so it isn't in scope here, but the FE folds this step into
# the same singleton ``data-thinking-steps`` part as everything
# the agent stream emits later. Mirror that fold server-side.
if stream_result.content_builder is not None:
stream_result.content_builder.on_thinking_step(
initial_step_id, initial_title, "in_progress", initial_items
)
yield streaming_service.format_thinking_step(
step_id=initial_step_id,
title=initial_title,
@ -2997,16 +3202,34 @@ async def stream_new_chat(
# Check if this is the first assistant response so we can generate
# a title in parallel with the agent stream (better UX than waiting
# until after the full response).
assistant_count_result = await session.execute(
select(func.count(NewChatMessage.id)).filter(
# Use a LIMIT 1 EXISTS-style probe rather than COUNT(*) because
# this is now a hot path executed on every turn, and COUNT scales
# with thread length (server-side persistence can grow rows
# quickly under power users).
#
# IMPORTANT: ``persist_assistant_shell`` above (line ~3112) already
# inserted THIS turn's assistant row. We must therefore exclude
# it from the probe — otherwise the gate fires on every turn
# except the very first, and title generation never runs for new
# threads. Excluding by primary key (``id != assistant_message_id``)
# is bulletproof regardless of ``turn_id`` shape (legacy NULLs,
# resume turns, etc.).
first_assistant_probe = await session.execute(
select(NewChatMessage.id)
.filter(
NewChatMessage.thread_id == chat_id,
NewChatMessage.role == "assistant",
NewChatMessage.id != assistant_message_id,
)
.limit(1)
)
is_first_response = (assistant_count_result.scalar() or 0) == 0
is_first_response = first_assistant_probe.scalars().first() is None
title_task: asyncio.Task[tuple[str | None, dict | None]] | None = None
if is_first_response:
# Gate title generation on a persisted user message so a stream
# that fails before persistence (we abort above) can never leave
# behind a thread with a generated title and no anchoring rows.
if is_first_response and user_message_id is not None:
async def _generate_title() -> tuple[str | None, dict | None]:
"""Generate a short title via litellm.acompletion.
@ -3138,6 +3361,7 @@ async def stream_new_chat(
),
fallback_commit_thread_id=chat_id,
runtime_context=runtime_context,
content_builder=stream_result.content_builder,
):
if not _first_event_logged:
_perf_log.info(
@ -3493,6 +3717,81 @@ async def stream_new_chat(
with contextlib.suppress(Exception):
await session.close()
# Server-side assistant-message + token_usage finalization.
# Runs after the main session has been closed (uses its own
# shielded session) so we don't fight the same DB connection.
# Idempotent against the legacy frontend appendMessage:
# * the assistant row was already INSERTed by
# ``persist_assistant_shell`` above, so this just UPDATEs
# it with the rich ContentPart[] from the builder.
# * token_usage uses INSERT ... ON CONFLICT DO NOTHING
# against migration 142's partial unique index, so a
# racing append_message recovery branch can never
# double-write.
# ``mark_interrupted`` closes any open text/reasoning blocks
# and flips running tool-calls (no result) to state=aborted
# so the persisted JSONB reflects a coherent end-state even
# on client disconnect.
# Never raises (best-effort, logs only).
if (
stream_result
and stream_result.turn_id
and stream_result.assistant_message_id
):
from app.tasks.chat.persistence import finalize_assistant_turn
builder_stats: dict[str, int] | None = None
if stream_result.content_builder is not None:
stream_result.content_builder.mark_interrupted()
# Snapshot stats BEFORE deepcopy in ``snapshot()`` so
# the perf log records the actual finalised payload
# (post-mark_interrupted), not the live-mutating
# builder state.
builder_stats = stream_result.content_builder.stats()
content_payload = stream_result.content_builder.snapshot()
else:
# Defensive fallback — we always set the builder
# alongside ``assistant_message_id`` above, so this
# branch only fires if a future refactor ever
# decouples them. Persist whatever accumulated
# text we captured so the row at least renders.
content_payload = [
{
"type": "text",
"text": stream_result.accumulated_text or "",
}
]
if builder_stats is not None:
_perf_log.info(
"[stream_new_chat] finalize_payload chat_id=%s "
"message_id=%s parts=%d bytes=%d text=%d "
"reasoning=%d tool_calls=%d "
"tool_calls_completed=%d tool_calls_aborted=%d "
"thinking_step_parts=%d step_separators=%d",
chat_id,
stream_result.assistant_message_id,
builder_stats["parts"],
builder_stats["bytes"],
builder_stats["text"],
builder_stats["reasoning"],
builder_stats["tool_calls"],
builder_stats["tool_calls_completed"],
builder_stats["tool_calls_aborted"],
builder_stats["thinking_step_parts"],
builder_stats["step_separators"],
)
await finalize_assistant_turn(
message_id=stream_result.assistant_message_id,
chat_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
turn_id=stream_result.turn_id,
content=content_payload,
accumulator=accumulator,
)
# Persist any sandbox-produced files to local storage so they
# remain downloadable after the Daytona sandbox auto-deletes.
if stream_result and stream_result.sandbox_files:
@ -3937,6 +4236,50 @@ async def stream_resume_chat(
)
yield streaming_service.format_data("turn-status", {"status": "busy"})
# Pre-write a fresh assistant row for this resume turn. The
# original (interrupted) ``stream_new_chat`` invocation already
# persisted its own assistant row anchored to a different
# ``turn_id``; resume allocates a new ``turn_id`` (above) so we
# need a separate row keyed on the same ``(thread_id, turn_id,
# ASSISTANT)`` invariant. Idempotent against migration 141's
# partial unique index — recovers existing id on retry.
from app.tasks.chat.content_builder import AssistantContentBuilder
from app.tasks.chat.persistence import persist_assistant_shell
assistant_message_id = await persist_assistant_shell(
chat_id=chat_id,
user_id=user_id,
turn_id=stream_result.turn_id,
)
if assistant_message_id is None:
yield _emit_stream_error(
message=(
"We couldn't initialize the assistant message. Please try again."
),
error_kind="server_error",
error_code="MESSAGE_PERSIST_FAILED",
)
yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
return
# Emit canonical assistant message id BEFORE any LLM streaming
# so the FE can rename ``pendingInterrupt.assistantMsgId`` to
# ``msg-{assistant_message_id}`` immediately. Resume does NOT
# emit ``data-user-message-id`` because the user row is from
# the original interrupted turn (different ``turn_id``) and is
# never re-persisted here. See B5 in the
# ``sse-based_message_id_handshake`` plan.
yield streaming_service.format_data(
"assistant-message-id",
{"message_id": assistant_message_id, "turn_id": stream_result.turn_id},
)
stream_result.assistant_message_id = assistant_message_id
stream_result.content_builder = AssistantContentBuilder()
# Resume path doesn't carry new ``mentioned_document_ids`` —
# those are seeded in the original turn. We still pass a
# context so future middleware extensions (Phase 2) can rely on
@ -3968,6 +4311,7 @@ async def stream_resume_chat(
),
fallback_commit_thread_id=chat_id,
runtime_context=runtime_context,
content_builder=stream_result.content_builder,
):
if not _first_event_logged:
_perf_log.info(
@ -4219,6 +4563,64 @@ async def stream_resume_chat(
with contextlib.suppress(Exception):
await session.close()
# Server-side assistant-message + token_usage finalization for
# the resume flow. The original user message was persisted by
# the original (interrupted) ``stream_new_chat`` invocation;
# the resume's own ``persist_assistant_shell`` write lives at
# the new ``turn_id`` above. This finalize updates that row
# with the rich ContentPart[] from the builder and writes
# token_usage idempotently via migration 142's partial
# unique index. Best-effort, never raises.
if (
stream_result
and stream_result.turn_id
and stream_result.assistant_message_id
):
from app.tasks.chat.persistence import finalize_assistant_turn
builder_stats: dict[str, int] | None = None
if stream_result.content_builder is not None:
stream_result.content_builder.mark_interrupted()
builder_stats = stream_result.content_builder.stats()
content_payload = stream_result.content_builder.snapshot()
else:
content_payload = [
{
"type": "text",
"text": stream_result.accumulated_text or "",
}
]
if builder_stats is not None:
_perf_log.info(
"[stream_resume] finalize_payload chat_id=%s "
"message_id=%s parts=%d bytes=%d text=%d "
"reasoning=%d tool_calls=%d "
"tool_calls_completed=%d tool_calls_aborted=%d "
"thinking_step_parts=%d step_separators=%d",
chat_id,
stream_result.assistant_message_id,
builder_stats["parts"],
builder_stats["bytes"],
builder_stats["text"],
builder_stats["reasoning"],
builder_stats["tool_calls"],
builder_stats["tool_calls_completed"],
builder_stats["tool_calls_aborted"],
builder_stats["thinking_step_parts"],
builder_stats["step_separators"],
)
await finalize_assistant_turn(
message_id=stream_result.assistant_message_id,
chat_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
turn_id=stream_result.turn_id,
content=content_payload,
accumulator=accumulator,
)
agent = llm = connector_service = None
stream_result = None
session = None

View file

@ -0,0 +1,573 @@
"""Integration tests for the cross-writer integration between the
streaming chat task and the legacy ``POST /threads/{id}/messages``
(``append_message``) round-trip.
Two scenarios anchor the contract introduced by the server-side
persistence rework:
(a) **Tool-heavy turn streamed to completion.**
Drives :class:`AssistantContentBuilder` with synthetic SSE events
that mirror what ``_stream_agent_events`` emits for a turn that
interleaves text, reasoning, a tool call (start/delta/available/
output), and a final text block. Then runs
:func:`finalize_assistant_turn` and asserts:
* ``new_chat_messages.content`` JSONB matches the
``ContentPart[]`` shape the FE history loader expects, with full
``args``/``argsText``/``result``/``langchainToolCallId`` for the
tool call.
* Exactly one ``token_usage`` row exists keyed on the assistant
``message_id``.
(b) **Stale FE ``appendMessage`` after server finalize.**
Verifies the recovery branch of the ``append_message`` route now
returns the SERVER's authoritative ``ContentPart[]`` (not the FE's
stale payload) when the partial unique index from migration 141
blocks the FE's INSERT, and that the ``ON CONFLICT DO NOTHING``
clause from migration 142 stops the route from writing a duplicate
``token_usage`` row.
"""
from __future__ import annotations
import json
from contextlib import asynccontextmanager
import pytest
import pytest_asyncio
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import (
ChatVisibility,
NewChatMessage,
NewChatMessageRole,
NewChatThread,
SearchSpace,
TokenUsage,
User,
)
from app.routes import new_chat_routes
from app.services.token_tracking_service import TurnTokenAccumulator
from app.tasks.chat import persistence as persistence_module
from app.tasks.chat.content_builder import AssistantContentBuilder
from app.tasks.chat.persistence import (
finalize_assistant_turn,
persist_assistant_shell,
)
pytestmark = pytest.mark.integration
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest_asyncio.fixture
async def db_thread(
db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
) -> NewChatThread:
thread = NewChatThread(
title="Test Chat",
search_space_id=db_search_space.id,
created_by_id=db_user.id,
visibility=ChatVisibility.PRIVATE,
)
db_session.add(thread)
await db_session.flush()
return thread
@pytest.fixture
def patched_shielded_session(monkeypatch, db_session: AsyncSession):
"""Route persistence helpers to the test's savepoint-bound session.
Mirrors the helper from ``test_persistence.py`` so the helpers'
internal ``ws.commit()`` / ``ws.rollback()`` resolve to SAVEPOINT
operations on the test transaction instead of touching real
autocommit boundaries.
"""
@asynccontextmanager
async def _fake_shielded_session():
yield db_session
monkeypatch.setattr(
persistence_module,
"shielded_async_session",
_fake_shielded_session,
)
return db_session
@pytest.fixture
def bypass_permission_checks(monkeypatch):
"""Replace RBAC + thread access checks with no-ops.
The append_message route under test calls ``check_permission`` and
``check_thread_access``; those rely on a SearchSpaceMembership row
that the existing integration fixtures don't create. The contract
we want to verify here is the ``IntegrityError`` -> recovery branch,
not the RBAC plumbing so stub them.
"""
async def _allow(*_args, **_kwargs):
return True
monkeypatch.setattr(new_chat_routes, "check_permission", _allow)
monkeypatch.setattr(new_chat_routes, "check_thread_access", _allow)
return None
class _FakeRequest:
"""Minimal Request stand-in used by ``append_message``.
The route only calls ``await request.json()`` keep the surface
area tight so this doesn't accidentally hide future signature
changes that we *would* want to break the test.
"""
def __init__(self, body: dict):
self._body = body
async def json(self) -> dict:
return self._body
def _build_tool_heavy_content() -> list[dict]:
"""Drive ``AssistantContentBuilder`` through a tool-heavy turn.
Produces the same ``ContentPart[]`` shape the streaming layer would
persist if ``_stream_agent_events`` ran a turn with: opening
reasoning -> text -> tool call (input start/delta/available/output)
-> closing text. Centralised here so the (a) and (b) scenarios use
the same authoritative payload.
"""
builder = AssistantContentBuilder()
builder.on_reasoning_start("r1")
builder.on_reasoning_delta("r1", "Let me look up ")
builder.on_reasoning_delta("r1", "the file listing.")
builder.on_reasoning_end("r1")
builder.on_text_start("t1")
builder.on_text_delta("t1", "Sure, listing files in ")
builder.on_text_delta("t1", "/.")
builder.on_text_end("t1")
builder.on_tool_input_start(
"tool_call_ui_1",
tool_name="ls",
langchain_tool_call_id="lc_call_xyz",
)
builder.on_tool_input_delta("tool_call_ui_1", '{"path"')
builder.on_tool_input_delta("tool_call_ui_1", ': "/"}')
builder.on_tool_input_available(
"tool_call_ui_1",
tool_name="ls",
args={"path": "/"},
langchain_tool_call_id="lc_call_xyz",
)
builder.on_tool_output_available(
"tool_call_ui_1",
output={"files": ["a.txt", "b.txt"]},
langchain_tool_call_id="lc_call_xyz",
)
builder.on_text_start("t2")
builder.on_text_delta("t2", "Found 2 files: a.txt and b.txt.")
builder.on_text_end("t2")
return builder.snapshot()
def _accumulator_with_one_call() -> TurnTokenAccumulator:
acc = TurnTokenAccumulator()
acc.add(
model="gpt-4o-mini",
prompt_tokens=200,
completion_tokens=80,
total_tokens=280,
cost_micros=22222,
)
return acc
# ---------------------------------------------------------------------------
# (a) Tool-heavy stream finalize
# ---------------------------------------------------------------------------
class TestToolHeavyTurnFinalize:
async def test_full_tool_call_persisted_and_one_token_usage_row(
self,
db_session,
db_user,
db_thread,
db_search_space,
patched_shielded_session,
):
"""End-to-end seam: builder snapshot -> finalize -> DB row.
Matches the production flow's *content* invariant: whatever
``AssistantContentBuilder.snapshot()`` produces is what the
streaming layer hands to ``finalize_assistant_turn``, so this
test catches any drift between the JSONB shape the builder
emits and the one the FE history loader expects.
"""
thread_id = db_thread.id
user_id_str = str(db_user.id)
search_space_id = db_search_space.id
turn_id = f"{thread_id}:tool_heavy"
msg_id = await persist_assistant_shell(
chat_id=thread_id,
user_id=user_id_str,
turn_id=turn_id,
)
assert msg_id is not None
snapshot = _build_tool_heavy_content()
# Sanity-check the snapshot before we hand it to the DB so a
# builder regression surfaces here, not deep inside an opaque
# JSONB diff.
assert any(p.get("type") == "reasoning" for p in snapshot)
text_parts = [p for p in snapshot if p.get("type") == "text"]
assert len(text_parts) == 2
tool_parts = [p for p in snapshot if p.get("type") == "tool-call"]
assert len(tool_parts) == 1
tool_part = tool_parts[0]
assert tool_part["toolCallId"] == "tool_call_ui_1"
assert tool_part["toolName"] == "ls"
assert tool_part["args"] == {"path": "/"}
# ``argsText`` ends up as the pretty-printed final args (the
# ``tool-input-available`` event replaces the streamed deltas
# with ``json.dumps(args, indent=2)`` to match the FE's
# post-stream rendering).
assert tool_part["argsText"] == '{\n "path": "/"\n}'
assert tool_part["result"] == {"files": ["a.txt", "b.txt"]}
# ``langchainToolCallId`` is the agent-side correlation id used
# by the regenerate path; a missing one breaks
# edit-from-tool-call later.
assert tool_part["langchainToolCallId"] == "lc_call_xyz"
await finalize_assistant_turn(
message_id=msg_id,
chat_id=thread_id,
search_space_id=search_space_id,
user_id=user_id_str,
turn_id=turn_id,
content=snapshot,
accumulator=_accumulator_with_one_call(),
)
# ``content`` must round-trip byte-for-byte through the JSONB
# column. SQLAlchemy doesn't auto-refresh the row that survived
# the savepoint commit, so refresh explicitly.
row = await db_session.get(NewChatMessage, msg_id)
await db_session.refresh(row)
# The history loader reads ``content`` straight into the FE's
# parts array, so a strict equality comparison is the right
# invariant here.
assert row.content == snapshot
# Tool-call parts must JSON-serialise cleanly — nothing in
# ``args`` / ``argsText`` / ``result`` should accidentally be a
# non-JSON-safe value (datetime, set, custom class).
assert json.dumps(row.content)
usage_count = (
await db_session.execute(
select(func.count())
.select_from(TokenUsage)
.where(TokenUsage.message_id == msg_id)
)
).scalar_one()
assert usage_count == 1
usage = (
await db_session.execute(
select(TokenUsage).where(TokenUsage.message_id == msg_id)
)
).scalar_one()
assert usage.usage_type == "chat"
assert usage.prompt_tokens == 200
assert usage.completion_tokens == 80
assert usage.total_tokens == 280
assert usage.cost_micros == 22222
assert usage.thread_id == thread_id
assert usage.search_space_id == search_space_id
# ---------------------------------------------------------------------------
# (b) FE appendMessage after server finalize
# ---------------------------------------------------------------------------
class TestAppendMessageRecoveryAfterFinalize:
async def test_returns_server_content_and_does_not_duplicate_token_usage(
self,
db_session,
db_user,
db_thread,
db_search_space,
patched_shielded_session,
bypass_permission_checks,
):
"""FE's stale ``appendMessage`` after server finalize.
The frontend used to be the authoritative writer for assistant
``content``. Now the server is. When the legacy FE round-trip
fires *after* the server has already finalized:
* the route's INSERT trips the (thread_id, turn_id, role)
partial unique index from migration 141,
* the recovery branch fetches the existing row and returns
*its* ``content`` discarding the FE payload so the
history loader reads the rich server payload (full tool
args, argsText, langchainToolCallId, etc.) on next page
reload,
* the route's optional ``token_usage`` insert is keyed on the
partial unique index from migration 142 so it silently
no-ops if the server already wrote one.
"""
thread_id = db_thread.id
user_id_str = str(db_user.id)
search_space_id = db_search_space.id
turn_id = f"{thread_id}:fe_late_append"
# Step 1: server stream completes. Server-built rich content is
# finalized.
msg_id = await persist_assistant_shell(
chat_id=thread_id,
user_id=user_id_str,
turn_id=turn_id,
)
assert msg_id is not None
server_content = _build_tool_heavy_content()
await finalize_assistant_turn(
message_id=msg_id,
chat_id=thread_id,
search_space_id=search_space_id,
user_id=user_id_str,
turn_id=turn_id,
content=server_content,
accumulator=_accumulator_with_one_call(),
)
# Step 2: simulate the legacy FE ``appendMessage`` round-trip
# arriving with stale, lossy content (missing tool args, etc.)
# plus a ``token_usage`` body.
fe_stale_content = [
{"type": "text", "text": "Found 2 files: a.txt and b.txt."},
]
fe_request_body = {
"role": "assistant",
"content": fe_stale_content,
"turn_id": turn_id,
"token_usage": {
"prompt_tokens": 999,
"completion_tokens": 999,
"total_tokens": 1998,
"cost_micros": 88888,
"usage": {"any": "thing"},
"call_details": {"calls": []},
},
}
request = _FakeRequest(fe_request_body)
# ``db_user`` is bound to ``db_session``. The route's
# IntegrityError branch calls ``session.rollback()``, which
# expires every ORM row attached to the session including this
# user — historically causing ``user.id`` to lazy-load
# out-of-greenlet and crash the request with ``MissingGreenlet``
# (observed in production logs at /api/v1/threads/531/messages
# 2026-05-04). The route now captures ``user.id`` to a primitive
# UUID at the top of the handler, so the rollback can't reach
# it. Pass the *attached* user here on purpose — that's the
# production scenario, and this test is the regression guard
# against that bug returning.
response = await new_chat_routes.append_message(
thread_id=thread_id,
request=request,
session=db_session,
user=db_user,
)
# Response must echo the SERVER's rich payload, not the FE's
# stale snapshot. This is the user-visible part of the
# contract: history reload + ThreadHistoryAdapter.append both
# read from the same authoritative source.
assert response.id == msg_id
assert response.role == NewChatMessageRole.ASSISTANT
assert response.turn_id == turn_id
assert response.content == server_content
assert response.content != fe_stale_content
# The on-disk row must agree with the response.
row = await db_session.get(NewChatMessage, msg_id)
await db_session.refresh(row)
assert row.content == server_content
# ``token_usage``: exactly one row, with the *server's* values
# (the FE's much larger token counts must not have overwritten
# them).
usage_count = (
await db_session.execute(
select(func.count())
.select_from(TokenUsage)
.where(TokenUsage.message_id == msg_id)
)
).scalar_one()
assert usage_count == 1
usage = (
await db_session.execute(
select(TokenUsage).where(TokenUsage.message_id == msg_id)
)
).scalar_one()
assert usage.cost_micros == 22222 # Server's value, not 88888
assert usage.total_tokens == 280 # Server's value, not 1998
async def test_legacy_fe_first_appendmessage_then_server_no_dupe(
self,
db_session,
db_user,
db_thread,
db_search_space,
patched_shielded_session,
bypass_permission_checks,
):
"""Inverse race: legacy FE writes first, server finalize second.
Some clients still post ``appendMessage`` before the streaming
``finally`` runs. The contract is symmetric: whichever writer
loses the (thread_id, turn_id, role) race silently lets the
winner keep its content. In particular the *server's*
finalize must NOT raise it must look up the existing row and
UPDATE its content with the server-built payload (which is
always richer/more authoritative than whatever the FE
snapshot held).
"""
thread_id = db_thread.id
user_id_str = str(db_user.id)
search_space_id = db_search_space.id
turn_id = f"{thread_id}:fe_first"
# Step 1: legacy FE appendMessage lands first. No prior shell
# row exists; the route does the INSERT itself.
fe_request_body = {
"role": "assistant",
"content": [{"type": "text", "text": "early FE write"}],
"turn_id": turn_id,
}
fe_response = await new_chat_routes.append_message(
thread_id=thread_id,
request=_FakeRequest(fe_request_body),
session=db_session,
user=db_user,
)
assert fe_response.role == NewChatMessageRole.ASSISTANT
# Step 2: server stream's persist_assistant_shell now races
# behind. It must adopt the existing row id, not raise.
adopted_id = await persist_assistant_shell(
chat_id=thread_id,
user_id=user_id_str,
turn_id=turn_id,
)
assert adopted_id == fe_response.id
# Step 3: server finalize then overwrites the FE's stub with
# the rich content (which is the correct, more authoritative
# payload).
server_content = _build_tool_heavy_content()
await finalize_assistant_turn(
message_id=adopted_id,
chat_id=thread_id,
search_space_id=search_space_id,
user_id=user_id_str,
turn_id=turn_id,
content=server_content,
accumulator=_accumulator_with_one_call(),
)
# Final state: one row, server content, one token_usage row.
msg_count = (
await db_session.execute(
select(func.count())
.select_from(NewChatMessage)
.where(
NewChatMessage.thread_id == thread_id,
NewChatMessage.turn_id == turn_id,
NewChatMessage.role == NewChatMessageRole.ASSISTANT,
)
)
).scalar_one()
assert msg_count == 1
row = await db_session.get(NewChatMessage, adopted_id)
await db_session.refresh(row)
assert row.content == server_content
usage_count = (
await db_session.execute(
select(func.count())
.select_from(TokenUsage)
.where(TokenUsage.message_id == adopted_id)
)
).scalar_one()
assert usage_count == 1
async def test_appendmessage_without_turn_id_legacy_400(
self,
db_session,
db_user,
db_thread,
patched_shielded_session,
bypass_permission_checks,
):
"""Defensive: a bare appendMessage with no turn_id and no
existing row is just a normal INSERT must succeed. But if a
row with the same role already exists in this thread *without*
turn_id collisions, the route should fall through to the
legacy 400 path on a foreign-key / unrelated IntegrityError
(we don't ship that bug today, but pin the behaviour so a
future schema change can't silently regress it).
"""
thread_id = db_thread.id
# Bare appendMessage with no turn_id — should just succeed
# without invoking the recovery branch.
ok_response = await new_chat_routes.append_message(
thread_id=thread_id,
request=_FakeRequest(
{
"role": "user",
"content": [{"type": "text", "text": "hi"}],
}
),
session=db_session,
user=db_user,
)
assert ok_response.role == NewChatMessageRole.USER
assert ok_response.turn_id is None
# Sanity: the route did NOT silently swallow the missing
# turn_id by routing through the unique-index recovery branch
# — it took the happy path.
msg_count = (
await db_session.execute(
select(func.count())
.select_from(NewChatMessage)
.where(
NewChatMessage.thread_id == thread_id,
NewChatMessage.role == NewChatMessageRole.USER,
)
)
).scalar_one()
assert msg_count == 1

View file

@ -0,0 +1,332 @@
"""Integration tests for the SSE-based message ID handshake.
The streaming generators (``stream_new_chat`` / ``stream_resume_chat``)
emit two new events after their respective persistence helpers resolve
the canonical ``new_chat_messages.id``:
* ``data-user-message-id`` emitted only by ``stream_new_chat``,
AFTER ``persist_user_turn`` and BEFORE any LLM streaming.
* ``data-assistant-message-id`` emitted by both
``stream_new_chat`` and ``stream_resume_chat``, AFTER
``persist_assistant_shell`` and BEFORE any LLM streaming.
The frontend renames its optimistic ``msg-user-XXX`` /
``msg-assistant-XXX`` placeholder ids to ``msg-{db_id}`` upon receiving
these events. This test suite anchors three contracts:
1. ``format_data`` produces SSE bytes in the precise shape
``data: {"type":"data-<suffix>","data":{...}}\\n\\n`` that the FE's
``readSSEStream`` consumer parses (matches ``surfsense_web/lib/chat/streaming-state.ts``).
2. The ``message_id`` carried in the SSE payload exactly equals the
primary key the persistence helper inserted into
``new_chat_messages`` so the FE rename produces ``msg-{real_pk}``,
which in turn unlocks DB-id-gated UI (comments, edit-from-message).
3. The same ``message_id`` is used for the ``token_usage.message_id``
foreign key, so ``finalize_assistant_turn``'s row binds correctly.
Direct end-to-end testing of ``stream_new_chat`` requires a fully
mocked agent + LLM stack (out-of-scope here); those flows are covered
by the harness-driven integration tests under
``tests/integration/agents/new_chat/`` plus the assertion in
``test_persistence.py`` that the helpers themselves return ``int``
ids. The contracts above close the remaining gap between the persist
helpers and the bytes that ship to the FE.
"""
from __future__ import annotations
import json
from contextlib import asynccontextmanager
import pytest
import pytest_asyncio
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import (
ChatVisibility,
NewChatMessage,
NewChatMessageRole,
NewChatThread,
SearchSpace,
User,
)
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat import persistence as persistence_module
from app.tasks.chat.persistence import (
persist_assistant_shell,
persist_user_turn,
)
pytestmark = pytest.mark.integration
# ---------------------------------------------------------------------------
# Fixtures (mirror test_persistence.py)
# ---------------------------------------------------------------------------
@pytest_asyncio.fixture
async def db_thread(
db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
) -> NewChatThread:
thread = NewChatThread(
title="Test Chat",
search_space_id=db_search_space.id,
created_by_id=db_user.id,
visibility=ChatVisibility.PRIVATE,
)
db_session.add(thread)
await db_session.flush()
return thread
@pytest.fixture
def patched_shielded_session(monkeypatch, db_session: AsyncSession):
"""Route persistence helpers to the test's savepoint-bound session."""
@asynccontextmanager
async def _fake_shielded_session():
yield db_session
monkeypatch.setattr(
persistence_module,
"shielded_async_session",
_fake_shielded_session,
)
return db_session
# ---------------------------------------------------------------------------
# (1) SSE byte-shape contract
# ---------------------------------------------------------------------------
def _parse_sse_data_line(blob: str) -> dict:
"""Unwrap a single ``data: <json>\\n\\n`` SSE frame.
Raises if there's more than one frame or the prefix is wrong — keeps
the parser strict so a regression in ``format_data`` produces a
test failure here, not in a downstream consumer.
"""
assert blob.endswith("\n\n"), f"missing terminator: {blob!r}"
line = blob.removesuffix("\n\n")
assert line.startswith("data: "), f"missing data prefix: {line!r}"
return json.loads(line.removeprefix("data: "))
class TestSSEByteShape:
def test_data_user_message_id_byte_shape(self):
"""``format_data("user-message-id", {...})`` must produce the
exact wire format the FE's
``readSSEStream`` -> ``data-user-message-id`` case parses.
"""
svc = VercelStreamingService()
blob = svc.format_data(
"user-message-id",
{"message_id": 1843, "turn_id": "533:1762900000000"},
)
envelope = _parse_sse_data_line(blob)
assert envelope == {
"type": "data-user-message-id",
"data": {"message_id": 1843, "turn_id": "533:1762900000000"},
}
def test_data_assistant_message_id_byte_shape(self):
svc = VercelStreamingService()
blob = svc.format_data(
"assistant-message-id",
{"message_id": 1844, "turn_id": "533:1762900000000"},
)
envelope = _parse_sse_data_line(blob)
assert envelope == {
"type": "data-assistant-message-id",
"data": {"message_id": 1844, "turn_id": "533:1762900000000"},
}
# ---------------------------------------------------------------------------
# (2) Helper-id <-> DB-pk coherence
# ---------------------------------------------------------------------------
class TestHandshakeIdMatchesDB:
"""The SSE handshake's correctness hinges on the integer in
``data-{user,assistant}-message-id`` being the EXACT primary key
the persistence helper inserted. If they ever diverge, the FE
rename produces ``msg-{wrong_id}``, comments break (regex match
fails), and downstream features (edit, regenerate) target the
wrong row. Anchor it here.
"""
async def test_user_message_id_matches_new_chat_messages_pk(
self,
db_session,
db_user,
db_thread,
patched_shielded_session,
):
thread_id = db_thread.id
user_id_str = str(db_user.id)
turn_id = f"{thread_id}:9000"
# The streaming generator passes this same value into
# ``streaming_service.format_data("user-message-id", {...})``.
msg_id_from_helper = await persist_user_turn(
chat_id=thread_id,
user_id=user_id_str,
turn_id=turn_id,
user_query="hello",
)
assert isinstance(msg_id_from_helper, int)
# Look up the row the helper inserted via
# ``(thread_id, turn_id, role)`` — the same composite the FE
# uses to identify a turn — and confirm the PK matches.
row = (
await db_session.execute(
select(NewChatMessage).where(
NewChatMessage.thread_id == thread_id,
NewChatMessage.turn_id == turn_id,
NewChatMessage.role == NewChatMessageRole.USER,
)
)
).scalar_one()
assert row.id == msg_id_from_helper
# The byte-stream the FE actually receives — confirms the
# round-trip from the helper return value to the SSE payload.
svc = VercelStreamingService()
envelope = _parse_sse_data_line(
svc.format_data(
"user-message-id",
{"message_id": msg_id_from_helper, "turn_id": turn_id},
)
)
assert envelope["data"]["message_id"] == row.id
async def test_assistant_message_id_matches_new_chat_messages_pk(
self,
db_session,
db_user,
db_thread,
patched_shielded_session,
):
thread_id = db_thread.id
user_id_str = str(db_user.id)
turn_id = f"{thread_id}:9100"
msg_id_from_helper = await persist_assistant_shell(
chat_id=thread_id,
user_id=user_id_str,
turn_id=turn_id,
)
assert isinstance(msg_id_from_helper, int)
row = (
await db_session.execute(
select(NewChatMessage).where(
NewChatMessage.thread_id == thread_id,
NewChatMessage.turn_id == turn_id,
NewChatMessage.role == NewChatMessageRole.ASSISTANT,
)
)
).scalar_one()
assert row.id == msg_id_from_helper
svc = VercelStreamingService()
envelope = _parse_sse_data_line(
svc.format_data(
"assistant-message-id",
{"message_id": msg_id_from_helper, "turn_id": turn_id},
)
)
assert envelope["data"]["message_id"] == row.id
async def test_handshake_ids_for_full_turn_are_distinct_and_paired(
self,
db_session,
db_user,
db_thread,
patched_shielded_session,
):
"""Sanity: a full new-chat turn's two SSE events carry two
DIFFERENT ids (user row PK assistant row PK), both anchored
to the SAME ``turn_id`` in the DB. This pairing is what
``finalize_assistant_turn`` and the regenerate / edit flows
rely on.
"""
thread_id = db_thread.id
user_id_str = str(db_user.id)
turn_id = f"{thread_id}:9200"
user_msg_id = await persist_user_turn(
chat_id=thread_id,
user_id=user_id_str,
turn_id=turn_id,
user_query="hi",
)
assistant_msg_id = await persist_assistant_shell(
chat_id=thread_id,
user_id=user_id_str,
turn_id=turn_id,
)
assert user_msg_id is not None and assistant_msg_id is not None
assert user_msg_id != assistant_msg_id
rows = (
(
await db_session.execute(
select(NewChatMessage)
.where(
NewChatMessage.thread_id == thread_id,
NewChatMessage.turn_id == turn_id,
)
.order_by(NewChatMessage.id)
)
)
.scalars()
.all()
)
assert len(rows) == 2
ids_by_role = {r.role: r.id for r in rows}
assert ids_by_role[NewChatMessageRole.USER] == user_msg_id
assert ids_by_role[NewChatMessageRole.ASSISTANT] == assistant_msg_id
# ---------------------------------------------------------------------------
# (3) Parse helpers used by the FE — sanity-check our payload shape
# ---------------------------------------------------------------------------
class TestPayloadShapeMatchesFEReader:
"""The FE's ``readStreamedMessageId`` (in
``surfsense_web/lib/chat/stream-side-effects.ts``) requires:
* ``message_id`` is a ``number`` (rejects null / string / NaN).
* ``turn_id`` is an optional non-empty string (else it's coerced
to ``null``).
These tests exercise the BE side of that contract by inspecting
``format_data`` output shapes that the FE consumes verbatim.
"""
def test_message_id_is_serialised_as_a_json_number(self):
svc = VercelStreamingService()
envelope = _parse_sse_data_line(
svc.format_data("user-message-id", {"message_id": 42, "turn_id": "t"})
)
assert isinstance(envelope["data"]["message_id"], int)
assert envelope["data"]["message_id"] == 42
def test_turn_id_round_trips_as_string(self):
svc = VercelStreamingService()
# The actual format used in production: f"{chat_id}:{int(time.time()*1000)}"
production_turn_id = "533:1762900000000"
envelope = _parse_sse_data_line(
svc.format_data(
"assistant-message-id",
{"message_id": 1, "turn_id": production_turn_id},
)
)
assert envelope["data"]["turn_id"] == production_turn_id

View file

@ -0,0 +1,747 @@
"""Integration tests for ``app.tasks.chat.persistence``.
Verifies the DB-side guarantees the streaming chat task relies on:
* ``persist_assistant_shell`` is idempotent against the
``(thread_id, turn_id, ASSISTANT)`` partial unique index from
migration 141. Two calls with the same ``turn_id`` return the SAME
``message_id`` and never create a duplicate ``new_chat_messages`` row.
* ``finalize_assistant_turn`` writes a status-marker payload when given
empty content, never raises, and is safe to call twice on the same
``message_id`` the partial unique index from migration 142
(``uq_token_usage_message_id``) prevents the second insert from
producing a duplicate ``token_usage`` row.
* The same ``ON CONFLICT DO NOTHING`` invariant covers the cross-writer
race where ``finalize_assistant_turn`` and the ``append_message``
recovery branch both target the same ``message_id``.
All tests run inside the conftest's outer-transaction-with-savepoint
fixture so commits inside the helpers (which open their own
``shielded_async_session``) are released as savepoints and rolled back
at test end. We monkey-patch ``shielded_async_session`` to yield the
same pooled test session so the integration transaction stays
in-scope.
"""
from __future__ import annotations
from contextlib import asynccontextmanager
import pytest
import pytest_asyncio
from sqlalchemy import func, select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import (
ChatVisibility,
NewChatMessage,
NewChatMessageRole,
NewChatThread,
SearchSpace,
TokenUsage,
User,
)
from app.services.token_tracking_service import TurnTokenAccumulator
from app.tasks.chat import persistence as persistence_module
from app.tasks.chat.persistence import (
finalize_assistant_turn,
persist_assistant_shell,
persist_user_turn,
)
pytestmark = pytest.mark.integration
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest_asyncio.fixture
async def db_thread(
db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
) -> NewChatThread:
thread = NewChatThread(
title="Test Chat",
search_space_id=db_search_space.id,
created_by_id=db_user.id,
visibility=ChatVisibility.PRIVATE,
)
db_session.add(thread)
await db_session.flush()
return thread
@pytest.fixture
def patched_shielded_session(monkeypatch, db_session: AsyncSession):
"""Route persistence helpers to the test's savepoint-bound session.
The persistence helpers use ``async with shielded_async_session() as
ws`` and call ``ws.commit()`` internally. Inside the conftest's
``join_transaction_mode="create_savepoint"`` setup, those commits
release a SAVEPOINT instead of committing the outer transaction
so the test session can see helper-staged rows immediately and the
outer rollback at end of test wipes them.
"""
@asynccontextmanager
async def _fake_shielded_session():
yield db_session
# Do NOT close — the outer fixture owns the session lifecycle.
monkeypatch.setattr(
persistence_module,
"shielded_async_session",
_fake_shielded_session,
)
return db_session
def _accumulator_with_one_call() -> TurnTokenAccumulator:
acc = TurnTokenAccumulator()
acc.add(
model="gpt-4o-mini",
prompt_tokens=100,
completion_tokens=50,
total_tokens=150,
cost_micros=12345,
)
return acc
async def _count_assistant_rows(
session: AsyncSession, thread_id: int, turn_id: str
) -> int:
result = await session.execute(
select(func.count())
.select_from(NewChatMessage)
.where(
NewChatMessage.thread_id == thread_id,
NewChatMessage.turn_id == turn_id,
NewChatMessage.role == NewChatMessageRole.ASSISTANT,
)
)
return int(result.scalar_one())
async def _count_token_usage_rows(session: AsyncSession, message_id: int) -> int:
result = await session.execute(
select(func.count())
.select_from(TokenUsage)
.where(TokenUsage.message_id == message_id)
)
return int(result.scalar_one())
# ---------------------------------------------------------------------------
# persist_assistant_shell
# ---------------------------------------------------------------------------
class TestPersistAssistantShell:
async def test_first_call_inserts_empty_shell_and_returns_id(
self,
db_session,
db_user,
db_thread,
patched_shielded_session,
):
# Capture primitive ids before any persistence helper runs:
# the helpers commit/rollback the shared test session, which
# can detach ORM rows mid-test.
thread_id = db_thread.id
user_id_str = str(db_user.id)
turn_id = f"{thread_id}:1000"
msg_id = await persist_assistant_shell(
chat_id=thread_id,
user_id=user_id_str,
turn_id=turn_id,
)
assert msg_id is not None and isinstance(msg_id, int)
row = await db_session.get(NewChatMessage, msg_id)
assert row is not None
assert row.thread_id == thread_id
assert row.role == NewChatMessageRole.ASSISTANT
assert row.turn_id == turn_id
# Empty shell payload — finalize_assistant_turn overwrites later.
assert row.content == [{"type": "text", "text": ""}]
async def test_second_call_with_same_turn_id_returns_same_id(
self,
db_session,
db_user,
db_thread,
patched_shielded_session,
):
# Capture primitive ids before any persistence helper runs:
# the helpers commit/rollback the shared test session, which
# can detach ORM rows mid-test.
thread_id = db_thread.id
user_id_str = str(db_user.id)
turn_id = f"{thread_id}:2000"
first_id = await persist_assistant_shell(
chat_id=thread_id,
user_id=user_id_str,
turn_id=turn_id,
)
second_id = await persist_assistant_shell(
chat_id=thread_id,
user_id=user_id_str,
turn_id=turn_id,
)
assert first_id is not None
assert first_id == second_id
# Exactly one row in the DB for this turn.
assert await _count_assistant_rows(db_session, thread_id, turn_id) == 1
async def test_missing_turn_id_returns_none(
self,
db_user,
db_thread,
patched_shielded_session,
):
thread_id = db_thread.id
user_id_str = str(db_user.id)
msg_id = await persist_assistant_shell(
chat_id=thread_id,
user_id=user_id_str,
turn_id="",
)
assert msg_id is None
async def test_after_persist_user_turn_resolves_assistant_id(
self,
db_session,
db_user,
db_thread,
patched_shielded_session,
):
thread_id = db_thread.id
user_id_str = str(db_user.id)
turn_id = f"{thread_id}:3000"
# The streaming layer always calls persist_user_turn first, so
# smoke-test the canonical sequence.
user_msg_id = await persist_user_turn(
chat_id=thread_id,
user_id=user_id_str,
turn_id=turn_id,
user_query="hello",
)
assert isinstance(user_msg_id, int)
msg_id = await persist_assistant_shell(
chat_id=thread_id,
user_id=user_id_str,
turn_id=turn_id,
)
assert msg_id is not None
# User row + assistant shell row = 2 rows for this turn.
result = await db_session.execute(
select(func.count())
.select_from(NewChatMessage)
.where(
NewChatMessage.thread_id == thread_id,
NewChatMessage.turn_id == turn_id,
)
)
assert result.scalar_one() == 2
async def test_double_call_with_same_turn_id_uses_on_conflict(
self,
db_session,
db_user,
db_thread,
patched_shielded_session,
):
"""Verifies the ON CONFLICT DO NOTHING path on the assistant
shell does not raise ``IntegrityError`` even when the second
writer races the first within a tight loop. ``test_second_call_with_same_turn_id_returns_same_id``
already covers the same-id semantics; this test additionally
asserts neither call raises so the debugger's
``raise-on-IntegrityError`` setting won't pause the streaming
path under contention.
"""
thread_id = db_thread.id
user_id_str = str(db_user.id)
turn_id = f"{thread_id}:3500"
# Both calls go through ``pg_insert(...).on_conflict_do_nothing``;
# the second one returns RETURNING=∅ and falls into the SELECT
# branch. Neither path raises.
first_id = await persist_assistant_shell(
chat_id=thread_id, user_id=user_id_str, turn_id=turn_id
)
second_id = await persist_assistant_shell(
chat_id=thread_id, user_id=user_id_str, turn_id=turn_id
)
assert first_id is not None
assert first_id == second_id
# ---------------------------------------------------------------------------
# persist_user_turn
# ---------------------------------------------------------------------------
class TestPersistUserTurn:
async def test_returns_message_id_on_first_insert(
self,
db_session,
db_user,
db_thread,
patched_shielded_session,
):
thread_id = db_thread.id
user_id_str = str(db_user.id)
turn_id = f"{thread_id}:8000"
msg_id = await persist_user_turn(
chat_id=thread_id,
user_id=user_id_str,
turn_id=turn_id,
user_query="hello",
)
assert isinstance(msg_id, int) and msg_id > 0
row = await db_session.get(NewChatMessage, msg_id)
assert row is not None
assert row.thread_id == thread_id
assert row.role == NewChatMessageRole.USER
assert row.turn_id == turn_id
assert row.content == [{"type": "text", "text": "hello"}]
async def test_returns_existing_id_on_conflict(
self,
db_session,
db_user,
db_thread,
patched_shielded_session,
):
thread_id = db_thread.id
user_id_str = str(db_user.id)
turn_id = f"{thread_id}:8100"
first_id = await persist_user_turn(
chat_id=thread_id,
user_id=user_id_str,
turn_id=turn_id,
user_query="hello",
)
# Second call simulates a legacy FE ``appendMessage`` racing the
# SSE stream: ON CONFLICT DO NOTHING short-circuits at the DB
# level, the helper recovers the existing id via SELECT, and
# crucially does NOT raise ``IntegrityError`` (the debugger
# would otherwise pause on it).
second_id = await persist_user_turn(
chat_id=thread_id,
user_id=user_id_str,
turn_id=turn_id,
user_query="ignored on conflict",
)
assert first_id is not None
assert first_id == second_id
# Exactly one user row for this turn.
count = await db_session.execute(
select(func.count())
.select_from(NewChatMessage)
.where(
NewChatMessage.thread_id == thread_id,
NewChatMessage.turn_id == turn_id,
NewChatMessage.role == NewChatMessageRole.USER,
)
)
assert count.scalar_one() == 1
async def test_embeds_mentioned_documents_part(
self,
db_session,
db_user,
db_thread,
patched_shielded_session,
):
"""The full ``{id, title, document_type}`` triple forwarded by
the FE must round-trip into a single ``mentioned-documents``
ContentPart on the persisted user message the history loader
renders the chips on reload from this part directly.
"""
thread_id = db_thread.id
user_id_str = str(db_user.id)
turn_id = f"{thread_id}:8200"
mentioned = [
{"id": 11, "title": "Alpha", "document_type": "GENERAL"},
{"id": 22, "title": "Beta", "document_type": "GENERAL"},
]
msg_id = await persist_user_turn(
chat_id=thread_id,
user_id=user_id_str,
turn_id=turn_id,
user_query="hello",
mentioned_documents=mentioned,
)
assert isinstance(msg_id, int)
row = await db_session.get(NewChatMessage, msg_id)
assert row is not None
# Content is a 2-part list: text + mentioned-documents.
assert isinstance(row.content, list)
assert row.content[0] == {"type": "text", "text": "hello"}
assert row.content[1] == {
"type": "mentioned-documents",
"documents": [
{"id": 11, "title": "Alpha", "document_type": "GENERAL"},
{"id": 22, "title": "Beta", "document_type": "GENERAL"},
],
}
async def test_skips_mentioned_documents_when_empty_or_invalid(
self,
db_session,
db_user,
db_thread,
patched_shielded_session,
):
"""Empty list and entries missing required fields are dropped;
a ``mentioned-documents`` part is only emitted when at least
one normalised entry survived.
"""
thread_id = db_thread.id
user_id_str = str(db_user.id)
turn_id_empty = f"{thread_id}:8300"
turn_id_invalid = f"{thread_id}:8301"
msg_id_empty = await persist_user_turn(
chat_id=thread_id,
user_id=user_id_str,
turn_id=turn_id_empty,
user_query="hi",
mentioned_documents=[],
)
assert isinstance(msg_id_empty, int)
row_empty = await db_session.get(NewChatMessage, msg_id_empty)
assert row_empty is not None
assert row_empty.content == [{"type": "text", "text": "hi"}]
# Each entry missing one required field — all skipped.
msg_id_invalid = await persist_user_turn(
chat_id=thread_id,
user_id=user_id_str,
turn_id=turn_id_invalid,
user_query="hi",
mentioned_documents=[
{"title": "no id", "document_type": "GENERAL"}, # missing id
{"id": 99, "document_type": "GENERAL"}, # missing title
{"id": 100, "title": "no type"}, # missing document_type
],
)
assert isinstance(msg_id_invalid, int)
row_invalid = await db_session.get(NewChatMessage, msg_id_invalid)
assert row_invalid is not None
assert row_invalid.content == [{"type": "text", "text": "hi"}]
async def test_missing_turn_id_returns_none(
self,
db_user,
db_thread,
patched_shielded_session,
):
thread_id = db_thread.id
user_id_str = str(db_user.id)
msg_id = await persist_user_turn(
chat_id=thread_id,
user_id=user_id_str,
turn_id="",
user_query="hello",
)
assert msg_id is None
# ---------------------------------------------------------------------------
# finalize_assistant_turn
# ---------------------------------------------------------------------------
class TestFinalizeAssistantTurn:
async def test_writes_content_and_token_usage(
self,
db_session,
db_user,
db_thread,
db_search_space,
patched_shielded_session,
):
thread_id = db_thread.id
user_id_uuid = db_user.id
user_id_str = str(user_id_uuid)
search_space_id = db_search_space.id
turn_id = f"{thread_id}:4000"
msg_id = await persist_assistant_shell(
chat_id=thread_id,
user_id=user_id_str,
turn_id=turn_id,
)
assert msg_id is not None
rich_content = [
{"type": "text", "text": "Hello world"},
{
"type": "tool-call",
"toolCallId": "call_x",
"toolName": "ls",
"args": {"path": "/"},
"argsText": '{\n "path": "/"\n}',
"result": {"files": []},
"langchainToolCallId": "lc_x",
},
]
await finalize_assistant_turn(
message_id=msg_id,
chat_id=thread_id,
search_space_id=search_space_id,
user_id=user_id_str,
turn_id=turn_id,
content=rich_content,
accumulator=_accumulator_with_one_call(),
)
row = await db_session.get(NewChatMessage, msg_id)
await db_session.refresh(row)
assert row.content == rich_content
# Exactly one token_usage row keyed on this message_id.
usage_rows = (
(
await db_session.execute(
select(TokenUsage).where(TokenUsage.message_id == msg_id)
)
)
.scalars()
.all()
)
assert len(usage_rows) == 1
usage = usage_rows[0]
assert usage.usage_type == "chat"
assert usage.prompt_tokens == 100
assert usage.completion_tokens == 50
assert usage.total_tokens == 150
assert usage.cost_micros == 12345
assert usage.thread_id == thread_id
assert usage.search_space_id == search_space_id
async def test_empty_content_writes_status_marker(
self,
db_session,
db_user,
db_thread,
db_search_space,
patched_shielded_session,
):
thread_id = db_thread.id
user_id_str = str(db_user.id)
search_space_id = db_search_space.id
turn_id = f"{thread_id}:5000"
msg_id = await persist_assistant_shell(
chat_id=thread_id,
user_id=user_id_str,
turn_id=turn_id,
)
assert msg_id is not None
# Pure tool-call turn that aborted before any output, or
# interrupt before any event arrived — empty list.
await finalize_assistant_turn(
message_id=msg_id,
chat_id=thread_id,
search_space_id=search_space_id,
user_id=user_id_str,
turn_id=turn_id,
content=[],
accumulator=None,
)
row = await db_session.get(NewChatMessage, msg_id)
await db_session.refresh(row)
assert row.content == [{"type": "status", "text": "(no text response)"}]
async def test_double_call_safe_via_on_conflict(
self,
db_session,
db_user,
db_thread,
db_search_space,
patched_shielded_session,
):
thread_id = db_thread.id
user_id_str = str(db_user.id)
search_space_id = db_search_space.id
turn_id = f"{thread_id}:6000"
msg_id = await persist_assistant_shell(
chat_id=thread_id,
user_id=user_id_str,
turn_id=turn_id,
)
assert msg_id is not None
first_acc = _accumulator_with_one_call()
await finalize_assistant_turn(
message_id=msg_id,
chat_id=thread_id,
search_space_id=search_space_id,
user_id=user_id_str,
turn_id=turn_id,
content=[{"type": "text", "text": "first finalize"}],
accumulator=first_acc,
)
# Simulate a follow-up finalize (e.g., resume retry within the
# shielded finally block firing twice). Different content, but
# ON CONFLICT DO NOTHING on token_usage means the cost from the
# first finalize stays authoritative.
second_acc = TurnTokenAccumulator()
second_acc.add(
model="gpt-4o-mini",
prompt_tokens=999,
completion_tokens=999,
total_tokens=1998,
cost_micros=99999,
)
await finalize_assistant_turn(
message_id=msg_id,
chat_id=thread_id,
search_space_id=search_space_id,
user_id=user_id_str,
turn_id=turn_id,
content=[{"type": "text", "text": "second finalize"}],
accumulator=second_acc,
)
# Content was overwritten by the second UPDATE.
row = await db_session.get(NewChatMessage, msg_id)
await db_session.refresh(row)
assert row.content == [{"type": "text", "text": "second finalize"}]
# But token_usage stayed at exactly one row, preserving the
# first finalize's authoritative cost.
assert await _count_token_usage_rows(db_session, msg_id) == 1
usage = (
await db_session.execute(
select(TokenUsage).where(TokenUsage.message_id == msg_id)
)
).scalar_one()
assert usage.cost_micros == 12345 # First finalize's value
async def test_append_message_style_insert_after_finalize_no_dupe(
self,
db_session,
db_user,
db_thread,
db_search_space,
patched_shielded_session,
):
"""Cross-writer race: ``append_message`` arrives after ``finalize_assistant_turn``.
Both target the same ``message_id``; the partial unique index
``uq_token_usage_message_id`` (migration 142) makes the second
insert a no-op via ``ON CONFLICT DO NOTHING``.
"""
from sqlalchemy import text as sa_text
thread_id = db_thread.id
user_uuid = db_user.id
user_id_str = str(user_uuid)
search_space_id = db_search_space.id
turn_id = f"{thread_id}:7000"
msg_id = await persist_assistant_shell(
chat_id=thread_id,
user_id=user_id_str,
turn_id=turn_id,
)
assert msg_id is not None
await finalize_assistant_turn(
message_id=msg_id,
chat_id=thread_id,
search_space_id=search_space_id,
user_id=user_id_str,
turn_id=turn_id,
content=[{"type": "text", "text": "from server"}],
accumulator=_accumulator_with_one_call(),
)
# Now simulate the FE's append_message branch firing AFTER —
# the same INSERT ... ON CONFLICT DO NOTHING shape used by the
# route handler, keyed on the migration-142 partial unique
# index.
late_insert = (
pg_insert(TokenUsage)
.values(
usage_type="chat",
prompt_tokens=42,
completion_tokens=42,
total_tokens=84,
cost_micros=1,
model_breakdown=None,
call_details=None,
thread_id=thread_id,
message_id=msg_id,
search_space_id=search_space_id,
user_id=user_uuid,
)
.on_conflict_do_nothing(
index_elements=["message_id"],
index_where=sa_text("message_id IS NOT NULL"),
)
)
await db_session.execute(late_insert)
await db_session.flush()
# Still exactly one row, with the original (server) cost value.
assert await _count_token_usage_rows(db_session, msg_id) == 1
usage = (
await db_session.execute(
select(TokenUsage).where(TokenUsage.message_id == msg_id)
)
).scalar_one()
assert usage.cost_micros == 12345
async def test_helper_never_raises_on_missing_message_id(
self,
db_session,
db_user,
db_thread,
db_search_space,
patched_shielded_session,
):
thread_id = db_thread.id
user_id_str = str(db_user.id)
search_space_id = db_search_space.id
# message_id that doesn't exist — finalize must log+return,
# never raise (called from shielded finally).
await finalize_assistant_turn(
message_id=999_999_999,
chat_id=thread_id,
search_space_id=search_space_id,
user_id=user_id_str,
turn_id="anything",
content=[{"type": "text", "text": "x"}],
accumulator=_accumulator_with_one_call(),
)
# If we got here without an exception, the test passes.
# Sanity: no token_usage row created (FK to message would have
# been rejected anyway, but ON CONFLICT path may swallow
# FK errors as well; check directly).
assert await _count_token_usage_rows(db_session, 999_999_999) == 0

View file

@ -0,0 +1,526 @@
"""Unit tests for ``AssistantContentBuilder``.
Pins the in-memory ``ContentPart[]`` projection so the JSONB the server
persists matches what the frontend renders live (see
``surfsense_web/lib/chat/streaming-state.ts``). Every test asserts both
the structural shape of ``snapshot()`` and that the snapshot is
``json.dumps``-safe (the streaming finally block writes it directly to
``new_chat_messages.content`` without an explicit serialization round
trip).
"""
from __future__ import annotations
import json
import pytest
from app.tasks.chat.content_builder import AssistantContentBuilder
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _assert_jsonb_safe(parts: list[dict]) -> None:
"""Sanity check: any snapshot must round-trip through ``json.dumps``."""
serialized = json.dumps(parts)
assert json.loads(serialized) == parts
# ---------------------------------------------------------------------------
# Text turns
# ---------------------------------------------------------------------------
class TestTextOnly:
def test_single_text_block_collapses_consecutive_deltas(self):
b = AssistantContentBuilder()
b.on_text_start("text-1")
b.on_text_delta("text-1", "Hello")
b.on_text_delta("text-1", " ")
b.on_text_delta("text-1", "world")
b.on_text_end("text-1")
snap = b.snapshot()
assert snap == [{"type": "text", "text": "Hello world"}]
assert not b.is_empty()
_assert_jsonb_safe(snap)
def test_empty_text_start_end_pair_leaves_no_part(self):
# Mirrors the FE: a text-start without any deltas should
# not materialise an empty ``{"type":"text","text":""}`` part.
b = AssistantContentBuilder()
b.on_text_start("text-1")
b.on_text_end("text-1")
assert b.snapshot() == []
assert b.is_empty()
def test_text_after_text_end_starts_fresh_part(self):
b = AssistantContentBuilder()
b.on_text_start("text-1")
b.on_text_delta("text-1", "first")
b.on_text_end("text-1")
b.on_text_start("text-2")
b.on_text_delta("text-2", "second")
b.on_text_end("text-2")
snap = b.snapshot()
assert snap == [
{"type": "text", "text": "first"},
{"type": "text", "text": "second"},
]
class TestReasoningThenText:
def test_reasoning_followed_by_text_yields_two_parts_in_order(self):
b = AssistantContentBuilder()
b.on_reasoning_start("r-1")
b.on_reasoning_delta("r-1", "Considering options...")
b.on_reasoning_end("r-1")
b.on_text_start("text-1")
b.on_text_delta("text-1", "The answer is 42.")
b.on_text_end("text-1")
snap = b.snapshot()
assert snap == [
{"type": "reasoning", "text": "Considering options..."},
{"type": "text", "text": "The answer is 42."},
]
_assert_jsonb_safe(snap)
def test_text_delta_after_reasoning_implicitly_closes_reasoning(self):
# Mirrors FE ``appendText``: a text delta arriving while a
# reasoning part is "active" still produces a fresh text
# part, never appends into the reasoning block.
b = AssistantContentBuilder()
b.on_reasoning_start("r-1")
b.on_reasoning_delta("r-1", "thinking")
# No explicit reasoning_end — text delta should close it.
b.on_text_delta("text-1", "answer")
snap = b.snapshot()
assert snap == [
{"type": "reasoning", "text": "thinking"},
{"type": "text", "text": "answer"},
]
# ---------------------------------------------------------------------------
# Tool calls
# ---------------------------------------------------------------------------
class TestToolHeavyTurn:
def test_full_tool_lifecycle_produces_complete_tool_call_part(self):
b = AssistantContentBuilder()
# Some narration before the tool fires.
b.on_text_start("text-1")
b.on_text_delta("text-1", "Searching...")
b.on_text_end("text-1")
b.on_tool_input_start(
ui_id="call_run123",
tool_name="web_search",
langchain_tool_call_id="lc_tool_abc",
)
b.on_tool_input_delta("call_run123", '{"query":')
b.on_tool_input_delta("call_run123", '"surfsense"}')
b.on_tool_input_available(
ui_id="call_run123",
tool_name="web_search",
args={"query": "surfsense"},
langchain_tool_call_id="lc_tool_abc",
)
b.on_tool_output_available(
ui_id="call_run123",
output={"status": "completed", "citations": {}},
langchain_tool_call_id="lc_tool_abc",
)
snap = b.snapshot()
assert snap[0] == {"type": "text", "text": "Searching..."}
tool_part = snap[1]
assert tool_part["type"] == "tool-call"
assert tool_part["toolCallId"] == "call_run123"
assert tool_part["toolName"] == "web_search"
assert tool_part["args"] == {"query": "surfsense"}
# ``argsText`` is the pretty-printed final JSON, not the raw
# streaming buffer (FE ``stream-pipeline.ts:128``).
assert tool_part["argsText"] == json.dumps(
{"query": "surfsense"}, indent=2, ensure_ascii=False
)
assert tool_part["langchainToolCallId"] == "lc_tool_abc"
assert tool_part["result"] == {"status": "completed", "citations": {}}
_assert_jsonb_safe(snap)
def test_tool_input_available_without_prior_start_creates_card(self):
# Legacy / parity_v2-OFF path: tool-input-available may be
# emitted without a prior tool-input-start (no streamed
# tool_call_chunks). The card should still be created.
b = AssistantContentBuilder()
b.on_tool_input_available(
ui_id="call_run42",
tool_name="grep",
args={"pattern": "TODO"},
langchain_tool_call_id="lc_x",
)
b.on_tool_output_available(
ui_id="call_run42",
output={"matches": 3},
langchain_tool_call_id="lc_x",
)
snap = b.snapshot()
assert len(snap) == 1
part = snap[0]
assert part["type"] == "tool-call"
assert part["toolCallId"] == "call_run42"
assert part["args"] == {"pattern": "TODO"}
assert part["langchainToolCallId"] == "lc_x"
assert part["result"] == {"matches": 3}
def test_tool_input_start_idempotent_for_same_ui_id(self):
# parity_v2: tool-input-start can fire from BOTH the chunk
# registration path AND the canonical ``on_tool_start`` path.
# The second call must not create a duplicate part.
b = AssistantContentBuilder()
b.on_tool_input_start("call_x", "ls", "lc_x")
b.on_tool_input_start("call_x", "ls", "lc_x")
snap = b.snapshot()
assert len(snap) == 1
def test_tool_input_delta_without_prior_start_is_silently_dropped(self):
b = AssistantContentBuilder()
b.on_tool_input_delta("call_unknown", '{"orphan": "delta"}')
assert b.snapshot() == []
def test_langchain_tool_call_id_backfills_only_when_absent(self):
b = AssistantContentBuilder()
b.on_tool_input_start("call_x", "ls", "lc_first")
# Late event must NOT clobber an already-set lc id.
b.on_tool_input_start("call_x", "ls", "lc_late")
snap = b.snapshot()
assert snap[0]["langchainToolCallId"] == "lc_first"
def test_args_text_streaming_buffer_reflects_concatenation(self):
b = AssistantContentBuilder()
b.on_tool_input_start("call_x", "save_doc", "lc_y")
b.on_tool_input_delta("call_x", '{"title":')
b.on_tool_input_delta("call_x", '"Hi"}')
# Snapshot mid-stream should see the partial buffer (the FE
# tolerates invalid JSON and renders it as-is).
mid = b.snapshot()
assert mid[0]["argsText"] == '{"title":"Hi"}'
# Then tool-input-available replaces with pretty-printed.
b.on_tool_input_available(
"call_x",
"save_doc",
{"title": "Hi"},
"lc_y",
)
final = b.snapshot()
assert final[0]["argsText"] == json.dumps(
{"title": "Hi"}, indent=2, ensure_ascii=False
)
# ---------------------------------------------------------------------------
# Thinking steps & separators
# ---------------------------------------------------------------------------
class TestThinkingSteps:
def test_first_thinking_step_unshifts_singleton_to_index_zero(self):
b = AssistantContentBuilder()
b.on_text_start("text-1")
b.on_text_delta("text-1", "Hello")
b.on_text_end("text-1")
b.on_thinking_step("step-1", "Analyzing", "in_progress", ["item-a"])
snap = b.snapshot()
# Singleton goes to index 0 (FE ``updateThinkingSteps`` unshift).
assert snap[0]["type"] == "data-thinking-steps"
assert snap[0]["data"]["steps"] == [
{
"id": "step-1",
"title": "Analyzing",
"status": "in_progress",
"items": ["item-a"],
}
]
assert snap[1] == {"type": "text", "text": "Hello"}
def test_subsequent_thinking_steps_mutate_the_singleton_in_place(self):
b = AssistantContentBuilder()
b.on_thinking_step("step-1", "Analyzing", "in_progress", [])
b.on_thinking_step("step-2", "Searching", "in_progress", ["q"])
b.on_thinking_step("step-1", "Analyzing", "completed", ["done"])
snap = b.snapshot()
assert len([p for p in snap if p["type"] == "data-thinking-steps"]) == 1
steps = snap[0]["data"]["steps"]
assert len(steps) == 2
assert steps[0]["id"] == "step-1"
assert steps[0]["status"] == "completed"
assert steps[0]["items"] == ["done"]
assert steps[1]["id"] == "step-2"
def test_thinking_step_with_text_continues_appending_to_text(self):
b = AssistantContentBuilder()
b.on_text_start("text-1")
b.on_text_delta("text-1", "first")
# Thinking step inserts at index 0, bumps text idx from 0 to 1.
b.on_thinking_step("step-1", "Working", "in_progress", [])
b.on_text_delta("text-1", " second")
snap = b.snapshot()
text_parts = [p for p in snap if p["type"] == "text"]
assert text_parts == [{"type": "text", "text": "first second"}]
def test_thinking_step_without_id_is_dropped(self):
b = AssistantContentBuilder()
b.on_thinking_step("", "noop", "in_progress", None)
assert b.snapshot() == []
assert b.is_empty()
class TestStepSeparators:
def test_separator_no_op_before_any_content(self):
b = AssistantContentBuilder()
b.on_step_separator()
assert b.snapshot() == []
def test_separator_after_text_appends_with_step_index_zero(self):
b = AssistantContentBuilder()
b.on_text_start("text-1")
b.on_text_delta("text-1", "first")
b.on_text_end("text-1")
b.on_step_separator()
snap = b.snapshot()
assert snap[-1] == {
"type": "data-step-separator",
"data": {"stepIndex": 0},
}
def test_consecutive_separators_collapse_to_one(self):
b = AssistantContentBuilder()
b.on_text_delta("text-1", "x")
b.on_step_separator()
b.on_step_separator() # No-op: previous part is already a separator.
snap = b.snapshot()
assert sum(1 for p in snap if p["type"] == "data-step-separator") == 1
def test_step_index_increments_across_separators(self):
b = AssistantContentBuilder()
b.on_text_delta("text-1", "a")
b.on_step_separator()
b.on_text_delta("text-2", "b")
b.on_step_separator()
snap = b.snapshot()
seps = [p for p in snap if p["type"] == "data-step-separator"]
assert [s["data"]["stepIndex"] for s in seps] == [0, 1]
# ---------------------------------------------------------------------------
# Interruption handling
# ---------------------------------------------------------------------------
class TestMarkInterrupted:
def test_running_tool_calls_get_state_aborted(self):
b = AssistantContentBuilder()
b.on_tool_input_start("call_a", "ls", "lc_a")
b.on_tool_input_available("call_a", "ls", {"path": "/"}, "lc_a")
# No tool-output-available — simulates client disconnect mid-tool.
b.mark_interrupted()
snap = b.snapshot()
assert snap[0]["state"] == "aborted"
assert "result" not in snap[0]
def test_completed_tool_calls_are_not_marked_aborted(self):
b = AssistantContentBuilder()
b.on_tool_input_start("call_a", "ls", "lc_a")
b.on_tool_input_available("call_a", "ls", {"path": "/"}, "lc_a")
b.on_tool_output_available("call_a", {"files": []}, "lc_a")
b.mark_interrupted()
snap = b.snapshot()
assert "state" not in snap[0]
assert snap[0]["result"] == {"files": []}
def test_open_text_block_keeps_accumulated_content(self):
b = AssistantContentBuilder()
b.on_text_start("text-1")
b.on_text_delta("text-1", "partial")
# No on_text_end — disconnect mid-stream.
b.mark_interrupted()
snap = b.snapshot()
assert snap == [{"type": "text", "text": "partial"}]
# ---------------------------------------------------------------------------
# is_empty / snapshot semantics
# ---------------------------------------------------------------------------
class TestIsEmpty:
def test_fresh_builder_is_empty(self):
assert AssistantContentBuilder().is_empty()
def test_text_part_breaks_emptiness(self):
b = AssistantContentBuilder()
b.on_text_delta("text-1", "x")
assert not b.is_empty()
def test_tool_call_breaks_emptiness(self):
b = AssistantContentBuilder()
b.on_tool_input_start("call_x", "ls", None)
assert not b.is_empty()
def test_thinking_step_alone_does_not_break_emptiness(self):
# Mirrors the "status marker fallback" semantic: a turn that
# only emitted a thinking step before being interrupted should
# still be treated as empty for finalize_assistant_turn's
# status-marker substitution.
b = AssistantContentBuilder()
b.on_thinking_step("step-1", "Working", "in_progress", [])
assert b.is_empty()
def test_step_separator_alone_does_not_break_emptiness(self):
b = AssistantContentBuilder()
# Force a separator (it would normally no-op without content,
# but we simulate the underlying state to verify is_empty is
# not fooled by a stray separator).
b.parts.append({"type": "data-step-separator", "data": {"stepIndex": 0}})
assert b.is_empty()
class TestSnapshotSemantics:
def test_snapshot_is_deep_copied_so_mutations_do_not_leak(self):
b = AssistantContentBuilder()
b.on_tool_input_start("call_x", "ls", "lc_x")
b.on_tool_input_available("call_x", "ls", {"path": "/"}, "lc_x")
snap = b.snapshot()
# Mutate the returned snapshot — original should be untouched.
snap[0]["args"]["mutated"] = True
snap[0]["state"] = "tampered"
again = b.snapshot()
assert "mutated" not in again[0]["args"]
assert "state" not in again[0]
def test_snapshot_round_trips_through_json(self):
b = AssistantContentBuilder()
b.on_thinking_step("step-1", "Analyzing", "in_progress", ["item"])
b.on_text_delta("text-1", "answer")
b.on_tool_input_start("call_x", "ls", "lc_x")
b.on_tool_input_available("call_x", "ls", {"path": "/"}, "lc_x")
b.on_tool_output_available("call_x", {"files": ["a.txt"]}, "lc_x")
b.on_step_separator()
snap = b.snapshot()
encoded = json.dumps(snap)
assert json.loads(encoded) == snap
class TestStats:
"""``stats()`` is the perf-log handle for [PERF] [stream_*]
finalize_payload lines. Pin the schema so an ops dashboard can
rely on these keys being present and meaningful.
"""
def test_fresh_builder_reports_all_zeros(self):
b = AssistantContentBuilder()
s = b.stats()
assert s == {
"parts": 0,
"bytes": 2, # ``[]`` is two bytes
"text": 0,
"reasoning": 0,
"tool_calls": 0,
"tool_calls_completed": 0,
"tool_calls_aborted": 0,
"thinking_step_parts": 0,
"step_separators": 0,
}
def test_counts_each_part_type_independently(self):
b = AssistantContentBuilder()
b.on_text_start("t1")
b.on_text_delta("t1", "hi")
b.on_text_end("t1")
b.on_reasoning_start("r1")
b.on_reasoning_delta("r1", "thinking")
b.on_reasoning_end("r1")
b.on_thinking_step("step-1", "Analyzing", "completed", ["item"])
b.on_step_separator()
b.on_tool_input_start("call_done", "ls", "lc_done")
b.on_tool_input_available("call_done", "ls", {}, "lc_done")
b.on_tool_output_available("call_done", {"ok": True}, "lc_done")
b.on_tool_input_start("call_running", "rm", "lc_running")
b.on_tool_input_available("call_running", "rm", {}, "lc_running")
s = b.stats()
assert s["text"] == 1
assert s["reasoning"] == 1
assert s["tool_calls"] == 2
assert s["tool_calls_completed"] == 1
assert s["tool_calls_aborted"] == 0
assert s["thinking_step_parts"] == 1
assert s["step_separators"] == 1
assert s["parts"] == sum(
[
s["text"],
s["reasoning"],
s["tool_calls"],
s["thinking_step_parts"],
s["step_separators"],
]
)
assert s["bytes"] > 0
def test_mark_interrupted_flips_running_calls_to_aborted_in_stats(self):
b = AssistantContentBuilder()
b.on_tool_input_start("call_done", "ls", "lc_done")
b.on_tool_input_available("call_done", "ls", {}, "lc_done")
b.on_tool_output_available("call_done", {"ok": True}, "lc_done")
b.on_tool_input_start("call_running", "rm", "lc_running")
b.on_tool_input_available("call_running", "rm", {}, "lc_running")
# Pre-interrupt: one completed, one still running (no result).
pre = b.stats()
assert pre["tool_calls_completed"] == 1
assert pre["tool_calls_aborted"] == 0
b.mark_interrupted()
post = b.stats()
assert post["tool_calls_completed"] == 1
assert post["tool_calls_aborted"] == 1
assert post["tool_calls"] == 2
def test_bytes_reflects_jsonb_payload_size(self):
# Each text-delta adds bytes monotonically — useful for catching
# an unbounded delta buffer regression in the perf signal.
b = AssistantContentBuilder()
b.on_text_start("t1")
b.on_text_delta("t1", "x" * 10)
small = b.stats()["bytes"]
b.on_text_delta("t1", "x" * 1000)
large = b.stats()["bytes"]
assert large > small + 900

View file

@ -457,6 +457,9 @@ def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows()
source = page_path.read_text(encoding="utf-8")
# Each flow tracks accepted boundary and passes it into shared terminal handling.
# The acceptance boundary is still meaningful post-refactor: it gates
# local-state cleanup (onPreAcceptFailure path) and lets the shared
# terminal handler distinguish pre-accept aborts from in-stream errors.
assert "let newAccepted = false;" in source
assert "let resumeAccepted = false;" in source
assert "let regenerateAccepted = false;" in source
@ -464,12 +467,23 @@ def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows()
assert "accepted: resumeAccepted," in source
assert "accepted: regenerateAccepted," in source
# Pre-accept abort in resume/regenerate exits without persistence.
assert "if (!resumeAccepted) return;" in source
assert "if (!regenerateAccepted) return;" in source
# NOTE: The FE-side persistence guards previously asserted here
# ("if (!resumeAccepted) return;", "if (!regenerateAccepted) return;",
# "if (newAccepted && !userPersisted) {") have been intentionally
# removed by the SSE-based message-id handshake refactor. Persistence
# is now server-authoritative: persist_user_turn / persist_assistant_shell
# run inside stream_new_chat / stream_resume_chat unconditionally and
# the FE consumes data-user-message-id / data-assistant-message-id
# SSE events to learn the canonical primary keys. There is therefore
# no FE call-site to guard, and the shared terminal handler relies
# purely on the `accepted` field above (forwarded to onAbort /
# onAcceptedStreamError) to drive UI cleanup. See
# tests/integration/chat/test_message_id_sse.py for the new
# cross-tier ID coherence guarantees.
# New flow persists only when accepted and not already persisted.
assert "if (newAccepted && !userPersisted) {" in source
# The TURN_CANCELLING / THREAD_BUSY retry plumbing is independent
# of the persistence refactor and must still exist on every
# start-stream fetch.
assert "const fetchWithTurnCancellingRetry = useCallback(" in source
assert "computeFallbackTurnCancellingRetryDelay" in source
assert 'withMeta.errorCode === "TURN_CANCELLING"' in source

View file

@ -82,6 +82,7 @@ import {
mergeChatTurnIdIntoMessage,
mergeEditedInterruptAction,
readStreamedChatTurnId,
readStreamedMessageId,
} from "@/lib/chat/stream-side-effects";
import {
buildContentForPersistence,
@ -256,110 +257,17 @@ export default function NewChatPage() {
[tokenUsageStore]
);
const persistUserTurn = useCallback(
async ({
threadId,
userMsgId,
content,
mentionedDocs,
turnId,
logContext,
}: {
threadId: number | null;
userMsgId: string;
content: unknown;
mentionedDocs?: MentionedDocumentInfo[];
turnId?: string | null;
logContext: string;
}) => {
if (!threadId) return null;
try {
const normalizedContent = Array.isArray(content) ? ([...content] as unknown[]) : [content];
const hasMentionedDocumentsPart = normalizedContent.some(
(part) => MentionedDocumentsPartSchema.safeParse(part).success
);
if (mentionedDocs && mentionedDocs.length > 0 && !hasMentionedDocumentsPart) {
normalizedContent.push({
type: "mentioned-documents",
documents: mentionedDocs,
});
}
const savedUserMessage = await appendMessage(threadId, {
role: "user",
content: normalizedContent as AppendMessage["content"],
turn_id: turnId,
});
const newUserMsgId = `msg-${savedUserMessage.id}`;
setMessages((prev) =>
prev.map((m) =>
m.id === userMsgId
? mergeChatTurnIdIntoMessage({ ...m, id: newUserMsgId }, savedUserMessage.turn_id)
: m
)
);
if (mentionedDocs && mentionedDocs.length > 0) {
setMessageDocumentsMap((prev) => {
const { [userMsgId]: _, ...rest } = prev;
return {
...rest,
[newUserMsgId]: mentionedDocs,
};
});
}
return newUserMsgId;
} catch (err) {
console.error(`Failed to persist ${logContext} user message:`, err);
return null;
}
},
[setMessageDocumentsMap]
);
const persistAssistantTurn = useCallback(
async ({
threadId,
assistantMsgId,
content,
tokenUsage,
turnId,
logContext,
onRemapped,
}: {
threadId: number | null;
assistantMsgId: string;
content: unknown;
tokenUsage?: TokenUsageData;
turnId?: string | null;
logContext: string;
onRemapped?: (newMsgId: string) => void;
}) => {
if (!threadId) return null;
try {
const savedMessage = await appendMessage(threadId, {
role: "assistant",
content: content as AppendMessage["content"],
token_usage: tokenUsage,
turn_id: turnId,
});
const newMsgId = `msg-${savedMessage.id}`;
tokenUsageStore.rename(assistantMsgId, newMsgId);
setMessages((prev) =>
prev.map((m) =>
m.id === assistantMsgId
? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id)
: m
)
);
onRemapped?.(newMsgId);
return newMsgId;
} catch (err) {
console.error(`Failed to persist ${logContext} assistant message:`, err);
return null;
}
},
[tokenUsageStore]
);
// NOTE: ``persistUserTurn`` / ``persistAssistantTurn`` callbacks
// were removed in the SSE-based message ID handshake refactor.
// ``stream_new_chat`` and ``stream_resume_chat`` now persist both
// the user and assistant rows server-side via
// ``persist_user_turn`` / ``persist_assistant_shell`` and emit
// ``data-user-message-id`` / ``data-assistant-message-id`` SSE
// events; the consumers below rename the optimistic ids in real
// time. ``persistAssistantErrorMessage`` (above) is intentionally
// kept — it is the pre-stream-error fallback fired when the
// server NEVER accepted the request, and the BE has nothing to
// persist in that case.
// Get disabled tools from the tool toggle UI
const disabledTools = useAtomValue(disabledToolsAtom);
@ -891,8 +799,13 @@ export default function NewChatPage() {
setPendingUserImageUrls((prev) => prev.filter((u) => !urlsSnapshot.includes(u)));
}
// Add user message to state
const userMsgId = `msg-user-${Date.now()}`;
// Add user message to state. Mutable because the SSE
// ``data-user-message-id`` handler (below) renames this
// optimistic id to the canonical ``msg-{db_id}`` once the
// backend's ``persist_user_turn`` resolves the row, and
// the in-stream flush / interrupt closures need to see
// the post-rename value via this live ``let`` binding.
let userMsgId = `msg-user-${Date.now()}`;
// Always include author metadata so the UI layer can decide visibility
const authorMetadata = currentUser
@ -958,22 +871,16 @@ export default function NewChatPage() {
}));
}
const persistContent: unknown[] = [...userDisplayContent];
if (allMentionedDocs.length > 0) {
persistContent.push({
type: "mentioned-documents",
documents: allMentionedDocs,
});
}
// Start streaming response
setIsRunning(true);
const controller = new AbortController();
abortControllerRef.current = controller;
// Prepare assistant message
const assistantMsgId = `msg-assistant-${Date.now()}`;
// Prepare assistant message. Mutable for the same reason
// as ``userMsgId`` above — the ``data-assistant-message-id``
// SSE handler reassigns this once
// ``persist_assistant_shell`` returns its canonical id.
let assistantMsgId = `msg-assistant-${Date.now()}`;
const currentThinkingSteps = new Map<string, ThinkingStepData>();
const contentPartsState: ContentPartsState = {
contentParts: [],
@ -983,11 +890,7 @@ export default function NewChatPage() {
};
const { contentParts } = contentPartsState;
let wasInterrupted = false;
let tokenUsageData: TokenUsageData | null = null;
let newAccepted = false;
let userPersisted = false;
// Captured from ``data-turn-info`` at stream start.
let streamedChatTurnId: string | null = null;
let streamBatcher: FrameBatchedUpdater | null = null;
try {
@ -1047,6 +950,18 @@ export default function NewChatPage() {
mentioned_surfsense_doc_ids: hasSurfsenseDocIds
? mentionedDocumentIds.surfsense_doc_ids
: undefined,
// Full mention metadata so the BE can embed a
// ``mentioned-documents`` ContentPart on the
// persisted user message (replaces the old FE-side
// injection in ``persistUserTurn``).
mentioned_documents:
allMentionedDocs.length > 0
? allMentionedDocs.map((d) => ({
id: d.id,
title: d.title,
document_type: d.document_type,
}))
: undefined,
disabled_tools: disabledTools.length > 0 ? disabledTools : undefined,
...(userImages.length > 0 ? { user_images: userImages } : {}),
}),
@ -1089,7 +1004,6 @@ export default function NewChatPage() {
scheduleFlush,
forceFlush,
onTokenUsage: (data) => {
tokenUsageData = data;
tokenUsageStore.set(assistantMsgId, data);
},
onTurnStatus: (data) => {
@ -1189,7 +1103,6 @@ export default function NewChatPage() {
case "data-turn-info": {
const turnId = readStreamedChatTurnId(parsed.data);
streamedChatTurnId = turnId;
if (turnId) {
setMessages((prev) =>
applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId)
@ -1197,46 +1110,96 @@ export default function NewChatPage() {
}
break;
}
case "data-user-message-id": {
// Server-authoritative user message id resolved by
// ``persist_user_turn`` (or recovered via ON CONFLICT).
// Rename the optimistic ``msg-user-XXX`` placeholder to
// the canonical ``msg-{db_id}`` so DB-id-gated UI
// (comments, edit-from-this-message) unlocks immediately,
// migrate the local mentioned-documents map, and reassign
// the closure variable so all downstream
// ``m.id === userMsgId`` checks see the new value.
const parsedMsg = readStreamedMessageId(parsed.data);
if (!parsedMsg) break;
const newUserMsgId = `msg-${parsedMsg.messageId}`;
const oldUserMsgId = userMsgId;
setMessages((prev) =>
prev.map((m) =>
m.id === oldUserMsgId
? mergeChatTurnIdIntoMessage(
{ ...m, id: newUserMsgId },
parsedMsg.turnId
)
: m
)
);
if (allMentionedDocs.length > 0) {
setMessageDocumentsMap((prev) => {
if (!(oldUserMsgId in prev)) {
return { ...prev, [newUserMsgId]: allMentionedDocs };
}
const { [oldUserMsgId]: _removed, ...rest } = prev;
return { ...rest, [newUserMsgId]: allMentionedDocs };
});
}
userMsgId = newUserMsgId;
if (isNewThread) {
// First user-side row landed in ``new_chat_messages``;
// refresh the sidebar so the freshly-bumped
// ``thread.updated_at`` reorders this thread.
queryClient.invalidateQueries({
queryKey: ["threads", String(searchSpaceId)],
});
}
break;
}
case "data-assistant-message-id": {
// Server-authoritative assistant message id resolved
// by ``persist_assistant_shell``. Rename the optimistic
// id, migrate ``tokenUsageStore`` so any pending
// ``data-token-usage`` payload binds to the new id,
// remap any in-flight ``pendingInterrupt`` reference,
// and reassign the closure variable so the in-stream
// flush callback (line ~1074) keeps writing to the
// renamed message.
const parsedMsg = readStreamedMessageId(parsed.data);
if (!parsedMsg) break;
const newAssistantMsgId = `msg-${parsedMsg.messageId}`;
const oldAssistantMsgId = assistantMsgId;
tokenUsageStore.rename(oldAssistantMsgId, newAssistantMsgId);
setMessages((prev) =>
prev.map((m) =>
m.id === oldAssistantMsgId
? mergeChatTurnIdIntoMessage(
{ ...m, id: newAssistantMsgId },
parsedMsg.turnId
)
: m
)
);
setPendingInterrupt((prev) =>
prev && prev.assistantMsgId === oldAssistantMsgId
? { ...prev, assistantMsgId: newAssistantMsgId }
: prev
);
assistantMsgId = newAssistantMsgId;
break;
}
}
});
batcher.flush();
// Skip persistence for interrupted messages -- handleResume will persist the final version
const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI);
// Server-authoritative persistence: ``stream_new_chat``
// already wrote the user row in ``persist_user_turn``
// (the FE renamed the optimistic id mid-stream via
// ``data-user-message-id``) and finalises the assistant
// row in ``finalize_assistant_turn`` from a shielded
// ``finally`` block. Nothing left for the FE to persist
// here — track the response and unblock the UI.
if (contentParts.length > 0 && !wasInterrupted) {
if (!userPersisted) {
const persistedUserMsgId = await persistUserTurn({
threadId: currentThreadId,
userMsgId,
content: persistContent,
mentionedDocs: allMentionedDocs,
turnId: streamedChatTurnId,
logContext: "new chat",
});
userPersisted = Boolean(persistedUserMsgId);
if (userPersisted && isNewThread) {
queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] });
}
}
await persistAssistantTurn({
threadId: currentThreadId,
assistantMsgId,
content: finalContent,
tokenUsage: tokenUsageData ?? undefined,
turnId: streamedChatTurnId,
logContext: "new chat",
onRemapped: (newMsgId) => {
setPendingInterrupt((prev) =>
prev && prev.assistantMsgId === assistantMsgId
? { ...prev, assistantMsgId: newMsgId }
: prev
);
},
});
// Track successful response
trackChatResponseReceived(searchSpaceId, currentThreadId);
}
} catch (error) {
@ -1247,51 +1210,21 @@ export default function NewChatPage() {
threadId: currentThreadId,
assistantMsgId,
accepted: newAccepted,
onAbort: async () => {
if (newAccepted && !userPersisted) {
const persistedUserMsgId = await persistUserTurn({
threadId: currentThreadId,
userMsgId,
content: persistContent,
mentionedDocs: allMentionedDocs,
turnId: streamedChatTurnId,
logContext: "new chat (aborted)",
});
userPersisted = Boolean(persistedUserMsgId);
if (userPersisted && isNewThread) {
queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] });
}
}
const hasContent = hasPersistableContent(contentParts, toolsWithUI);
if (hasContent && currentThreadId) {
const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI);
await persistAssistantTurn({
threadId: currentThreadId,
assistantMsgId,
content: partialContent,
turnId: streamedChatTurnId,
logContext: "partial new chat",
});
}
},
onAcceptedStreamError: async () => {
if (!userPersisted) {
const persistedUserMsgId = await persistUserTurn({
threadId: currentThreadId,
userMsgId,
content: persistContent,
mentionedDocs: allMentionedDocs,
turnId: streamedChatTurnId,
logContext: "new chat (stream error)",
});
userPersisted = Boolean(persistedUserMsgId);
if (userPersisted && isNewThread) {
queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] });
}
}
},
// Server-side ``finalize_assistant_turn`` runs from a
// shielded ``anyio.CancelScope(shield=True)`` finally
// block, so partial content (incl. abort-mid-stream)
// is already persisted by the BE for the assistant
// row, and ``persist_user_turn`` ran before any LLM
// call. The FE's only remaining responsibility on
// abort / accepted-stream-error is to surface the
// error toast (handled by ``handleStreamTerminalError``
// itself).
onPreAcceptFailure: async () => {
// Pre-accept failure means the BE never accepted the
// request — no server-side persistence ran. Roll
// back the optimistic UI insertions we made before
// the fetch so the user message and any local
// mentioned-docs metadata don't linger.
setMessages((prev) => prev.filter((m) => m.id !== userMsgId));
setMessageDocumentsMap((prev) => {
if (!(userMsgId in prev)) return prev;
@ -1325,8 +1258,6 @@ export default function NewChatPage() {
fetchWithTurnCancellingRetry,
handleStreamTerminalError,
handleChatFailure,
persistAssistantTurn,
persistUserTurn,
]
);
@ -1339,7 +1270,12 @@ export default function NewChatPage() {
}>
) => {
if (!pendingInterrupt) return;
const { threadId: resumeThreadId, assistantMsgId } = pendingInterrupt;
const { threadId: resumeThreadId } = pendingInterrupt;
// Destructured separately as ``let`` so the SSE
// ``data-assistant-message-id`` handler (resume always
// allocates a fresh server-side row) can rename it to
// the canonical ``msg-{db_id}`` mid-stream.
let assistantMsgId = pendingInterrupt.assistantMsgId;
setPendingInterrupt(null);
setIsRunning(true);
@ -1362,10 +1298,7 @@ export default function NewChatPage() {
toolCallIndices: new Map(),
};
const { contentParts, toolCallIndices } = contentPartsState;
let tokenUsageData: TokenUsageData | null = null;
let resumeAccepted = false;
// Captured from ``data-turn-info`` at stream start.
let streamedChatTurnId: string | null = null;
let streamBatcher: FrameBatchedUpdater | null = null;
const existingMsg = messages.find((m) => m.id === assistantMsgId);
@ -1466,7 +1399,6 @@ export default function NewChatPage() {
scheduleFlush,
forceFlush,
onTokenUsage: (data) => {
tokenUsageData = data;
tokenUsageStore.set(assistantMsgId, data);
},
onTurnStatus: (data) => {
@ -1514,7 +1446,6 @@ export default function NewChatPage() {
case "data-turn-info": {
const turnId = readStreamedChatTurnId(parsed.data);
streamedChatTurnId = turnId;
if (turnId) {
setMessages((prev) =>
applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId)
@ -1522,22 +1453,44 @@ export default function NewChatPage() {
}
break;
}
case "data-assistant-message-id": {
// Resume always allocates a fresh ``new_chat_messages``
// row anchored to a new ``turn_id`` (the original
// interrupted turn's row stays as-is), so this is a
// real id swap. Rename the optimistic placeholder to
// ``msg-{db_id}`` and reassign closure state. Resume
// does NOT emit ``data-user-message-id`` — the user
// row belongs to the original interrupted turn.
const parsedMsg = readStreamedMessageId(parsed.data);
if (!parsedMsg) break;
const newAssistantMsgId = `msg-${parsedMsg.messageId}`;
const oldAssistantMsgId = assistantMsgId;
tokenUsageStore.rename(oldAssistantMsgId, newAssistantMsgId);
setMessages((prev) =>
prev.map((m) =>
m.id === oldAssistantMsgId
? mergeChatTurnIdIntoMessage(
{ ...m, id: newAssistantMsgId },
parsedMsg.turnId
)
: m
)
);
assistantMsgId = newAssistantMsgId;
break;
}
}
});
batcher.flush();
const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI);
if (contentParts.length > 0) {
await persistAssistantTurn({
threadId: resumeThreadId,
assistantMsgId,
content: finalContent,
tokenUsage: tokenUsageData ?? undefined,
turnId: streamedChatTurnId,
logContext: "resumed chat",
});
}
// Server-authoritative persistence: ``stream_resume_chat``
// finalises the assistant row in
// ``finalize_assistant_turn`` from a shielded
// ``finally`` block (covers both happy-path and
// abort-mid-stream). FE has no remaining persistence
// work here.
} catch (error) {
streamBatcher?.dispose();
await handleStreamTerminalError({
@ -1546,19 +1499,6 @@ export default function NewChatPage() {
threadId: resumeThreadId,
assistantMsgId,
accepted: resumeAccepted,
onAbort: async () => {
if (!resumeAccepted) return;
const hasContent = hasPersistableContent(contentParts, toolsWithUI);
if (!hasContent) return;
const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI);
await persistAssistantTurn({
threadId: resumeThreadId,
assistantMsgId,
content: partialContent,
turnId: streamedChatTurnId,
logContext: "partial resumed chat",
});
},
});
} finally {
setIsRunning(false);
@ -1574,7 +1514,6 @@ export default function NewChatPage() {
tokenUsageStore,
fetchWithTurnCancellingRetry,
handleStreamTerminalError,
persistAssistantTurn,
]
);
@ -1715,9 +1654,12 @@ export default function NewChatPage() {
const controller = new AbortController();
abortControllerRef.current = controller;
// Add placeholder user message if we have a new query (edit mode)
const userMsgId = `msg-user-${Date.now()}`;
const assistantMsgId = `msg-assistant-${Date.now()}`;
// Add placeholder user message if we have a new query (edit mode).
// Mutable for the same reason as in ``onNew`` — both ids are
// renamed mid-stream by the new ``data-user-message-id`` /
// ``data-assistant-message-id`` SSE handlers below.
let userMsgId = `msg-user-${Date.now()}`;
let assistantMsgId = `msg-assistant-${Date.now()}`;
const currentThinkingSteps = new Map<string, ThinkingStepData>();
const contentPartsState: ContentPartsState = {
@ -1727,13 +1669,7 @@ export default function NewChatPage() {
toolCallIndices: new Map(),
};
const { contentParts } = contentPartsState;
let tokenUsageData: TokenUsageData | null = null;
let regenerateAccepted = false;
let userPersisted = false;
// Captured from ``data-turn-info`` at stream start; stamped
// onto persisted messages so future edits can locate the
// right LangGraph checkpoint.
let streamedChatTurnId: string | null = null;
let streamBatcher: FrameBatchedUpdater | null = null;
// Add placeholder messages to UI
@ -1747,9 +1683,6 @@ export default function NewChatPage() {
createdAt: new Date(),
metadata: isEdit ? undefined : originalUserMessageMetadata,
};
const userContentToPersist = isEdit
? (editExtras?.userMessageContent ?? [{ type: "text", text: newUserQuery ?? "" }])
: originalUserMessageContent || [{ type: "text", text: userQueryToDisplay || "" }];
const sourceMentionedDocs =
sourceUserMessageId && messageDocumentsMap[sourceUserMessageId]
? messageDocumentsMap[sourceUserMessageId]
@ -1765,6 +1698,18 @@ export default function NewChatPage() {
filesystem_mode: selection.filesystem_mode,
client_platform: selection.client_platform,
local_filesystem_mounts: selection.local_filesystem_mounts,
// Full mention metadata for the regenerate-specific
// source list. Only meaningful for edit (the BE only
// re-persists a user row when ``user_query`` is set);
// reload reuses the original turn's mentioned_documents.
mentioned_documents:
sourceMentionedDocs.length > 0
? sourceMentionedDocs.map((d) => ({
id: d.id,
title: d.title,
document_type: d.document_type,
}))
: undefined,
};
if (isEdit) {
requestBody.user_images = editExtras?.userImages ?? [];
@ -1852,7 +1797,6 @@ export default function NewChatPage() {
scheduleFlush,
forceFlush,
onTokenUsage: (data) => {
tokenUsageData = data;
tokenUsageStore.set(assistantMsgId, data);
},
onTurnStatus: (data) => {
@ -1897,7 +1841,6 @@ export default function NewChatPage() {
case "data-turn-info": {
const turnId = readStreamedChatTurnId(parsed.data);
streamedChatTurnId = turnId;
if (turnId) {
setMessages((prev) =>
applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId)
@ -1906,6 +1849,57 @@ export default function NewChatPage() {
break;
}
case "data-user-message-id": {
// Same role as in ``onNew`` but the regenerate-specific
// mention metadata (``sourceMentionedDocs``) is the
// list to migrate onto the canonical id key.
const parsedMsg = readStreamedMessageId(parsed.data);
if (!parsedMsg) break;
const newUserMsgId = `msg-${parsedMsg.messageId}`;
const oldUserMsgId = userMsgId;
setMessages((prev) =>
prev.map((m) =>
m.id === oldUserMsgId
? mergeChatTurnIdIntoMessage(
{ ...m, id: newUserMsgId },
parsedMsg.turnId
)
: m
)
);
if (sourceMentionedDocs.length > 0) {
setMessageDocumentsMap((prev) => {
if (!(oldUserMsgId in prev)) {
return { ...prev, [newUserMsgId]: sourceMentionedDocs };
}
const { [oldUserMsgId]: _removed, ...rest } = prev;
return { ...rest, [newUserMsgId]: sourceMentionedDocs };
});
}
userMsgId = newUserMsgId;
break;
}
case "data-assistant-message-id": {
const parsedMsg = readStreamedMessageId(parsed.data);
if (!parsedMsg) break;
const newAssistantMsgId = `msg-${parsedMsg.messageId}`;
const oldAssistantMsgId = assistantMsgId;
tokenUsageStore.rename(oldAssistantMsgId, newAssistantMsgId);
setMessages((prev) =>
prev.map((m) =>
m.id === oldAssistantMsgId
? mergeChatTurnIdIntoMessage(
{ ...m, id: newAssistantMsgId },
parsedMsg.turnId
)
: m
)
);
assistantMsgId = newAssistantMsgId;
break;
}
case "data-revert-results": {
const summary = parsed.data;
// failureCount must include every "not undone" bucket
@ -1946,28 +1940,14 @@ export default function NewChatPage() {
batcher.flush();
// Persist messages after streaming completes
const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI);
// Server-authoritative persistence: ``stream_new_chat``
// (regenerate flow) wrote the user row in
// ``persist_user_turn`` and finalises the assistant row
// in ``finalize_assistant_turn`` from a shielded
// ``finally`` block (covers both happy-path and
// abort-mid-stream). FE only needs to track the
// successful response here.
if (contentParts.length > 0) {
const persistedUserMsgId = await persistUserTurn({
threadId,
userMsgId,
content: userContentToPersist,
mentionedDocs: sourceMentionedDocs,
turnId: streamedChatTurnId,
logContext: "regenerated",
});
userPersisted = Boolean(persistedUserMsgId);
await persistAssistantTurn({
threadId,
assistantMsgId,
content: finalContent,
tokenUsage: tokenUsageData ?? undefined,
turnId: streamedChatTurnId,
logContext: "regenerated",
});
trackChatResponseReceived(searchSpaceId, threadId);
}
} catch (error) {
@ -1978,44 +1958,6 @@ export default function NewChatPage() {
threadId,
assistantMsgId,
accepted: regenerateAccepted,
onAbort: async () => {
if (!regenerateAccepted) return;
if (!userPersisted) {
const persistedUserMsgId = await persistUserTurn({
threadId,
userMsgId,
content: userContentToPersist,
mentionedDocs: sourceMentionedDocs,
turnId: streamedChatTurnId,
logContext: "regenerated (aborted)",
});
userPersisted = Boolean(persistedUserMsgId);
}
const hasContent = hasPersistableContent(contentParts, toolsWithUI);
if (!hasContent) return;
const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI);
await persistAssistantTurn({
threadId,
assistantMsgId,
content: partialContent,
tokenUsage: tokenUsageData ?? undefined,
turnId: streamedChatTurnId,
logContext: "partial regenerated chat",
});
},
onAcceptedStreamError: async () => {
if (!userPersisted) {
const persistedUserMsgId = await persistUserTurn({
threadId,
userMsgId,
content: userContentToPersist,
mentionedDocs: sourceMentionedDocs,
turnId: streamedChatTurnId,
logContext: "regenerated (stream error)",
});
userPersisted = Boolean(persistedUserMsgId);
}
},
});
} finally {
setIsRunning(false);
@ -2034,8 +1976,6 @@ export default function NewChatPage() {
tokenUsageStore,
fetchWithTurnCancellingRetry,
handleStreamTerminalError,
persistAssistantTurn,
persistUserTurn,
]
);

View file

@ -114,6 +114,29 @@ export function readStreamedChatTurnId(data: unknown): string | null {
return typeof value === "string" && value.length > 0 ? value : null;
}
/**
* Parse the payload of `data-user-message-id` / `data-assistant-message-id`
* SSE events emitted by `stream_new_chat` and `stream_resume_chat` after
* `persist_user_turn` / `persist_assistant_shell` resolve a canonical
* `new_chat_messages.id`. Mirrors {@link readStreamedChatTurnId}.
*
* Returns `null` when the payload is malformed (missing or non-numeric
* `message_id`); callers should treat this as "ignore the event" so a
* malformed BE payload never overwrites the optimistic id with a bogus
* value.
*/
export function readStreamedMessageId(
data: unknown
): { messageId: number; turnId: string | null } | null {
if (typeof data !== "object" || data === null) return null;
const obj = data as { message_id?: unknown; turn_id?: unknown };
if (typeof obj.message_id !== "number" || !Number.isFinite(obj.message_id)) {
return null;
}
const turnId = typeof obj.turn_id === "string" && obj.turn_id.length > 0 ? obj.turn_id : null;
return { messageId: obj.message_id, turnId };
}
export function applyTurnIdToAssistantMessageList(
messages: ThreadMessageLike[],
assistantMsgId: string,

View file

@ -487,6 +487,37 @@ export type SSEEvent =
type: "data-turn-info";
data: { chat_turn_id: string };
}
| {
/**
* Emitted by ``stream_new_chat`` AFTER ``data-turn-info`` /
* ``data-turn-status`` and BEFORE any LLM streaming events,
* once ``persist_user_turn`` has resolved the canonical
* ``new_chat_messages.id`` for the user-side row of the
* current turn. The frontend renames its optimistic
* ``msg-user-XXX`` placeholder id to ``msg-{message_id}``
* so DB-id-gated UI (comments, edit-from-this-message)
* unlocks immediately. Not emitted by ``stream_resume_chat``
* (resume reuses the original turn's user message).
*/
type: "data-user-message-id";
data: { message_id: number; turn_id: string };
}
| {
/**
* Emitted by ``stream_new_chat`` AND ``stream_resume_chat``
* AFTER ``data-turn-info`` / ``data-turn-status`` and BEFORE
* any LLM streaming events, once ``persist_assistant_shell``
* has resolved the canonical ``new_chat_messages.id`` for
* the assistant-side row of the current turn. The frontend
* renames its optimistic ``msg-assistant-XXX`` placeholder
* id, migrates the local ``tokenUsageStore`` and
* ``pendingInterrupt`` references, and binds the running
* mutable ``assistantMsgId`` closure variable to the
* canonical id for the rest of the stream.
*/
type: "data-assistant-message-id";
data: { message_id: number; turn_id: string };
}
| {
/**
* Best-effort revert pass that ran BEFORE this regeneration.

View file

@ -144,6 +144,17 @@ export async function getThreadMessages(threadId: number): Promise<ThreadHistory
* via ``data-turn-info``. Persisting it lets later edits locate the
* matching LangGraph checkpoint without HumanMessage scanning. Older
* callers can still omit it for back-compat.
*
* @deprecated Replaced by the SSE-based message ID handshake. The
* streaming generator (`stream_new_chat` / `stream_resume_chat`) now
* persists both the user and assistant rows server-side via
* `persist_user_turn` / `persist_assistant_shell` and emits
* `data-user-message-id` / `data-assistant-message-id` SSE events so
* the UI renames its optimistic IDs in real time. The only remaining
* caller is `persistAssistantErrorMessage` (pre-stream error fallback
* for requests the server never accepted the server has nothing to
* persist in that case). After the legacy route is removed in a
* follow-up PR this function will be deleted entirely.
*/
export async function appendMessage(
threadId: number,