feat: moved chat persistance to Server Side

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-05-04 03:06:15 -07:00
parent 2e1b9b5582
commit 19b6e0a025
19 changed files with 4515 additions and 390 deletions

View file

@ -675,6 +675,23 @@ class NewChatMessage(BaseModel, TimestampMixin):
__tablename__ = "new_chat_messages"
# Partial unique index on (thread_id, turn_id, role) where turn_id IS NOT NULL.
# Mirrors alembic migration 141. Lets the streaming agent and the
# legacy frontend appendMessage call coexist idempotently — the second
# writer trips the unique and recovers without creating a duplicate row.
# Partial so legacy NULL turn_id rows and clone/snapshot inserts in
# app/services/public_chat_service.py (which omit turn_id) are unaffected.
__table_args__ = (
Index(
"uq_new_chat_messages_thread_turn_role",
"thread_id",
"turn_id",
"role",
unique=True,
postgresql_where=text("turn_id IS NOT NULL"),
),
)
role = Column(SQLAlchemyEnum(NewChatMessageRole), nullable=False)
# Content stored as JSONB to support rich content (text, tool calls, etc.)
content = Column(JSONB, nullable=False)
@ -728,6 +745,22 @@ class TokenUsage(BaseModel, TimestampMixin):
__tablename__ = "token_usage"
# Partial unique index on (message_id) where message_id IS NOT NULL.
# Mirrors alembic migration 142. Lets the streaming agent's
# ``finalize_assistant_turn`` and the legacy frontend ``append_message``
# recovery branch both use ``INSERT ... ON CONFLICT DO NOTHING`` without
# racing on a SELECT-then-INSERT window. Partial so non-chat usage rows
# (indexing, image generation, podcasts) — which keep ``message_id`` NULL
# because there is no per-message anchor — are unaffected.
__table_args__ = (
Index(
"uq_token_usage_message_id",
"message_id",
unique=True,
postgresql_where=text("message_id IS NOT NULL"),
),
)
prompt_tokens = Column(Integer, nullable=False, default=0)
completion_tokens = Column(Integer, nullable=False, default=0)
total_tokens = Column(Integer, nullable=False, default=0)

View file

@ -17,7 +17,8 @@ from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from sqlalchemy import func, or_
from sqlalchemy import func, or_, text as sa_text
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
@ -44,6 +45,7 @@ from app.db import (
NewChatThread,
Permission,
SearchSpace,
TokenUsage,
User,
get_async_session,
shielded_async_session,
@ -69,9 +71,9 @@ from app.schemas.new_chat import (
TokenUsageSummary,
TurnStatusResponse,
)
from app.services.token_tracking_service import record_token_usage
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
from app.users import current_active_user
from app.utils.perf import get_perf_logger
from app.utils.rbac import check_permission
from app.utils.user_message_multimodal import (
split_langchain_human_content,
@ -79,6 +81,7 @@ from app.utils.user_message_multimodal import (
)
_logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
_background_tasks: set[asyncio.Task] = set()
TURN_CANCELLING_INITIAL_DELAY_MS = 200
TURN_CANCELLING_BACKOFF_FACTOR = 2
@ -1287,6 +1290,24 @@ async def append_message(
user: User = Depends(current_active_user),
):
"""
.. deprecated:: 2026-05
Replaced by the **SSE-based message ID handshake**. The streaming
generator (`stream_new_chat` / `stream_resume_chat`) now persists
both the user and assistant rows server-side via
``persist_user_turn`` / ``persist_assistant_shell`` and emits
``data-user-message-id`` / ``data-assistant-message-id`` SSE events
so the frontend can rename its optimistic IDs in real time. The
new FE bundle no longer calls this route.
This handler is retained as a **silent no-op for legacy / cached
FE bundles**: the underlying ``INSERT ... ON CONFLICT DO NOTHING``
pattern means a stale bundle hitting this route after the SSE
handshake already wrote the row simply returns the existing row
(200 OK) without raising or duplicating data. After a 2-week soak
(target: ``[persist_user_turn] outcome=race_recovered`` rate ~0)
this entire route and the FE ``appendMessage`` function is
earmarked for removal.
Append a message to a thread.
This is used by ThreadHistoryAdapter.append() to persist messages.
@ -1297,6 +1318,22 @@ async def append_message(
Requires CHATS_UPDATE permission.
"""
try:
# Capture ``user.id`` as a primitive UUID up front. The
# ``current_active_user`` dependency hands us a ``User`` ORM
# row bound to ``session``; if the outer ``except
# IntegrityError`` block below ever fires (an unexpected
# constraint like a foreign key violation — the common
# ``(thread_id, turn_id, role)`` race is now handled silently
# by ``ON CONFLICT DO NOTHING`` so it never raises) it calls
# ``session.rollback()``, which expires every attached ORM
# row including this user. Any later ``user.id`` access would
# then trigger a lazy PK reload — which on async SQLAlchemy
# fails with ``MissingGreenlet`` because the reload happens
# outside the awaitable greenlet boundary. Reading ``id``
# once here pins the value as a plain UUID so all downstream
# uses (TokenUsage insert, response build) are immune.
user_uuid = user.id
# Parse raw body - extract only role and content, ignoring extra fields
raw_body = await request.json()
role = raw_body.get("role")
@ -1351,42 +1388,166 @@ async def append_message(
else None
)
db_message = NewChatMessage(
thread_id=thread_id,
role=message_role,
content=content,
author_id=user.id,
turn_id=turn_id_value,
)
session.add(db_message)
# Update thread's updated_at timestamp
# Update thread's updated_at timestamp (always — both insert
# and recovery paths represent thread activity).
thread.updated_at = datetime.now(UTC)
# flush assigns the PK/defaults without a round-trip SELECT
await session.flush()
# Insert the new message via ``INSERT ... ON CONFLICT DO NOTHING``
# keyed on the ``(thread_id, turn_id, role)`` partial unique
# index from migration 141 (``WHERE turn_id IS NOT NULL``).
#
# Why ON CONFLICT instead of ``session.add() + flush() + except
# IntegrityError``:
# 1. The conflict between this legacy FE ``appendMessage``
# round-trip and the server-side
# ``finalize_assistant_turn`` writer is a NORMAL,
# *expected* race — every assistant turn fires it. Using
# catch-and-recover means asyncpg raises
# ``UniqueViolationError`` -> SQLAlchemy wraps it as
# ``IntegrityError`` -> our handler catches and recovers.
# Functionally fine, but every ``raise`` event lights up
# VS Code's debugger (debugpy's ``justMyCode=false`` mode
# loses track of the catch frame across SQLAlchemy's
# async greenlet boundary, so even ``Raised Exceptions``
# being unchecked doesn't reliably suppress the pause).
# ON CONFLICT pushes the conflict resolution into Postgres
# where no Python exception is constructed at all.
# 2. No ``session.rollback()`` -> no expiring of attached
# ORM rows -> no risk of ``MissingGreenlet`` from
# lazy-loading expired user/thread state later in the
# handler.
# 3. Cleaner production logs (no SQLAlchemy ``IntegrityError``
# tracebacks emitted by uvicorn's logger between the
# ``raise`` and our ``except``).
#
# When ``turn_id_value`` is ``None`` the partial index doesn't
# apply and the INSERT proceeds normally. Other constraint
# violations (FK, NOT NULL, etc.) still raise ``IntegrityError``
# and are caught by the outer ``except IntegrityError`` block
# to preserve the legacy 400 behavior.
#
# Note on ``content``: when we recover the existing row, we
# intentionally discard the FE's ``content`` payload from
# ``raw_body`` and return the row's existing ``content``. The
# streaming task is now the *authoritative writer* for
# assistant ``ContentPart[]`` shape (mid-stream
# ``AssistantContentBuilder`` -> ``finalize_assistant_turn``)
# so the FE's later ``appendMessage`` is just a stale snapshot
# of the same data — keeping the server-built rich content
# (with full tool-call args / argsText / langchainToolCallId)
# is correct, not lossy.
insert_stmt = (
pg_insert(NewChatMessage)
.values(
thread_id=thread_id,
role=message_role,
content=content,
author_id=user_uuid,
turn_id=turn_id_value,
)
.on_conflict_do_nothing(
index_elements=["thread_id", "turn_id", "role"],
index_where=sa_text("turn_id IS NOT NULL"),
)
.returning(NewChatMessage.id)
)
inserted_id = (await session.execute(insert_stmt)).scalar()
if inserted_id is None:
# Conflict on partial unique index — server-side stream
# already wrote this row. Look it up and reuse it.
if turn_id_value is None:
# Defensive: ON CONFLICT only fires for ``turn_id IS
# NOT NULL`` rows, so this branch should be
# unreachable. Preserve the legacy 400 just in case
# Postgres ever surprises us.
raise HTTPException(
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
lookup = await session.execute(
select(NewChatMessage).filter(
NewChatMessage.thread_id == thread_id,
NewChatMessage.turn_id == turn_id_value,
NewChatMessage.role == message_role,
)
)
existing_message = lookup.scalars().first()
if existing_message is None:
# Conflict reported but the row vanished between
# INSERT and SELECT — extremely unlikely (would
# require a concurrent DELETE within the same
# transaction visibility), but preserve safe
# behavior.
raise HTTPException(
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
db_message = existing_message
# Perf signal: counts how often the legacy FE round-trip
# races the server-side ``finalize_assistant_turn``. A
# rising rate after the rework is OK (it's exactly the
# ghost-thread fix's recovery path firing); a sudden drop
# to zero would mean the FE isn't posting appendMessage
# at all (different bug).
_perf_log.info(
"[append_message] outcome=recovered_via_unique_index "
"thread_id=%s turn_id=%s role=%s message_id=%s",
thread_id,
turn_id_value,
message_role.value,
db_message.id,
)
else:
# INSERT succeeded — load the full ORM row so the
# response can include server-side-defaulted columns
# (``created_at``, etc.) and the relationship surface
# stays consistent with the recovery path.
inserted_row = await session.get(NewChatMessage, inserted_id)
if inserted_row is None:
# Should be impossible: we just inserted it in this
# same transaction. Fail loud if it happens.
raise HTTPException(
status_code=500,
detail="Inserted message could not be loaded.",
) from None
db_message = inserted_row
# Persist token usage if provided (for assistant messages).
# ``cost_micros`` is the provider USD cost reported by LiteLLM,
# forwarded by the FE through the appendMessage round-trip so
# the historical TokenUsage row matches the credit debit applied
# at finalize time.
#
# De-dup: ``finalize_assistant_turn`` may also race to write a
# token_usage row for this same ``message_id`` (cross-session,
# cross-shielded). Use ``INSERT ... ON CONFLICT DO NOTHING`` keyed
# on the ``uq_token_usage_message_id`` partial unique index
# (migration 142). The loser silently drops its insert; exactly
# one row results regardless of which writer commits first.
token_usage_data = raw_body.get("token_usage")
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
await record_token_usage(
session,
usage_type="chat",
search_space_id=thread.search_space_id,
user_id=user.id,
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
completion_tokens=token_usage_data.get("completion_tokens", 0),
total_tokens=token_usage_data.get("total_tokens", 0),
cost_micros=token_usage_data.get("cost_micros", 0),
model_breakdown=token_usage_data.get("usage"),
call_details=token_usage_data.get("call_details"),
thread_id=thread_id,
message_id=db_message.id,
insert_stmt = (
pg_insert(TokenUsage)
.values(
usage_type="chat",
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
completion_tokens=token_usage_data.get("completion_tokens", 0),
total_tokens=token_usage_data.get("total_tokens", 0),
cost_micros=token_usage_data.get("cost_micros", 0),
model_breakdown=token_usage_data.get("usage"),
call_details=token_usage_data.get("call_details"),
thread_id=thread_id,
message_id=db_message.id,
search_space_id=thread.search_space_id,
user_id=user_uuid,
)
.on_conflict_do_nothing(
index_elements=["message_id"],
index_where=sa_text("message_id IS NOT NULL"),
)
)
await session.execute(insert_stmt)
await session.commit()
@ -1406,6 +1567,9 @@ async def append_message(
except HTTPException:
raise
except IntegrityError:
# Any IntegrityError that escaped the inline handler above
# comes from a *different* constraint (foreign key, etc.) —
# preserve the legacy 400 path.
await session.rollback()
raise HTTPException(
status_code=400,
@ -1599,6 +1763,12 @@ async def handle_new_chat(
else None
)
mentioned_documents_payload = (
[doc.model_dump() for doc in request.mentioned_documents]
if request.mentioned_documents
else None
)
return StreamingResponse(
stream_new_chat(
user_query=request.user_query,
@ -1608,6 +1778,7 @@ async def handle_new_chat(
llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids,
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
mentioned_documents=mentioned_documents_payload,
needs_history_bootstrap=thread.needs_history_bootstrap,
thread_visibility=thread.visibility,
current_user_display_name=user.display_name or "A team member",
@ -2078,6 +2249,11 @@ async def regenerate_response(
"data": revert_results,
}
yield f"data: {json.dumps(envelope, default=str)}\n\n".encode()
mentioned_documents_payload = (
[doc.model_dump() for doc in request.mentioned_documents]
if request.mentioned_documents
else None
)
try:
async for chunk in stream_new_chat(
user_query=str(user_query_to_use),
@ -2087,6 +2263,7 @@ async def regenerate_response(
llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids,
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
mentioned_documents=mentioned_documents_payload,
checkpoint_id=target_checkpoint_id,
needs_history_bootstrap=thread.needs_history_bootstrap,
thread_visibility=thread.visibility,

View file

@ -200,6 +200,21 @@ class NewChatUserImagePart(BaseModel):
return to_data_url(self.media_type, self.data)
class MentionedDocumentInfo(BaseModel):
"""Display metadata for a single ``@``-mentioned document.
The full triple ``{id, title, document_type}`` is forwarded by the
frontend mention chip so the server can embed it in the persisted
user message ``ContentPart[]`` (single ``mentioned-documents`` part).
The history loader then renders the chips on reload without an extra
fetch mirrors the pre-refactor frontend ``persistUserTurn`` shape.
"""
id: int
title: str = Field(..., min_length=1, max_length=500)
document_type: str = Field(..., min_length=1, max_length=100)
class NewChatRequest(BaseModel):
"""Request schema for the deep agent chat endpoint."""
@ -213,6 +228,17 @@ class NewChatRequest(BaseModel):
mentioned_surfsense_doc_ids: list[int] | None = (
None # Optional SurfSense documentation IDs mentioned with @ in the chat
)
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
default=None,
description=(
"Display metadata (id, title, document_type) for every "
"@-mentioned document. Persisted as a ``mentioned-documents`` "
"ContentPart on the user message so reload renders chips "
"without an extra fetch. Optional and additive — when None "
"the user message is persisted without a mentioned-documents "
"part."
),
)
disabled_tools: list[str] | None = (
None # Optional list of tool names the user has disabled from the UI
)
@ -264,6 +290,16 @@ class RegenerateRequest(BaseModel):
)
mentioned_document_ids: list[int] | None = None
mentioned_surfsense_doc_ids: list[int] | None = None
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
default=None,
description=(
"Display metadata (id, title, document_type) for every "
"@-mentioned document on the edited user turn. Only used "
"when ``user_query`` is non-None (edit). Persisted as a "
"``mentioned-documents`` ContentPart on the new user "
"message. None means no chip metadata."
),
)
disabled_tools: list[str] | None = None
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
client_platform: Literal["web", "desktop"] = "web"
@ -334,6 +370,16 @@ class ResumeRequest(BaseModel):
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
client_platform: Literal["web", "desktop"] = "web"
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
default=None,
description=(
"Display metadata forwarded for symmetry with /new_chat and "
"/regenerate. Resume reuses the original interrupted user "
"turn so the server does not write a new user message. "
"Currently unused but accepted to keep request bodies "
"uniform across the three streaming entrypoints."
),
)
class CancelActiveTurnResponse(BaseModel):

View file

@ -0,0 +1,515 @@
"""Server-side mirror of the frontend's assistant-ui ``ContentPart`` projection.
Background
----------
The streaming chat task in ``stream_new_chat`` / ``stream_resume_chat`` yields
SSE events that the frontend folds into a ``ContentPartsState`` (see
``surfsense_web/lib/chat/streaming-state.ts`` and the matching pipeline in
``stream-pipeline.ts``). When a turn ends, the frontend calls
``buildContentForPersistence(...)`` and round-trips that ``ContentPart[]``
JSONB to ``POST /threads/{id}/messages``, which is what was historically
written to ``new_chat_messages.content``.
After the ghost-thread fix moved persistence server-side, the assistant
row is written by ``finalize_assistant_turn`` in the streaming finally
block. The frontend's later ``appendMessage`` is now a no-op (recovers
via the ``(thread_id, turn_id, role)`` partial unique index added in
migration 141), which means the *server* is now responsible for
producing the rich ``ContentPart[]`` shape the FE expects on history
reload text + reasoning + tool-call cards (with ``args``, ``argsText``,
``result``, ``langchainToolCallId``) + thinking-step buckets +
step-separators.
This module is the in-memory accumulator that mirrors the FE state for
exactly that purpose. The streaming code calls ``on_text_*`` / ``on_reasoning_*``
/ ``on_tool_*`` / ``on_thinking_step`` / ``on_step_separator`` /
``mark_interrupted`` at the same call sites it yields the matching
``streaming_service.format_*`` SSE event, so the in-memory ``parts`` list
stays in lockstep with what the FE's pipeline would have produced live.
``snapshot()`` is then taken once in the ``finally`` block and persisted
in a single UPDATE.
Pure synchronous state no DB I/O, no async, no flush callbacks. The
streaming code is responsible for driving lifecycle methods; this class
is a thin projection helper.
"""
from __future__ import annotations
import copy
import json
import logging
from typing import Any
logger = logging.getLogger(__name__)
# Mirrors the FE's filter in ``buildContentForPersistence`` / ``buildContentForUI``:
# only text/reasoning/tool-call parts count as "meaningful". data-thinking-steps
# and data-step-separator decorate the meaningful parts but never stand alone
# in a successful turn.
_MEANINGFUL_PART_TYPES: frozenset[str] = frozenset({"text", "reasoning", "tool-call"})
class AssistantContentBuilder:
"""Server-side projection of ``surfsense_web/lib/chat/streaming-state.ts``.
Output shape (deep copy of ``self.parts`` via ``snapshot()``) strictly
matches the FE ``ContentPart`` union::
| { type: "text"; text: string }
| { type: "reasoning"; text: string }
| { type: "tool-call"; toolCallId: str; toolName: str;
args: dict; result?: any; argsText?: str; langchainToolCallId?: str;
state?: "aborted" }
| { type: "data-thinking-steps"; data: { steps: ThinkingStepData[] } }
| { type: "data-step-separator"; data: { stepIndex: int } }
Order matches the wire order of the SSE events that drive the lifecycle
methods, with two FE-mirrored exceptions:
1. ``data-thinking-steps`` is a *singleton* and pinned at index 0 the
first time we see a ``data-thinking-step`` SSE event (the FE's
``updateThinkingSteps`` does ``unshift`` on first sight). Subsequent
thinking-step updates mutate that singleton in place.
2. ``data-step-separator`` is appended only when the message already has
meaningful content and the previous part isn't itself a separator
(so the FIRST step of a turn doesn't generate a leading divider).
"""
def __init__(self) -> None:
self.parts: list[dict[str, Any]] = []
# Index of the active text/reasoning part within ``parts`` while
# streaming is open; -1 means "no active part" and the next delta
# opens a fresh one. Mirrors ``ContentPartsState.currentTextPartIndex``.
self._current_text_idx: int = -1
self._current_reasoning_idx: int = -1
# ``ui_id``-keyed indexes for tool-call parts. ``ui_id`` is the
# synthetic ``call_<run_id>`` (legacy) or the LangChain
# ``tool_call.id`` (parity_v2) — same key the streaming layer
# threads through every ``tool-input-*`` / ``tool-output-*`` event.
self._tool_call_idx_by_ui_id: dict[str, int] = {}
# Live argsText accumulator (concatenated ``tool-input-delta`` chunks)
# so we can reproduce the FE's ``appendToolInputDelta`` behaviour
# before ``tool-input-available`` overwrites it with the
# pretty-printed final JSON.
self._args_text_by_ui_id: dict[str, str] = {}
# ------------------------------------------------------------------
# Text
# ------------------------------------------------------------------
def on_text_start(self, text_id: str) -> None:
"""Begin a fresh text block.
Symmetric to FE ``appendText``: opening text closes any active
reasoning so the renderer treats them as separate parts. The
actual text part isn't materialised here — it's lazily created
on the first ``on_text_delta`` so an empty start/end pair
leaves no trace. Matches the FE pipeline which has no explicit
``text-start`` handler at all.
"""
if self._current_reasoning_idx >= 0:
self._current_reasoning_idx = -1
def on_text_delta(self, text_id: str, delta: str) -> None:
if not delta:
return
if self._current_reasoning_idx >= 0:
# FE behaviour: a text delta after reasoning implicitly
# closes the reasoning block (see ``appendText`` lines
# 178-180).
self._current_reasoning_idx = -1
if (
self._current_text_idx >= 0
and 0 <= self._current_text_idx < len(self.parts)
and self.parts[self._current_text_idx].get("type") == "text"
):
self.parts[self._current_text_idx]["text"] += delta
return
self.parts.append({"type": "text", "text": delta})
self._current_text_idx = len(self.parts) - 1
def on_text_end(self, text_id: str) -> None:
"""Close the active text block.
Mirrors the wire-level ``text-end`` boundary the streaming layer
emits before tool calls / reasoning / step boundaries. The FE
pipeline implicitly closes via ``currentTextPartIndex = -1``
in ``addToolCall`` / ``appendReasoning`` / ``addStepSeparator``;
our helper does the same explicitly so callers don't have to
maintain that invariant per call site.
"""
self._current_text_idx = -1
# ------------------------------------------------------------------
# Reasoning
# ------------------------------------------------------------------
def on_reasoning_start(self, reasoning_id: str) -> None:
if self._current_text_idx >= 0:
self._current_text_idx = -1
def on_reasoning_delta(self, reasoning_id: str, delta: str) -> None:
if not delta:
return
if self._current_text_idx >= 0:
self._current_text_idx = -1
if (
self._current_reasoning_idx >= 0
and 0 <= self._current_reasoning_idx < len(self.parts)
and self.parts[self._current_reasoning_idx].get("type") == "reasoning"
):
self.parts[self._current_reasoning_idx]["text"] += delta
return
self.parts.append({"type": "reasoning", "text": delta})
self._current_reasoning_idx = len(self.parts) - 1
def on_reasoning_end(self, reasoning_id: str) -> None:
self._current_reasoning_idx = -1
# ------------------------------------------------------------------
# Tool calls
# ------------------------------------------------------------------
def on_tool_input_start(
self,
ui_id: str,
tool_name: str,
langchain_tool_call_id: str | None,
) -> None:
"""Register a tool-call card. Args are filled in by later events."""
if not ui_id:
return
# Skip duplicate registration: parity_v2 may emit
# ``tool-input-start`` from both ``on_chat_model_stream``
# (when tool_call_chunks register a name) and ``on_tool_start``
# (the canonical path). The FE de-dupes via ``toolCallIndices``;
# we mirror that here.
if ui_id in self._tool_call_idx_by_ui_id:
if langchain_tool_call_id:
idx = self._tool_call_idx_by_ui_id[ui_id]
part = self.parts[idx]
if not part.get("langchainToolCallId"):
part["langchainToolCallId"] = langchain_tool_call_id
return
part: dict[str, Any] = {
"type": "tool-call",
"toolCallId": ui_id,
"toolName": tool_name,
"args": {},
}
if langchain_tool_call_id:
part["langchainToolCallId"] = langchain_tool_call_id
self.parts.append(part)
self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1
self._current_text_idx = -1
self._current_reasoning_idx = -1
def on_tool_input_delta(self, ui_id: str, args_chunk: str) -> None:
"""Append a streamed args-delta chunk to the matching card's argsText.
Mirrors FE ``appendToolInputDelta``: no-ops when no card has been
registered yet for the given ``ui_id`` the deltas have nowhere
safe to land.
"""
if not ui_id or not args_chunk:
return
idx = self._tool_call_idx_by_ui_id.get(ui_id)
if idx is None:
return
if not (0 <= idx < len(self.parts)):
return
part = self.parts[idx]
if part.get("type") != "tool-call":
return
new_text = (part.get("argsText") or "") + args_chunk
part["argsText"] = new_text
self._args_text_by_ui_id[ui_id] = new_text
def on_tool_input_available(
self,
ui_id: str,
tool_name: str,
args: dict[str, Any],
langchain_tool_call_id: str | None,
) -> None:
"""Finalize the tool-call card's input.
Mirrors FE ``stream-pipeline.ts`` lines 127-153: replaces ``argsText``
with ``json.dumps(input, indent=2)`` so the post-stream card renders
pretty-printed JSON, sets the full ``args`` dict, and backfills
``langchainToolCallId`` if it wasn't known at ``tool-input-start`` time.
Also creates the card if no prior ``tool-input-start`` registered it
(legacy parity_v2-OFF / late-registration paths).
"""
if not ui_id:
return
try:
final_args_text = json.dumps(args or {}, indent=2, ensure_ascii=False)
except (TypeError, ValueError):
# Defensive: ``args`` should already be JSON-safe (the
# streaming layer sanitizes it before emitting), but if a
# caller hands us a non-serializable value we still want
# to record the call without breaking the snapshot.
final_args_text = str(args)
idx = self._tool_call_idx_by_ui_id.get(ui_id)
if idx is not None and 0 <= idx < len(self.parts):
part = self.parts[idx]
if part.get("type") == "tool-call":
part["args"] = args or {}
part["argsText"] = final_args_text
if langchain_tool_call_id and not part.get("langchainToolCallId"):
part["langchainToolCallId"] = langchain_tool_call_id
return
# No prior tool-input-start: register the card now.
new_part: dict[str, Any] = {
"type": "tool-call",
"toolCallId": ui_id,
"toolName": tool_name,
"args": args or {},
"argsText": final_args_text,
}
if langchain_tool_call_id:
new_part["langchainToolCallId"] = langchain_tool_call_id
self.parts.append(new_part)
self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1
self._current_text_idx = -1
self._current_reasoning_idx = -1
def on_tool_output_available(
self,
ui_id: str,
output: Any,
langchain_tool_call_id: str | None,
) -> None:
"""Attach the tool's output (``result``) to the matching card.
Mirrors FE ``updateToolCall``: backfill ``langchainToolCallId``
only if not already set (a NULL late-arriving value never blows
away an earlier known good one).
"""
if not ui_id:
return
idx = self._tool_call_idx_by_ui_id.get(ui_id)
if idx is None or not (0 <= idx < len(self.parts)):
return
part = self.parts[idx]
if part.get("type") != "tool-call":
return
part["result"] = output
if langchain_tool_call_id and not part.get("langchainToolCallId"):
part["langchainToolCallId"] = langchain_tool_call_id
# ------------------------------------------------------------------
# Thinking steps & step separators
# ------------------------------------------------------------------
def on_thinking_step(
self,
step_id: str,
title: str,
status: str,
items: list[str] | None,
) -> None:
"""Update / insert the singleton ``data-thinking-steps`` part.
Mirrors FE ``updateThinkingSteps``: maintain a single
``data-thinking-steps`` part anchored at index 0, replacing or
unshifting on first sight. Each ``on_thinking_step`` call
replaces the entry in the steps list keyed by ``step_id`` (or
appends if new).
"""
if not step_id:
return
new_step = {
"id": step_id,
"title": title or "",
"status": status or "in_progress",
"items": list(items) if items else [],
}
# Find existing data-thinking-steps part.
existing_idx = -1
for i, p in enumerate(self.parts):
if p.get("type") == "data-thinking-steps":
existing_idx = i
break
if existing_idx >= 0:
current_steps = self.parts[existing_idx].get("data", {}).get("steps") or []
replaced = False
for i, step in enumerate(current_steps):
if step.get("id") == step_id:
current_steps[i] = new_step
replaced = True
break
if not replaced:
current_steps.append(new_step)
self.parts[existing_idx] = {
"type": "data-thinking-steps",
"data": {"steps": current_steps},
}
return
# First sight: unshift to position 0 (FE parity).
self.parts.insert(
0,
{
"type": "data-thinking-steps",
"data": {"steps": [new_step]},
},
)
# Bump tracked indices since we inserted at the head.
if self._current_text_idx >= 0:
self._current_text_idx += 1
if self._current_reasoning_idx >= 0:
self._current_reasoning_idx += 1
for ui_id, idx in list(self._tool_call_idx_by_ui_id.items()):
self._tool_call_idx_by_ui_id[ui_id] = idx + 1
def on_step_separator(self) -> None:
"""Append a ``data-step-separator`` between consecutive model steps.
Mirrors FE ``addStepSeparator``: only emit when the message
already has meaningful content AND the previous part isn't
itself a separator. ``stepIndex`` is the running count of
separators already in ``parts``.
"""
has_content = any(p.get("type") in _MEANINGFUL_PART_TYPES for p in self.parts)
if not has_content:
return
if self.parts and self.parts[-1].get("type") == "data-step-separator":
return
step_index = sum(
1 for p in self.parts if p.get("type") == "data-step-separator"
)
self.parts.append(
{
"type": "data-step-separator",
"data": {"stepIndex": step_index},
}
)
self._current_text_idx = -1
self._current_reasoning_idx = -1
# ------------------------------------------------------------------
# Interruption handling
# ------------------------------------------------------------------
def mark_interrupted(self) -> None:
"""Close any open text/reasoning and flip running tools to aborted.
Called from the streaming ``finally`` block before ``snapshot()`` so
the persisted JSONB reflects a coherent end-state even when the
client disconnected mid-turn or the agent hit a fatal error.
- Active text/reasoning blocks: simply lose their "active"
marker (no synthetic content appended). Whatever was streamed
stays as-is.
- Tool-call parts that never received a ``result`` get
``state="aborted"`` so the FE history loader can render them
as "interrupted" rather than "still running".
"""
self._current_text_idx = -1
self._current_reasoning_idx = -1
for part in self.parts:
if part.get("type") != "tool-call":
continue
if "result" in part:
continue
part["state"] = "aborted"
# ------------------------------------------------------------------
# Snapshot & introspection
# ------------------------------------------------------------------
def snapshot(self) -> list[dict[str, Any]]:
"""Return a deep copy of ``parts`` ready for SQL UPDATE / json.dumps.
Deep-copied so callers that finalize from the shielded ``finally``
block can't accidentally mutate the persisted payload while the
SQL UPDATE is in flight (the streaming layer doesn't touch the
builder after this call, but defensive copies are cheap and cheap
is what we want in a finally block).
"""
return copy.deepcopy(self.parts)
def is_empty(self) -> bool:
"""True if no meaningful content was captured.
``data-thinking-steps`` and ``data-step-separator`` decorate
meaningful content but don't count on their own — a turn that
only emitted a thinking step before being interrupted should
still be treated as empty for the status-marker fallback.
"""
return not any(p.get("type") in _MEANINGFUL_PART_TYPES for p in self.parts)
def stats(self) -> dict[str, int]:
"""Return counts of each part-type plus rough byte size.
Used by the streaming layer's perf logger so an ops dashboard
can correlate finalize latency with payload size, and so a
regression that quietly stops emitting tool-call parts (or
starts emitting hundreds) shows up in [PERF] grep rather than
only as a "history reload looks weird" bug report.
``bytes`` is the JSON-serialised payload length what actually
crosses the wire to PostgreSQL's JSONB column. We compute it
with ``ensure_ascii=False`` to match the JSONB encoder's UTF-8
on-disk layout closely enough for back-of-the-envelope sizing.
Reasoning/text/tool-call/thinking-step/step-separator counts are
independent so any one can spike without the others.
Defensive: ``json.dumps`` failure (a non-serializable value
slipped past the streaming layer's sanitization) is reported as
``bytes=-1`` rather than raised perf logging must not be the
thing that breaks the streaming finally block.
"""
text_blocks = 0
reasoning_blocks = 0
tool_calls = 0
tool_calls_completed = 0
tool_calls_aborted = 0
thinking_step_parts = 0
step_separators = 0
for part in self.parts:
kind = part.get("type")
if kind == "text":
text_blocks += 1
elif kind == "reasoning":
reasoning_blocks += 1
elif kind == "tool-call":
tool_calls += 1
if part.get("state") == "aborted":
tool_calls_aborted += 1
elif "result" in part:
tool_calls_completed += 1
elif kind == "data-thinking-steps":
thinking_step_parts += 1
elif kind == "data-step-separator":
step_separators += 1
try:
byte_size = len(json.dumps(self.parts, ensure_ascii=False, default=str))
except (TypeError, ValueError):
byte_size = -1
return {
"parts": len(self.parts),
"bytes": byte_size,
"text": text_blocks,
"reasoning": reasoning_blocks,
"tool_calls": tool_calls,
"tool_calls_completed": tool_calls_completed,
"tool_calls_aborted": tool_calls_aborted,
"thinking_step_parts": thinking_step_parts,
"step_separators": step_separators,
}

View file

@ -0,0 +1,534 @@
"""Server-side message persistence helpers for the streaming chat agent.
Historically the streaming task (``stream_new_chat``/``stream_resume_chat``)
left ``new_chat_messages`` empty and relied on the frontend to round-trip
``POST /threads/{id}/messages`` afterwards. That gave authenticated clients
a "ghost-thread" abuse vector: skip the round-trip and burn LLM tokens
without leaving an audit trail. These helpers move both writes (the user
turn that triggered the stream and the assistant turn the stream produced)
into the server itself, idempotent against the partial unique index
``uq_new_chat_messages_thread_turn_role`` so legacy frontends that *do*
keep posting via ``appendMessage`` simply hit the unique-index recovery
path on the second writer instead of creating duplicates.
Assistant turn lifecycle
------------------------
The assistant side is split into two helpers so we can capture the row id
*before* the stream produces any output:
* ``persist_assistant_shell`` runs immediately after ``persist_user_turn``
and INSERTs an empty assistant row anchored to ``(thread_id, turn_id,
ASSISTANT)``. Returns the row id so the streaming layer can correlate
later writes (token_usage, AgentActionLog future-correlation) against
a stable PK from the start of the turn.
* ``finalize_assistant_turn`` runs from the streaming ``finally`` block.
It UPDATEs the row's ``content`` to the rich ``ContentPart[]`` snapshot
produced server-side by ``AssistantContentBuilder`` and writes the
``token_usage`` row using ``INSERT ... ON CONFLICT DO NOTHING`` against
the ``uq_token_usage_message_id`` partial unique index from migration
142, hard-eliminating any race against ``append_message``'s recovery
branch.
Defensive contract
------------------
* Every helper runs inside ``shielded_async_session()`` so ``session.close()``
survives starlette's mid-stream cancel scope on client disconnect.
* ``persist_user_turn`` and ``persist_assistant_shell`` use ``INSERT ... ON
CONFLICT DO NOTHING ... RETURNING id`` keyed on the ``(thread_id, turn_id,
role)`` partial unique index. On conflict the insert silently no-ops at
the DB level no Python ``IntegrityError`` is constructed, which
eliminates spurious debugger pauses and keeps logs clean. On conflict a
follow-up ``SELECT`` resolves the existing row id so the streaming layer
can correlate writes against a stable PK.
* ``finalize_assistant_turn`` is best-effort: it never raises. The
streaming ``finally`` block calls it from within
``anyio.CancelScope(shield=True)`` and any raised exception there
would mask the real error.
"""
from __future__ import annotations
import logging
import time
from datetime import UTC, datetime
from typing import Any
from uuid import UUID
from sqlalchemy import text as sa_text
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.future import select
from app.db import (
NewChatMessage,
NewChatMessageRole,
NewChatThread,
TokenUsage,
shielded_async_session,
)
from app.services.token_tracking_service import (
TurnTokenAccumulator,
)
from app.utils.perf import get_perf_logger
logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
# Empty initial assistant content. ``finalize_assistant_turn`` overwrites
# this in a single UPDATE at end-of-stream with the full ``ContentPart[]``
# snapshot produced by ``AssistantContentBuilder``. We persist a one-element
# list with an empty text part so a crash between shell-INSERT and finalize
# leaves the row in a FE-renderable shape (blank bubble) instead of
# blowing up the history loader.
_EMPTY_SHELL_CONTENT: list[dict[str, Any]] = [{"type": "text", "text": ""}]
# Substituted content for genuinely empty turns (no text, no reasoning,
# no tool calls). The streaming layer flips to this when
# ``AssistantContentBuilder.is_empty()`` returns True so the persisted
# row is at least somewhat self-describing instead of an empty text
# bubble. The FE's ``ContentPart`` union doesn't include ``status``
# yet, so the history loader will silently drop this part and render
# a blank bubble (matches today's behaviour for empty turns); a follow-up
# FE PR adds the explicit "no response" rendering.
_STATUS_NO_RESPONSE: list[dict[str, Any]] = [
{"type": "status", "text": "(no text response)"}
]
def _build_user_content(
user_query: str,
user_image_data_urls: list[str] | None,
mentioned_documents: list[dict[str, Any]] | None = None,
) -> list[dict[str, Any]]:
"""Build the persisted user-message ``content`` (assistant-ui v2 parts).
Mirrors the shape the existing frontend posts via
``appendMessage`` (see ``surfsense_web/.../new-chat/[[...chat_id]]/page.tsx``):
[{"type": "text", "text": "..."},
{"type": "image", "image": "data:..."},
{"type": "mentioned-documents", "documents": [{"id": int,
"title": str, "document_type": str}, ...]}]
The companion reader is
``app.utils.user_message_multimodal.split_persisted_user_content_parts``
which expects exactly this shape keep them in sync.
``mentioned_documents``: optional list of ``{id, title, document_type}``
dicts. When non-empty (and a ``mentioned-documents`` part is not already
in some other input shape), a single ``{"type": "mentioned-documents",
"documents": [...]}`` part is appended. Mirrors the FE injection at
``page.tsx:281-286`` (``persistUserTurn``).
"""
parts: list[dict[str, Any]] = [{"type": "text", "text": user_query or ""}]
for url in user_image_data_urls or ():
if isinstance(url, str) and url:
parts.append({"type": "image", "image": url})
if mentioned_documents:
normalized: list[dict[str, Any]] = []
for doc in mentioned_documents:
if not isinstance(doc, dict):
continue
doc_id = doc.get("id")
title = doc.get("title")
document_type = doc.get("document_type")
if doc_id is None or title is None or document_type is None:
continue
normalized.append(
{
"id": doc_id,
"title": str(title),
"document_type": str(document_type),
}
)
if normalized:
parts.append({"type": "mentioned-documents", "documents": normalized})
return parts
async def persist_user_turn(
*,
chat_id: int,
user_id: str | None,
turn_id: str,
user_query: str,
user_image_data_urls: list[str] | None = None,
mentioned_documents: list[dict[str, Any]] | None = None,
) -> int | None:
"""Persist the user-side row for a chat turn and return its ``id``.
Uses ``INSERT ... ON CONFLICT DO NOTHING ... RETURNING id`` keyed on the
``(thread_id, turn_id, role)`` partial unique index from migration 141
(``WHERE turn_id IS NOT NULL``). On conflict the insert silently no-ops
at the DB level no Python ``IntegrityError`` is constructed, which
eliminates the debugger pause that ``justMyCode=false`` + async greenlet
interactions used to produce, and keeps production logs clean.
Returns the ``id`` of the row that exists for this turn after the call:
the freshly inserted ``id`` on the happy path, or the existing ``id``
when a previous writer (legacy FE ``appendMessage`` racing the SSE
stream, redelivered request, etc.) already wrote it. Returns ``None``
only on genuine DB failure; the caller should yield a streaming error
and abort the turn so we never produce a title/assistant row that
isn't anchored to a persisted user message.
Other constraint violations (FK, NOT NULL, etc.) still raise
``IntegrityError`` only the ``(thread_id, turn_id, role)`` collision
is silenced.
"""
if not turn_id:
# Defensive: turn_id is always populated by the streaming path
# before this helper is called. If it isn't, we cannot be
# idempotent against the unique index — refuse to write rather
# than create a row the unique index can't dedupe.
logger.error(
"persist_user_turn called without a turn_id (chat_id=%s); skipping",
chat_id,
)
return None
t0 = time.perf_counter()
outcome = "failed"
resolved_id: int | None = None
try:
async with shielded_async_session() as ws:
# Re-attach the thread row so we can also bump updated_at
# in the same write — keeps the sidebar ordering accurate
# when a user fires off a turn but never reaches the
# legacy appendMessage.
thread = await ws.get(NewChatThread, chat_id)
author_uuid: UUID | None = None
if user_id:
try:
author_uuid = UUID(user_id)
except (TypeError, ValueError):
logger.warning(
"persist_user_turn: invalid user_id=%r, persisting as anonymous",
user_id,
)
content_payload = _build_user_content(
user_query, user_image_data_urls, mentioned_documents
)
insert_stmt = (
pg_insert(NewChatMessage)
.values(
thread_id=chat_id,
role=NewChatMessageRole.USER,
content=content_payload,
author_id=author_uuid,
turn_id=turn_id,
)
.on_conflict_do_nothing(
index_elements=["thread_id", "turn_id", "role"],
index_where=sa_text("turn_id IS NOT NULL"),
)
.returning(NewChatMessage.id)
)
inserted_id = (await ws.execute(insert_stmt)).scalar()
if inserted_id is None:
# Conflict on partial unique index — another writer
# (legacy FE appendMessage, redelivered request, etc.)
# already persisted this row. Look it up and reuse.
lookup = await ws.execute(
select(NewChatMessage.id).where(
NewChatMessage.thread_id == chat_id,
NewChatMessage.turn_id == turn_id,
NewChatMessage.role == NewChatMessageRole.USER,
)
)
existing_id = lookup.scalars().first()
if existing_id is None:
# Conflict reported but no row found — extremely
# unlikely (concurrent DELETE). Surface as failure.
logger.warning(
"persist_user_turn: conflict but no matching row "
"(chat_id=%s, turn_id=%s)",
chat_id,
turn_id,
)
outcome = "integrity_no_match"
return None
resolved_id = int(existing_id)
outcome = "race_recovered"
else:
resolved_id = int(inserted_id)
outcome = "inserted"
# Bump thread.updated_at only on a real insert — when
# we recovered an existing row the prior writer
# already touched the thread.
if thread is not None:
thread.updated_at = datetime.now(UTC)
await ws.commit()
return resolved_id
except Exception:
logger.exception(
"persist_user_turn failed (chat_id=%s, turn_id=%s)",
chat_id,
turn_id,
)
return None
finally:
_perf_log.info(
"[persist_user_turn] outcome=%s chat_id=%s turn_id=%s "
"message_id=%s query_len=%d images=%d mentioned_docs=%d "
"in %.3fs",
outcome,
chat_id,
turn_id,
resolved_id,
len(user_query or ""),
len(user_image_data_urls or ()),
len(mentioned_documents or ()),
time.perf_counter() - t0,
)
async def persist_assistant_shell(
*,
chat_id: int,
user_id: str | None,
turn_id: str,
) -> int | None:
"""Pre-write an empty assistant row for the turn and return its id.
Inserts a placeholder ``new_chat_messages`` row (empty text content) so
the streaming layer has a stable ``message_id`` to correlate against
for the rest of the turn. ``finalize_assistant_turn`` overwrites the
``content`` field at end-of-stream with the rich ``ContentPart[]``
snapshot produced by ``AssistantContentBuilder``.
Returns the row id on success, ``None`` on a genuine DB failure (caller
should abort the turn rather than stream into a void).
Idempotent against the ``(thread_id, turn_id, ASSISTANT)`` partial unique
index from migration 141: if a row already exists (resume retry, racing
legacy frontend, redelivered request, etc.) we look it up by
``(thread_id, turn_id, role)`` and return its existing id. The streaming
layer is then free to UPDATE that row at finalize time.
"""
if not turn_id:
logger.error(
"persist_assistant_shell called without a turn_id (chat_id=%s); skipping",
chat_id,
)
return None
t0 = time.perf_counter()
outcome = "failed"
resolved_id: int | None = None
try:
async with shielded_async_session() as ws:
insert_stmt = (
pg_insert(NewChatMessage)
.values(
thread_id=chat_id,
role=NewChatMessageRole.ASSISTANT,
content=_EMPTY_SHELL_CONTENT,
author_id=None,
turn_id=turn_id,
)
.on_conflict_do_nothing(
index_elements=["thread_id", "turn_id", "role"],
index_where=sa_text("turn_id IS NOT NULL"),
)
.returning(NewChatMessage.id)
)
inserted_id = (await ws.execute(insert_stmt)).scalar()
if inserted_id is None:
# Conflict — another writer (legacy FE appendMessage,
# resume retry, redelivered request) wrote the
# (thread_id, turn_id, ASSISTANT) row first. Look it up
# so the streaming layer can UPDATE the same row at
# finalize time.
lookup = await ws.execute(
select(NewChatMessage.id).where(
NewChatMessage.thread_id == chat_id,
NewChatMessage.turn_id == turn_id,
NewChatMessage.role == NewChatMessageRole.ASSISTANT,
)
)
existing_id = lookup.scalars().first()
if existing_id is None:
logger.warning(
"persist_assistant_shell: conflict but no matching "
"(thread_id, turn_id, role) row found "
"(chat_id=%s, turn_id=%s)",
chat_id,
turn_id,
)
outcome = "integrity_no_match"
return None
resolved_id = int(existing_id)
outcome = "race_recovered"
else:
resolved_id = int(inserted_id)
outcome = "inserted"
await ws.commit()
return resolved_id
except Exception:
logger.exception(
"persist_assistant_shell failed (chat_id=%s, turn_id=%s)",
chat_id,
turn_id,
)
return None
finally:
_perf_log.info(
"[persist_assistant_shell] outcome=%s chat_id=%s turn_id=%s "
"message_id=%s in %.3fs",
outcome,
chat_id,
turn_id,
resolved_id,
time.perf_counter() - t0,
)
async def finalize_assistant_turn(
*,
message_id: int,
chat_id: int,
search_space_id: int,
user_id: str | None,
turn_id: str,
content: list[dict[str, Any]],
accumulator: TurnTokenAccumulator | None,
) -> None:
"""Finalize the assistant row and write its token_usage.
Two writes in a single shielded session:
1. ``UPDATE new_chat_messages SET content = :c, updated_at = now()
WHERE id = :id`` overwrites the placeholder ``persist_assistant_shell``
wrote with the full ``ContentPart[]`` snapshot produced server-side.
2. ``INSERT INTO token_usage (...) VALUES (...) ON CONFLICT (message_id)
WHERE message_id IS NOT NULL DO NOTHING`` uses the partial unique
index ``uq_token_usage_message_id`` from migration 142 to make the
insert idempotent against ``append_message``'s recovery branch
(which uses the same ON CONFLICT clause).
Substitutes the status-marker payload when ``content`` is empty
(pure tool-call turn that aborted before any output, or interrupt
before any event arrived). The status marker is preferable to a
blank text bubble because token accounting still runs and an ops
dashboard can flag the row.
Best-effort never raises. The streaming ``finally`` calls this
from within ``anyio.CancelScope(shield=True)``; any raised exception
here would mask the real error that triggered the cleanup.
"""
if not turn_id:
logger.error(
"finalize_assistant_turn called without turn_id "
"(chat_id=%s, message_id=%s); skipping",
chat_id,
message_id,
)
return
if not message_id:
logger.error(
"finalize_assistant_turn called without message_id "
"(chat_id=%s, turn_id=%s); skipping",
chat_id,
turn_id,
)
return
payload: list[dict[str, Any]]
is_status_marker = False
if content:
payload = content
else:
payload = _STATUS_NO_RESPONSE
is_status_marker = True
t0 = time.perf_counter()
outcome = "failed"
token_usage_attempted = bool(
accumulator is not None and accumulator.calls and user_id
)
try:
async with shielded_async_session() as ws:
assistant_row = await ws.get(NewChatMessage, message_id)
if assistant_row is None:
logger.warning(
"finalize_assistant_turn: row not found "
"(chat_id=%s, message_id=%s, turn_id=%s); skipping",
chat_id,
message_id,
turn_id,
)
outcome = "row_missing"
return
assistant_row.content = payload
assistant_row.updated_at = datetime.now(UTC)
# Token usage. ``record_token_usage`` (used elsewhere) does
# SELECT-then-INSERT in two statements which races with
# ``append_message``. Switch to a single INSERT ... ON
# CONFLICT DO NOTHING keyed on the migration-142 partial
# unique index so the loser silently drops its write at
# the DB level — exactly one row per ``message_id``,
# regardless of which session committed first.
if accumulator is not None and accumulator.calls and user_id:
try:
user_uuid = UUID(user_id)
except (TypeError, ValueError):
logger.warning(
"finalize_assistant_turn: invalid user_id=%r, "
"skipping token_usage row",
user_id,
)
else:
insert_stmt = (
pg_insert(TokenUsage)
.values(
usage_type="chat",
prompt_tokens=accumulator.total_prompt_tokens,
completion_tokens=accumulator.total_completion_tokens,
total_tokens=accumulator.grand_total,
cost_micros=accumulator.total_cost_micros,
model_breakdown=accumulator.per_message_summary(),
call_details={"calls": accumulator.serialized_calls()},
thread_id=chat_id,
message_id=message_id,
search_space_id=search_space_id,
user_id=user_uuid,
)
.on_conflict_do_nothing(
index_elements=["message_id"],
index_where=sa_text("message_id IS NOT NULL"),
)
)
await ws.execute(insert_stmt)
await ws.commit()
outcome = "ok"
except Exception:
logger.exception(
"finalize_assistant_turn failed (chat_id=%s, message_id=%s, turn_id=%s)",
chat_id,
message_id,
turn_id,
)
finally:
_perf_log.info(
"[finalize_assistant_turn] outcome=%s chat_id=%s message_id=%s "
"turn_id=%s parts=%d status_marker=%s "
"token_usage_attempted=%s in %.3fs",
outcome,
chat_id,
message_id,
turn_id,
len(payload),
is_status_marker,
token_usage_attempted,
time.perf_counter() - t0,
)

View file

@ -25,7 +25,6 @@ from uuid import UUID
import anyio
from langchain_core.messages import HumanMessage
from sqlalchemy import func
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
@ -314,6 +313,19 @@ class StreamResult:
verification_succeeded: bool = False
commit_gate_passed: bool = True
commit_gate_reason: str = ""
# Pre-allocated assistant ``new_chat_messages.id`` for this turn,
# captured by ``persist_assistant_shell`` right after the user row is
# persisted. ``None`` for the legacy / anonymous code paths that don't
# opt in to server-side ``ContentPart[]`` projection.
assistant_message_id: int | None = None
# In-memory mirror of the FE's assistant-ui ``ContentPartsState``,
# populated by the lifecycle methods called from ``_stream_agent_events``
# at each ``streaming_service.format_*`` yield site. Snapshot in the
# streaming ``finally`` to produce the rich JSONB persisted by
# ``finalize_assistant_turn``. ``repr=False`` keeps the
# log-on-error path (``StreamResult`` is logged in some error
# branches) from dumping a potentially-large parts list.
content_builder: Any | None = field(default=None, repr=False)
def _safe_float(value: Any, default: float = 0.0) -> float:
@ -721,6 +733,7 @@ async def _stream_agent_events(
fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
fallback_commit_thread_id: int | None = None,
runtime_context: Any = None,
content_builder: Any | None = None,
) -> AsyncGenerator[str, None]:
"""Shared async generator that streams and formats astream_events from the agent.
@ -737,6 +750,15 @@ async def _stream_agent_events(
initial_step_id: If set, the helper inherits an already-active thinking step.
initial_step_title: Title of the inherited thinking step.
initial_step_items: Items of the inherited thinking step.
content_builder: Optional ``AssistantContentBuilder``. When set, every
``streaming_service.format_*`` yield site also drives the matching
builder lifecycle method (``on_text_*``, ``on_reasoning_*``,
``on_tool_*``, ``on_thinking_step``, ``on_step_separator``) so the
in-memory ``ContentPart[]`` projection stays in lockstep with what
the FE renders live. Pure in-memory accumulation no DB I/O
consumed by the streaming ``finally`` to produce the rich JSONB
persisted via ``finalize_assistant_turn``. ``None`` (the default)
is used by the anonymous / legacy code paths and is a no-op.
Yields:
SSE-formatted strings for each event.
@ -801,12 +823,46 @@ async def _stream_agent_events(
current_lc_tool_call_id: dict[str, str | None] = {"value": None}
def _emit_tool_output(call_id: str, output: Any) -> str:
# Drive the builder before formatting the SSE so the in-memory
# ContentPart[] mirror sees the result attached to the same
# card the FE will render. Builder method is a no-op when
# ``content_builder`` is None (anonymous / legacy paths).
if content_builder is not None:
content_builder.on_tool_output_available(
call_id, output, current_lc_tool_call_id["value"]
)
return streaming_service.format_tool_output_available(
call_id,
output,
langchain_tool_call_id=current_lc_tool_call_id["value"],
)
def _emit_thinking_step(
*,
step_id: str,
title: str,
status: str = "in_progress",
items: list[str] | None = None,
) -> str:
"""Format a thinking-step SSE event and notify the builder.
Single helper used at every ``format_thinking_step`` yield site
in this generator. Drives ``AssistantContentBuilder.on_thinking_step``
first so the FE-mirror state lands the update before the SSE
carrying the same data leaves the wire order matches the FE
pipeline (``processSharedStreamEvent`` updates state, then
flushes). Builder call is a no-op when ``content_builder`` is
None (anonymous / legacy paths).
"""
if content_builder is not None:
content_builder.on_thinking_step(step_id, title, status, items)
return streaming_service.format_thinking_step(
step_id=step_id,
title=title,
status=status,
items=items,
)
def next_thinking_step_id() -> str:
nonlocal thinking_step_counter
thinking_step_counter += 1
@ -816,7 +872,7 @@ async def _stream_agent_events(
nonlocal last_active_step_id
if last_active_step_id and last_active_step_id not in completed_step_ids:
completed_step_ids.add(last_active_step_id)
event = streaming_service.format_thinking_step(
event = _emit_thinking_step(
step_id=last_active_step_id,
title=last_active_step_title,
status="completed",
@ -861,6 +917,8 @@ async def _stream_agent_events(
if parity_v2 and reasoning_delta:
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
if content_builder is not None:
content_builder.on_text_end(current_text_id)
current_text_id = None
if current_reasoning_id is None:
completion_event = complete_current_step()
@ -873,13 +931,21 @@ async def _stream_agent_events(
just_finished_tool = False
current_reasoning_id = streaming_service.generate_reasoning_id()
yield streaming_service.format_reasoning_start(current_reasoning_id)
if content_builder is not None:
content_builder.on_reasoning_start(current_reasoning_id)
yield streaming_service.format_reasoning_delta(
current_reasoning_id, reasoning_delta
)
if content_builder is not None:
content_builder.on_reasoning_delta(
current_reasoning_id, reasoning_delta
)
if text_delta:
if current_reasoning_id is not None:
yield streaming_service.format_reasoning_end(current_reasoning_id)
if content_builder is not None:
content_builder.on_reasoning_end(current_reasoning_id)
current_reasoning_id = None
if current_text_id is None:
completion_event = complete_current_step()
@ -892,8 +958,12 @@ async def _stream_agent_events(
just_finished_tool = False
current_text_id = streaming_service.generate_text_id()
yield streaming_service.format_text_start(current_text_id)
if content_builder is not None:
content_builder.on_text_start(current_text_id)
yield streaming_service.format_text_delta(current_text_id, text_delta)
accumulated_text += text_delta
if content_builder is not None:
content_builder.on_text_delta(current_text_id, text_delta)
# Live tool-call argument streaming. Runs AFTER text/reasoning
# processing so chunks containing both stay in their natural
@ -925,11 +995,17 @@ async def _stream_agent_events(
# within the same stream window.
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
if content_builder is not None:
content_builder.on_text_end(current_text_id)
current_text_id = None
if current_reasoning_id is not None:
yield streaming_service.format_reasoning_end(
current_reasoning_id
)
if content_builder is not None:
content_builder.on_reasoning_end(
current_reasoning_id
)
current_reasoning_id = None
index_to_meta[idx] = {
@ -942,6 +1018,8 @@ async def _stream_agent_events(
name,
langchain_tool_call_id=lc_id,
)
if content_builder is not None:
content_builder.on_tool_input_start(ui_id, name, lc_id)
# Emit args delta for any chunk at a registered
# index (including idless continuations). Once an
@ -957,6 +1035,10 @@ async def _stream_agent_events(
yield streaming_service.format_tool_input_delta(
meta["ui_id"], args_chunk
)
if content_builder is not None:
content_builder.on_tool_input_delta(
meta["ui_id"], args_chunk
)
else:
pending_tool_call_chunks.append(tcc)
@ -974,6 +1056,8 @@ async def _stream_agent_events(
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
if content_builder is not None:
content_builder.on_text_end(current_text_id)
current_text_id = None
if last_active_step_title != "Synthesizing response":
@ -994,7 +1078,7 @@ async def _stream_agent_events(
)
last_active_step_title = "Listing files"
last_active_step_items = [ls_path]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Listing files",
status="in_progress",
@ -1009,7 +1093,7 @@ async def _stream_agent_events(
display_fp = fp if len(fp) <= 80 else "" + fp[-77:]
last_active_step_title = "Reading file"
last_active_step_items = [display_fp]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Reading file",
status="in_progress",
@ -1024,7 +1108,7 @@ async def _stream_agent_events(
display_fp = fp if len(fp) <= 80 else "" + fp[-77:]
last_active_step_title = "Writing file"
last_active_step_items = [display_fp]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Writing file",
status="in_progress",
@ -1039,7 +1123,7 @@ async def _stream_agent_events(
display_fp = fp if len(fp) <= 80 else "" + fp[-77:]
last_active_step_title = "Editing file"
last_active_step_items = [display_fp]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Editing file",
status="in_progress",
@ -1056,7 +1140,7 @@ async def _stream_agent_events(
)
last_active_step_title = "Searching files"
last_active_step_items = [f"{pat} in {base_path}"]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Searching files",
status="in_progress",
@ -1076,7 +1160,7 @@ async def _stream_agent_events(
last_active_step_items = [
f'"{display_pat}"' + (f" in {grep_path}" if grep_path else "")
]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Searching content",
status="in_progress",
@ -1091,7 +1175,7 @@ async def _stream_agent_events(
display_path = rm_path if len(rm_path) <= 80 else "" + rm_path[-77:]
last_active_step_title = "Deleting file"
last_active_step_items = [display_path] if display_path else []
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Deleting file",
status="in_progress",
@ -1108,7 +1192,7 @@ async def _stream_agent_events(
)
last_active_step_title = "Deleting folder"
last_active_step_items = [display_path] if display_path else []
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Deleting folder",
status="in_progress",
@ -1125,7 +1209,7 @@ async def _stream_agent_events(
)
last_active_step_title = "Creating folder"
last_active_step_items = [display_path] if display_path else []
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Creating folder",
status="in_progress",
@ -1148,7 +1232,7 @@ async def _stream_agent_events(
last_active_step_items = (
[f"{display_src}{display_dst}"] if src or dst else []
)
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Moving file",
status="in_progress",
@ -1165,7 +1249,7 @@ async def _stream_agent_events(
if todo_count
else []
)
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Planning tasks",
status="in_progress",
@ -1180,7 +1264,7 @@ async def _stream_agent_events(
display_title = doc_title[:60] + ("" if len(doc_title) > 60 else "")
last_active_step_title = "Saving document"
last_active_step_items = [display_title]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Saving document",
status="in_progress",
@ -1196,7 +1280,7 @@ async def _stream_agent_events(
last_active_step_items = [
f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}"
]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Generating image",
status="in_progress",
@ -1212,7 +1296,7 @@ async def _stream_agent_events(
last_active_step_items = [
f"URL: {url[:80]}{'...' if len(url) > 80 else ''}"
]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Scraping webpage",
status="in_progress",
@ -1235,7 +1319,7 @@ async def _stream_agent_events(
f"Content: {content_len:,} characters",
"Preparing audio generation...",
]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Generating podcast",
status="in_progress",
@ -1256,7 +1340,7 @@ async def _stream_agent_events(
f"Topic: {report_topic}",
"Analyzing source content...",
]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title=step_title,
status="in_progress",
@ -1271,7 +1355,7 @@ async def _stream_agent_events(
display_cmd = cmd[:80] + ("" if len(cmd) > 80 else "")
last_active_step_title = "Running command"
last_active_step_items = [f"$ {display_cmd}"]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title="Running command",
status="in_progress",
@ -1288,7 +1372,7 @@ async def _stream_agent_events(
tool_name.replace("_", " ").strip().capitalize() or tool_name
)
last_active_step_items = []
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=tool_step_id,
title=last_active_step_title,
status="in_progress",
@ -1349,6 +1433,10 @@ async def _stream_agent_events(
tool_name,
langchain_tool_call_id=langchain_tool_call_id,
)
if content_builder is not None:
content_builder.on_tool_input_start(
tool_call_id, tool_name, langchain_tool_call_id
)
if run_id:
ui_tool_call_id_by_run[run_id] = tool_call_id
@ -1371,6 +1459,13 @@ async def _stream_agent_events(
_safe_input,
langchain_tool_call_id=langchain_tool_call_id,
)
if content_builder is not None:
content_builder.on_tool_input_available(
tool_call_id,
tool_name,
_safe_input,
langchain_tool_call_id,
)
elif event_type == "on_tool_end":
active_tool_depth = max(0, active_tool_depth - 1)
@ -1443,70 +1538,70 @@ async def _stream_agent_events(
current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id]
if tool_name == "read_file":
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Reading file",
status="completed",
items=last_active_step_items,
)
elif tool_name == "write_file":
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Writing file",
status="completed",
items=last_active_step_items,
)
elif tool_name == "edit_file":
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Editing file",
status="completed",
items=last_active_step_items,
)
elif tool_name == "glob":
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Searching files",
status="completed",
items=last_active_step_items,
)
elif tool_name == "grep":
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Searching content",
status="completed",
items=last_active_step_items,
)
elif tool_name == "rm":
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Deleting file",
status="completed",
items=last_active_step_items,
)
elif tool_name == "rmdir":
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Deleting folder",
status="completed",
items=last_active_step_items,
)
elif tool_name == "mkdir":
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Creating folder",
status="completed",
items=last_active_step_items,
)
elif tool_name == "move_file":
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Moving file",
status="completed",
items=last_active_step_items,
)
elif tool_name == "write_todos":
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Planning tasks",
status="completed",
@ -1523,7 +1618,7 @@ async def _stream_agent_events(
*last_active_step_items,
result_str[:80] if is_error else "Saved to knowledge base",
]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Saving document",
status="completed",
@ -1542,7 +1637,7 @@ async def _stream_agent_events(
else "Generation failed"
)
completed_items = [*last_active_step_items, f"Error: {error_msg}"]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Generating image",
status="completed",
@ -1566,7 +1661,7 @@ async def _stream_agent_events(
]
else:
completed_items = [*last_active_step_items, "Content extracted"]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Scraping webpage",
status="completed",
@ -1612,7 +1707,7 @@ async def _stream_agent_events(
]
else:
completed_items = last_active_step_items
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Generating podcast",
status="completed",
@ -1647,7 +1742,7 @@ async def _stream_agent_events(
]
else:
completed_items = last_active_step_items
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Generating video presentation",
status="completed",
@ -1695,7 +1790,7 @@ async def _stream_agent_events(
else:
completed_items = last_active_step_items
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title=step_title,
status="completed",
@ -1721,7 +1816,7 @@ async def _stream_agent_events(
]
else:
completed_items = [*last_active_step_items, "Finished"]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Running command",
status="completed",
@ -1761,7 +1856,7 @@ async def _stream_agent_events(
completed_items.append(f"(+{len(file_names) - 4} more)")
else:
completed_items = ["No files found"]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title="Listing files",
status="completed",
@ -1773,7 +1868,7 @@ async def _stream_agent_events(
fallback_title = (
tool_name.replace("_", " ").strip().capitalize() or tool_name
)
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=original_step_id,
title=fallback_title,
status="completed",
@ -2113,7 +2208,7 @@ async def _stream_agent_events(
# Phase transitions: replace everything after topic
last_active_step_items = [*topic_items, message]
yield streaming_service.format_thinking_step(
yield _emit_thinking_step(
step_id=last_active_step_id,
title=last_active_step_title,
status="in_progress",
@ -2155,10 +2250,14 @@ async def _stream_agent_events(
elif event_type in ("on_chain_end", "on_agent_end"):
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
if content_builder is not None:
content_builder.on_text_end(current_text_id)
current_text_id = None
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
if content_builder is not None:
content_builder.on_text_end(current_text_id)
completion_event = complete_current_step()
if completion_event:
@ -2243,8 +2342,14 @@ async def _stream_agent_events(
)
gate_text_id = streaming_service.generate_text_id()
yield streaming_service.format_text_start(gate_text_id)
if content_builder is not None:
content_builder.on_text_start(gate_text_id)
yield streaming_service.format_text_delta(gate_text_id, gate_notice)
if content_builder is not None:
content_builder.on_text_delta(gate_text_id, gate_notice)
yield streaming_service.format_text_end(gate_text_id)
if content_builder is not None:
content_builder.on_text_end(gate_text_id)
yield streaming_service.format_terminal_info(gate_notice, "error")
accumulated_text = gate_notice
else:
@ -2270,6 +2375,7 @@ async def stream_new_chat(
llm_config_id: int = -1,
mentioned_document_ids: list[int] | None = None,
mentioned_surfsense_doc_ids: list[int] | None = None,
mentioned_documents: list[dict[str, Any]] | None = None,
checkpoint_id: str | None = None,
needs_history_bootstrap: bool = False,
thread_visibility: ChatVisibility | None = None,
@ -2949,6 +3055,96 @@ async def stream_new_chat(
)
yield streaming_service.format_data("turn-status", {"status": "busy"})
# Persist the user-side row for this turn before any expensive
# work runs. Closes the "ghost-thread" abuse vector
# (authenticated client hits POST /new_chat then never calls
# /messages — empty new_chat_messages, free LLM completion).
# Idempotent against the unique index in migration 141 so the
# legacy frontend appendMessage call is a no-op on the second
# writer. Hard failure aborts the turn so we never produce a
# title or assistant row that isn't anchored to a persisted
# user message.
from app.tasks.chat.content_builder import AssistantContentBuilder
from app.tasks.chat.persistence import (
persist_assistant_shell,
persist_user_turn,
)
user_message_id = await persist_user_turn(
chat_id=chat_id,
user_id=user_id,
turn_id=stream_result.turn_id,
user_query=user_query,
user_image_data_urls=user_image_data_urls,
mentioned_documents=mentioned_documents,
)
if user_message_id is None:
yield _emit_stream_error(
message=(
"We couldn't save your message. Please try again in a moment."
),
error_kind="server_error",
error_code="MESSAGE_PERSIST_FAILED",
)
yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
return
# Emit canonical user message id BEFORE any LLM streaming so the
# FE can rename its optimistic ``msg-user-XXX`` placeholder to
# ``msg-{user_message_id}`` and unlock features gated on a real
# DB id (comments, edit-from-this-message). See B4 in
# ``sse-based_message_id_handshake`` plan.
yield streaming_service.format_data(
"user-message-id",
{"message_id": user_message_id, "turn_id": stream_result.turn_id},
)
# Pre-write the assistant row for this turn so we have a stable
# ``message_id`` to anchor mid-stream metadata (token_usage,
# future agent_action_log.message_id correlation) and a
# write-once UPDATE target at finalize time. Idempotent against
# the (thread_id, turn_id, ASSISTANT) partial unique index from
# migration 141 — if the legacy frontend appendMessage races
# this, we recover the existing row's id.
assistant_message_id = await persist_assistant_shell(
chat_id=chat_id,
user_id=user_id,
turn_id=stream_result.turn_id,
)
if assistant_message_id is None:
# Genuine DB failure — abort the turn rather than stream
# into a void. The user row is already persisted so the
# legacy "ghost-thread" gate isn't reopened.
yield _emit_stream_error(
message=(
"We couldn't initialize the assistant message. Please try again."
),
error_kind="server_error",
error_code="MESSAGE_PERSIST_FAILED",
)
yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
return
# Emit canonical assistant message id BEFORE any LLM streaming
# so the FE can rename its optimistic ``msg-assistant-XXX``
# placeholder to ``msg-{assistant_message_id}`` and bind
# ``tokenUsageStore`` / ``pendingInterrupt`` to the real id
# immediately. See B4 in ``sse-based_message_id_handshake``
# plan.
yield streaming_service.format_data(
"assistant-message-id",
{"message_id": assistant_message_id, "turn_id": stream_result.turn_id},
)
stream_result.assistant_message_id = assistant_message_id
stream_result.content_builder = AssistantContentBuilder()
# Initial thinking step - analyzing the request
if mentioned_surfsense_docs:
initial_title = "Analyzing referenced content"
@ -2981,6 +3177,15 @@ async def stream_new_chat(
initial_items = [f"{action_verb}: {' '.join(processing_parts)}"]
initial_step_id = "thinking-1"
# Drive the builder for this initial thinking step too — the
# ``_emit_thinking_step`` helper lives inside ``_stream_agent_events``
# so it isn't in scope here, but the FE folds this step into
# the same singleton ``data-thinking-steps`` part as everything
# the agent stream emits later. Mirror that fold server-side.
if stream_result.content_builder is not None:
stream_result.content_builder.on_thinking_step(
initial_step_id, initial_title, "in_progress", initial_items
)
yield streaming_service.format_thinking_step(
step_id=initial_step_id,
title=initial_title,
@ -2997,16 +3202,34 @@ async def stream_new_chat(
# Check if this is the first assistant response so we can generate
# a title in parallel with the agent stream (better UX than waiting
# until after the full response).
assistant_count_result = await session.execute(
select(func.count(NewChatMessage.id)).filter(
# Use a LIMIT 1 EXISTS-style probe rather than COUNT(*) because
# this is now a hot path executed on every turn, and COUNT scales
# with thread length (server-side persistence can grow rows
# quickly under power users).
#
# IMPORTANT: ``persist_assistant_shell`` above (line ~3112) already
# inserted THIS turn's assistant row. We must therefore exclude
# it from the probe — otherwise the gate fires on every turn
# except the very first, and title generation never runs for new
# threads. Excluding by primary key (``id != assistant_message_id``)
# is bulletproof regardless of ``turn_id`` shape (legacy NULLs,
# resume turns, etc.).
first_assistant_probe = await session.execute(
select(NewChatMessage.id)
.filter(
NewChatMessage.thread_id == chat_id,
NewChatMessage.role == "assistant",
NewChatMessage.id != assistant_message_id,
)
.limit(1)
)
is_first_response = (assistant_count_result.scalar() or 0) == 0
is_first_response = first_assistant_probe.scalars().first() is None
title_task: asyncio.Task[tuple[str | None, dict | None]] | None = None
if is_first_response:
# Gate title generation on a persisted user message so a stream
# that fails before persistence (we abort above) can never leave
# behind a thread with a generated title and no anchoring rows.
if is_first_response and user_message_id is not None:
async def _generate_title() -> tuple[str | None, dict | None]:
"""Generate a short title via litellm.acompletion.
@ -3138,6 +3361,7 @@ async def stream_new_chat(
),
fallback_commit_thread_id=chat_id,
runtime_context=runtime_context,
content_builder=stream_result.content_builder,
):
if not _first_event_logged:
_perf_log.info(
@ -3493,6 +3717,81 @@ async def stream_new_chat(
with contextlib.suppress(Exception):
await session.close()
# Server-side assistant-message + token_usage finalization.
# Runs after the main session has been closed (uses its own
# shielded session) so we don't fight the same DB connection.
# Idempotent against the legacy frontend appendMessage:
# * the assistant row was already INSERTed by
# ``persist_assistant_shell`` above, so this just UPDATEs
# it with the rich ContentPart[] from the builder.
# * token_usage uses INSERT ... ON CONFLICT DO NOTHING
# against migration 142's partial unique index, so a
# racing append_message recovery branch can never
# double-write.
# ``mark_interrupted`` closes any open text/reasoning blocks
# and flips running tool-calls (no result) to state=aborted
# so the persisted JSONB reflects a coherent end-state even
# on client disconnect.
# Never raises (best-effort, logs only).
if (
stream_result
and stream_result.turn_id
and stream_result.assistant_message_id
):
from app.tasks.chat.persistence import finalize_assistant_turn
builder_stats: dict[str, int] | None = None
if stream_result.content_builder is not None:
stream_result.content_builder.mark_interrupted()
# Snapshot stats BEFORE deepcopy in ``snapshot()`` so
# the perf log records the actual finalised payload
# (post-mark_interrupted), not the live-mutating
# builder state.
builder_stats = stream_result.content_builder.stats()
content_payload = stream_result.content_builder.snapshot()
else:
# Defensive fallback — we always set the builder
# alongside ``assistant_message_id`` above, so this
# branch only fires if a future refactor ever
# decouples them. Persist whatever accumulated
# text we captured so the row at least renders.
content_payload = [
{
"type": "text",
"text": stream_result.accumulated_text or "",
}
]
if builder_stats is not None:
_perf_log.info(
"[stream_new_chat] finalize_payload chat_id=%s "
"message_id=%s parts=%d bytes=%d text=%d "
"reasoning=%d tool_calls=%d "
"tool_calls_completed=%d tool_calls_aborted=%d "
"thinking_step_parts=%d step_separators=%d",
chat_id,
stream_result.assistant_message_id,
builder_stats["parts"],
builder_stats["bytes"],
builder_stats["text"],
builder_stats["reasoning"],
builder_stats["tool_calls"],
builder_stats["tool_calls_completed"],
builder_stats["tool_calls_aborted"],
builder_stats["thinking_step_parts"],
builder_stats["step_separators"],
)
await finalize_assistant_turn(
message_id=stream_result.assistant_message_id,
chat_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
turn_id=stream_result.turn_id,
content=content_payload,
accumulator=accumulator,
)
# Persist any sandbox-produced files to local storage so they
# remain downloadable after the Daytona sandbox auto-deletes.
if stream_result and stream_result.sandbox_files:
@ -3937,6 +4236,50 @@ async def stream_resume_chat(
)
yield streaming_service.format_data("turn-status", {"status": "busy"})
# Pre-write a fresh assistant row for this resume turn. The
# original (interrupted) ``stream_new_chat`` invocation already
# persisted its own assistant row anchored to a different
# ``turn_id``; resume allocates a new ``turn_id`` (above) so we
# need a separate row keyed on the same ``(thread_id, turn_id,
# ASSISTANT)`` invariant. Idempotent against migration 141's
# partial unique index — recovers existing id on retry.
from app.tasks.chat.content_builder import AssistantContentBuilder
from app.tasks.chat.persistence import persist_assistant_shell
assistant_message_id = await persist_assistant_shell(
chat_id=chat_id,
user_id=user_id,
turn_id=stream_result.turn_id,
)
if assistant_message_id is None:
yield _emit_stream_error(
message=(
"We couldn't initialize the assistant message. Please try again."
),
error_kind="server_error",
error_code="MESSAGE_PERSIST_FAILED",
)
yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
return
# Emit canonical assistant message id BEFORE any LLM streaming
# so the FE can rename ``pendingInterrupt.assistantMsgId`` to
# ``msg-{assistant_message_id}`` immediately. Resume does NOT
# emit ``data-user-message-id`` because the user row is from
# the original interrupted turn (different ``turn_id``) and is
# never re-persisted here. See B5 in the
# ``sse-based_message_id_handshake`` plan.
yield streaming_service.format_data(
"assistant-message-id",
{"message_id": assistant_message_id, "turn_id": stream_result.turn_id},
)
stream_result.assistant_message_id = assistant_message_id
stream_result.content_builder = AssistantContentBuilder()
# Resume path doesn't carry new ``mentioned_document_ids`` —
# those are seeded in the original turn. We still pass a
# context so future middleware extensions (Phase 2) can rely on
@ -3968,6 +4311,7 @@ async def stream_resume_chat(
),
fallback_commit_thread_id=chat_id,
runtime_context=runtime_context,
content_builder=stream_result.content_builder,
):
if not _first_event_logged:
_perf_log.info(
@ -4219,6 +4563,64 @@ async def stream_resume_chat(
with contextlib.suppress(Exception):
await session.close()
# Server-side assistant-message + token_usage finalization for
# the resume flow. The original user message was persisted by
# the original (interrupted) ``stream_new_chat`` invocation;
# the resume's own ``persist_assistant_shell`` write lives at
# the new ``turn_id`` above. This finalize updates that row
# with the rich ContentPart[] from the builder and writes
# token_usage idempotently via migration 142's partial
# unique index. Best-effort, never raises.
if (
stream_result
and stream_result.turn_id
and stream_result.assistant_message_id
):
from app.tasks.chat.persistence import finalize_assistant_turn
builder_stats: dict[str, int] | None = None
if stream_result.content_builder is not None:
stream_result.content_builder.mark_interrupted()
builder_stats = stream_result.content_builder.stats()
content_payload = stream_result.content_builder.snapshot()
else:
content_payload = [
{
"type": "text",
"text": stream_result.accumulated_text or "",
}
]
if builder_stats is not None:
_perf_log.info(
"[stream_resume] finalize_payload chat_id=%s "
"message_id=%s parts=%d bytes=%d text=%d "
"reasoning=%d tool_calls=%d "
"tool_calls_completed=%d tool_calls_aborted=%d "
"thinking_step_parts=%d step_separators=%d",
chat_id,
stream_result.assistant_message_id,
builder_stats["parts"],
builder_stats["bytes"],
builder_stats["text"],
builder_stats["reasoning"],
builder_stats["tool_calls"],
builder_stats["tool_calls_completed"],
builder_stats["tool_calls_aborted"],
builder_stats["thinking_step_parts"],
builder_stats["step_separators"],
)
await finalize_assistant_turn(
message_id=stream_result.assistant_message_id,
chat_id=chat_id,
search_space_id=search_space_id,
user_id=user_id,
turn_id=stream_result.turn_id,
content=content_payload,
accumulator=accumulator,
)
agent = llm = connector_service = None
stream_result = None
session = None