mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-07 14:52:39 +02:00
feat: moved chat persistance to Server Side
This commit is contained in:
parent
2e1b9b5582
commit
19b6e0a025
19 changed files with 4515 additions and 390 deletions
|
|
@ -675,6 +675,23 @@ class NewChatMessage(BaseModel, TimestampMixin):
|
|||
|
||||
__tablename__ = "new_chat_messages"
|
||||
|
||||
# Partial unique index on (thread_id, turn_id, role) where turn_id IS NOT NULL.
|
||||
# Mirrors alembic migration 141. Lets the streaming agent and the
|
||||
# legacy frontend appendMessage call coexist idempotently — the second
|
||||
# writer trips the unique and recovers without creating a duplicate row.
|
||||
# Partial so legacy NULL turn_id rows and clone/snapshot inserts in
|
||||
# app/services/public_chat_service.py (which omit turn_id) are unaffected.
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"uq_new_chat_messages_thread_turn_role",
|
||||
"thread_id",
|
||||
"turn_id",
|
||||
"role",
|
||||
unique=True,
|
||||
postgresql_where=text("turn_id IS NOT NULL"),
|
||||
),
|
||||
)
|
||||
|
||||
role = Column(SQLAlchemyEnum(NewChatMessageRole), nullable=False)
|
||||
# Content stored as JSONB to support rich content (text, tool calls, etc.)
|
||||
content = Column(JSONB, nullable=False)
|
||||
|
|
@ -728,6 +745,22 @@ class TokenUsage(BaseModel, TimestampMixin):
|
|||
|
||||
__tablename__ = "token_usage"
|
||||
|
||||
# Partial unique index on (message_id) where message_id IS NOT NULL.
|
||||
# Mirrors alembic migration 142. Lets the streaming agent's
|
||||
# ``finalize_assistant_turn`` and the legacy frontend ``append_message``
|
||||
# recovery branch both use ``INSERT ... ON CONFLICT DO NOTHING`` without
|
||||
# racing on a SELECT-then-INSERT window. Partial so non-chat usage rows
|
||||
# (indexing, image generation, podcasts) — which keep ``message_id`` NULL
|
||||
# because there is no per-message anchor — are unaffected.
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"uq_token_usage_message_id",
|
||||
"message_id",
|
||||
unique=True,
|
||||
postgresql_where=text("message_id IS NOT NULL"),
|
||||
),
|
||||
)
|
||||
|
||||
prompt_tokens = Column(Integer, nullable=False, default=0)
|
||||
completion_tokens = Column(Integer, nullable=False, default=0)
|
||||
total_tokens = Column(Integer, nullable=False, default=0)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,8 @@ from datetime import UTC, datetime
|
|||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy import func, or_, text as sa_text
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
|
@ -44,6 +45,7 @@ from app.db import (
|
|||
NewChatThread,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
TokenUsage,
|
||||
User,
|
||||
get_async_session,
|
||||
shielded_async_session,
|
||||
|
|
@ -69,9 +71,9 @@ from app.schemas.new_chat import (
|
|||
TokenUsageSummary,
|
||||
TurnStatusResponse,
|
||||
)
|
||||
from app.services.token_tracking_service import record_token_usage
|
||||
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
||||
from app.users import current_active_user
|
||||
from app.utils.perf import get_perf_logger
|
||||
from app.utils.rbac import check_permission
|
||||
from app.utils.user_message_multimodal import (
|
||||
split_langchain_human_content,
|
||||
|
|
@ -79,6 +81,7 @@ from app.utils.user_message_multimodal import (
|
|||
)
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
_perf_log = get_perf_logger()
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
TURN_CANCELLING_INITIAL_DELAY_MS = 200
|
||||
TURN_CANCELLING_BACKOFF_FACTOR = 2
|
||||
|
|
@ -1287,6 +1290,24 @@ async def append_message(
|
|||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
.. deprecated:: 2026-05
|
||||
Replaced by the **SSE-based message ID handshake**. The streaming
|
||||
generator (`stream_new_chat` / `stream_resume_chat`) now persists
|
||||
both the user and assistant rows server-side via
|
||||
``persist_user_turn`` / ``persist_assistant_shell`` and emits
|
||||
``data-user-message-id`` / ``data-assistant-message-id`` SSE events
|
||||
so the frontend can rename its optimistic IDs in real time. The
|
||||
new FE bundle no longer calls this route.
|
||||
|
||||
This handler is retained as a **silent no-op for legacy / cached
|
||||
FE bundles**: the underlying ``INSERT ... ON CONFLICT DO NOTHING``
|
||||
pattern means a stale bundle hitting this route after the SSE
|
||||
handshake already wrote the row simply returns the existing row
|
||||
(200 OK) without raising or duplicating data. After a 2-week soak
|
||||
(target: ``[persist_user_turn] outcome=race_recovered`` rate ~0)
|
||||
this entire route — and the FE ``appendMessage`` function — is
|
||||
earmarked for removal.
|
||||
|
||||
Append a message to a thread.
|
||||
This is used by ThreadHistoryAdapter.append() to persist messages.
|
||||
|
||||
|
|
@ -1297,6 +1318,22 @@ async def append_message(
|
|||
Requires CHATS_UPDATE permission.
|
||||
"""
|
||||
try:
|
||||
# Capture ``user.id`` as a primitive UUID up front. The
|
||||
# ``current_active_user`` dependency hands us a ``User`` ORM
|
||||
# row bound to ``session``; if the outer ``except
|
||||
# IntegrityError`` block below ever fires (an unexpected
|
||||
# constraint like a foreign key violation — the common
|
||||
# ``(thread_id, turn_id, role)`` race is now handled silently
|
||||
# by ``ON CONFLICT DO NOTHING`` so it never raises) it calls
|
||||
# ``session.rollback()``, which expires every attached ORM
|
||||
# row including this user. Any later ``user.id`` access would
|
||||
# then trigger a lazy PK reload — which on async SQLAlchemy
|
||||
# fails with ``MissingGreenlet`` because the reload happens
|
||||
# outside the awaitable greenlet boundary. Reading ``id``
|
||||
# once here pins the value as a plain UUID so all downstream
|
||||
# uses (TokenUsage insert, response build) are immune.
|
||||
user_uuid = user.id
|
||||
|
||||
# Parse raw body - extract only role and content, ignoring extra fields
|
||||
raw_body = await request.json()
|
||||
role = raw_body.get("role")
|
||||
|
|
@ -1351,42 +1388,166 @@ async def append_message(
|
|||
else None
|
||||
)
|
||||
|
||||
db_message = NewChatMessage(
|
||||
thread_id=thread_id,
|
||||
role=message_role,
|
||||
content=content,
|
||||
author_id=user.id,
|
||||
turn_id=turn_id_value,
|
||||
)
|
||||
session.add(db_message)
|
||||
|
||||
# Update thread's updated_at timestamp
|
||||
# Update thread's updated_at timestamp (always — both insert
|
||||
# and recovery paths represent thread activity).
|
||||
thread.updated_at = datetime.now(UTC)
|
||||
|
||||
# flush assigns the PK/defaults without a round-trip SELECT
|
||||
await session.flush()
|
||||
# Insert the new message via ``INSERT ... ON CONFLICT DO NOTHING``
|
||||
# keyed on the ``(thread_id, turn_id, role)`` partial unique
|
||||
# index from migration 141 (``WHERE turn_id IS NOT NULL``).
|
||||
#
|
||||
# Why ON CONFLICT instead of ``session.add() + flush() + except
|
||||
# IntegrityError``:
|
||||
# 1. The conflict between this legacy FE ``appendMessage``
|
||||
# round-trip and the server-side
|
||||
# ``finalize_assistant_turn`` writer is a NORMAL,
|
||||
# *expected* race — every assistant turn fires it. Using
|
||||
# catch-and-recover means asyncpg raises
|
||||
# ``UniqueViolationError`` -> SQLAlchemy wraps it as
|
||||
# ``IntegrityError`` -> our handler catches and recovers.
|
||||
# Functionally fine, but every ``raise`` event lights up
|
||||
# VS Code's debugger (debugpy's ``justMyCode=false`` mode
|
||||
# loses track of the catch frame across SQLAlchemy's
|
||||
# async greenlet boundary, so even ``Raised Exceptions``
|
||||
# being unchecked doesn't reliably suppress the pause).
|
||||
# ON CONFLICT pushes the conflict resolution into Postgres
|
||||
# where no Python exception is constructed at all.
|
||||
# 2. No ``session.rollback()`` -> no expiring of attached
|
||||
# ORM rows -> no risk of ``MissingGreenlet`` from
|
||||
# lazy-loading expired user/thread state later in the
|
||||
# handler.
|
||||
# 3. Cleaner production logs (no SQLAlchemy ``IntegrityError``
|
||||
# tracebacks emitted by uvicorn's logger between the
|
||||
# ``raise`` and our ``except``).
|
||||
#
|
||||
# When ``turn_id_value`` is ``None`` the partial index doesn't
|
||||
# apply and the INSERT proceeds normally. Other constraint
|
||||
# violations (FK, NOT NULL, etc.) still raise ``IntegrityError``
|
||||
# and are caught by the outer ``except IntegrityError`` block
|
||||
# to preserve the legacy 400 behavior.
|
||||
#
|
||||
# Note on ``content``: when we recover the existing row, we
|
||||
# intentionally discard the FE's ``content`` payload from
|
||||
# ``raw_body`` and return the row's existing ``content``. The
|
||||
# streaming task is now the *authoritative writer* for
|
||||
# assistant ``ContentPart[]`` shape (mid-stream
|
||||
# ``AssistantContentBuilder`` -> ``finalize_assistant_turn``)
|
||||
# so the FE's later ``appendMessage`` is just a stale snapshot
|
||||
# of the same data — keeping the server-built rich content
|
||||
# (with full tool-call args / argsText / langchainToolCallId)
|
||||
# is correct, not lossy.
|
||||
insert_stmt = (
|
||||
pg_insert(NewChatMessage)
|
||||
.values(
|
||||
thread_id=thread_id,
|
||||
role=message_role,
|
||||
content=content,
|
||||
author_id=user_uuid,
|
||||
turn_id=turn_id_value,
|
||||
)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=["thread_id", "turn_id", "role"],
|
||||
index_where=sa_text("turn_id IS NOT NULL"),
|
||||
)
|
||||
.returning(NewChatMessage.id)
|
||||
)
|
||||
inserted_id = (await session.execute(insert_stmt)).scalar()
|
||||
|
||||
if inserted_id is None:
|
||||
# Conflict on partial unique index — server-side stream
|
||||
# already wrote this row. Look it up and reuse it.
|
||||
if turn_id_value is None:
|
||||
# Defensive: ON CONFLICT only fires for ``turn_id IS
|
||||
# NOT NULL`` rows, so this branch should be
|
||||
# unreachable. Preserve the legacy 400 just in case
|
||||
# Postgres ever surprises us.
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Database constraint violation. Please check your input data.",
|
||||
) from None
|
||||
lookup = await session.execute(
|
||||
select(NewChatMessage).filter(
|
||||
NewChatMessage.thread_id == thread_id,
|
||||
NewChatMessage.turn_id == turn_id_value,
|
||||
NewChatMessage.role == message_role,
|
||||
)
|
||||
)
|
||||
existing_message = lookup.scalars().first()
|
||||
if existing_message is None:
|
||||
# Conflict reported but the row vanished between
|
||||
# INSERT and SELECT — extremely unlikely (would
|
||||
# require a concurrent DELETE within the same
|
||||
# transaction visibility), but preserve safe
|
||||
# behavior.
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Database constraint violation. Please check your input data.",
|
||||
) from None
|
||||
db_message = existing_message
|
||||
# Perf signal: counts how often the legacy FE round-trip
|
||||
# races the server-side ``finalize_assistant_turn``. A
|
||||
# rising rate after the rework is OK (it's exactly the
|
||||
# ghost-thread fix's recovery path firing); a sudden drop
|
||||
# to zero would mean the FE isn't posting appendMessage
|
||||
# at all (different bug).
|
||||
_perf_log.info(
|
||||
"[append_message] outcome=recovered_via_unique_index "
|
||||
"thread_id=%s turn_id=%s role=%s message_id=%s",
|
||||
thread_id,
|
||||
turn_id_value,
|
||||
message_role.value,
|
||||
db_message.id,
|
||||
)
|
||||
else:
|
||||
# INSERT succeeded — load the full ORM row so the
|
||||
# response can include server-side-defaulted columns
|
||||
# (``created_at``, etc.) and the relationship surface
|
||||
# stays consistent with the recovery path.
|
||||
inserted_row = await session.get(NewChatMessage, inserted_id)
|
||||
if inserted_row is None:
|
||||
# Should be impossible: we just inserted it in this
|
||||
# same transaction. Fail loud if it happens.
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Inserted message could not be loaded.",
|
||||
) from None
|
||||
db_message = inserted_row
|
||||
|
||||
# Persist token usage if provided (for assistant messages).
|
||||
# ``cost_micros`` is the provider USD cost reported by LiteLLM,
|
||||
# forwarded by the FE through the appendMessage round-trip so
|
||||
# the historical TokenUsage row matches the credit debit applied
|
||||
# at finalize time.
|
||||
#
|
||||
# De-dup: ``finalize_assistant_turn`` may also race to write a
|
||||
# token_usage row for this same ``message_id`` (cross-session,
|
||||
# cross-shielded). Use ``INSERT ... ON CONFLICT DO NOTHING`` keyed
|
||||
# on the ``uq_token_usage_message_id`` partial unique index
|
||||
# (migration 142). The loser silently drops its insert; exactly
|
||||
# one row results regardless of which writer commits first.
|
||||
token_usage_data = raw_body.get("token_usage")
|
||||
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
|
||||
await record_token_usage(
|
||||
session,
|
||||
usage_type="chat",
|
||||
search_space_id=thread.search_space_id,
|
||||
user_id=user.id,
|
||||
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=token_usage_data.get("completion_tokens", 0),
|
||||
total_tokens=token_usage_data.get("total_tokens", 0),
|
||||
cost_micros=token_usage_data.get("cost_micros", 0),
|
||||
model_breakdown=token_usage_data.get("usage"),
|
||||
call_details=token_usage_data.get("call_details"),
|
||||
thread_id=thread_id,
|
||||
message_id=db_message.id,
|
||||
insert_stmt = (
|
||||
pg_insert(TokenUsage)
|
||||
.values(
|
||||
usage_type="chat",
|
||||
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=token_usage_data.get("completion_tokens", 0),
|
||||
total_tokens=token_usage_data.get("total_tokens", 0),
|
||||
cost_micros=token_usage_data.get("cost_micros", 0),
|
||||
model_breakdown=token_usage_data.get("usage"),
|
||||
call_details=token_usage_data.get("call_details"),
|
||||
thread_id=thread_id,
|
||||
message_id=db_message.id,
|
||||
search_space_id=thread.search_space_id,
|
||||
user_id=user_uuid,
|
||||
)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=["message_id"],
|
||||
index_where=sa_text("message_id IS NOT NULL"),
|
||||
)
|
||||
)
|
||||
await session.execute(insert_stmt)
|
||||
|
||||
await session.commit()
|
||||
|
||||
|
|
@ -1406,6 +1567,9 @@ async def append_message(
|
|||
except HTTPException:
|
||||
raise
|
||||
except IntegrityError:
|
||||
# Any IntegrityError that escaped the inline handler above
|
||||
# comes from a *different* constraint (foreign key, etc.) —
|
||||
# preserve the legacy 400 path.
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
|
|
@ -1599,6 +1763,12 @@ async def handle_new_chat(
|
|||
else None
|
||||
)
|
||||
|
||||
mentioned_documents_payload = (
|
||||
[doc.model_dump() for doc in request.mentioned_documents]
|
||||
if request.mentioned_documents
|
||||
else None
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
stream_new_chat(
|
||||
user_query=request.user_query,
|
||||
|
|
@ -1608,6 +1778,7 @@ async def handle_new_chat(
|
|||
llm_config_id=llm_config_id,
|
||||
mentioned_document_ids=request.mentioned_document_ids,
|
||||
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
|
||||
mentioned_documents=mentioned_documents_payload,
|
||||
needs_history_bootstrap=thread.needs_history_bootstrap,
|
||||
thread_visibility=thread.visibility,
|
||||
current_user_display_name=user.display_name or "A team member",
|
||||
|
|
@ -2078,6 +2249,11 @@ async def regenerate_response(
|
|||
"data": revert_results,
|
||||
}
|
||||
yield f"data: {json.dumps(envelope, default=str)}\n\n".encode()
|
||||
mentioned_documents_payload = (
|
||||
[doc.model_dump() for doc in request.mentioned_documents]
|
||||
if request.mentioned_documents
|
||||
else None
|
||||
)
|
||||
try:
|
||||
async for chunk in stream_new_chat(
|
||||
user_query=str(user_query_to_use),
|
||||
|
|
@ -2087,6 +2263,7 @@ async def regenerate_response(
|
|||
llm_config_id=llm_config_id,
|
||||
mentioned_document_ids=request.mentioned_document_ids,
|
||||
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
|
||||
mentioned_documents=mentioned_documents_payload,
|
||||
checkpoint_id=target_checkpoint_id,
|
||||
needs_history_bootstrap=thread.needs_history_bootstrap,
|
||||
thread_visibility=thread.visibility,
|
||||
|
|
|
|||
|
|
@ -200,6 +200,21 @@ class NewChatUserImagePart(BaseModel):
|
|||
return to_data_url(self.media_type, self.data)
|
||||
|
||||
|
||||
class MentionedDocumentInfo(BaseModel):
|
||||
"""Display metadata for a single ``@``-mentioned document.
|
||||
|
||||
The full triple ``{id, title, document_type}`` is forwarded by the
|
||||
frontend mention chip so the server can embed it in the persisted
|
||||
user message ``ContentPart[]`` (single ``mentioned-documents`` part).
|
||||
The history loader then renders the chips on reload without an extra
|
||||
fetch — mirrors the pre-refactor frontend ``persistUserTurn`` shape.
|
||||
"""
|
||||
|
||||
id: int
|
||||
title: str = Field(..., min_length=1, max_length=500)
|
||||
document_type: str = Field(..., min_length=1, max_length=100)
|
||||
|
||||
|
||||
class NewChatRequest(BaseModel):
|
||||
"""Request schema for the deep agent chat endpoint."""
|
||||
|
||||
|
|
@ -213,6 +228,17 @@ class NewChatRequest(BaseModel):
|
|||
mentioned_surfsense_doc_ids: list[int] | None = (
|
||||
None # Optional SurfSense documentation IDs mentioned with @ in the chat
|
||||
)
|
||||
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Display metadata (id, title, document_type) for every "
|
||||
"@-mentioned document. Persisted as a ``mentioned-documents`` "
|
||||
"ContentPart on the user message so reload renders chips "
|
||||
"without an extra fetch. Optional and additive — when None "
|
||||
"the user message is persisted without a mentioned-documents "
|
||||
"part."
|
||||
),
|
||||
)
|
||||
disabled_tools: list[str] | None = (
|
||||
None # Optional list of tool names the user has disabled from the UI
|
||||
)
|
||||
|
|
@ -264,6 +290,16 @@ class RegenerateRequest(BaseModel):
|
|||
)
|
||||
mentioned_document_ids: list[int] | None = None
|
||||
mentioned_surfsense_doc_ids: list[int] | None = None
|
||||
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Display metadata (id, title, document_type) for every "
|
||||
"@-mentioned document on the edited user turn. Only used "
|
||||
"when ``user_query`` is non-None (edit). Persisted as a "
|
||||
"``mentioned-documents`` ContentPart on the new user "
|
||||
"message. None means no chip metadata."
|
||||
),
|
||||
)
|
||||
disabled_tools: list[str] | None = None
|
||||
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
|
||||
client_platform: Literal["web", "desktop"] = "web"
|
||||
|
|
@ -334,6 +370,16 @@ class ResumeRequest(BaseModel):
|
|||
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
|
||||
client_platform: Literal["web", "desktop"] = "web"
|
||||
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
|
||||
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Display metadata forwarded for symmetry with /new_chat and "
|
||||
"/regenerate. Resume reuses the original interrupted user "
|
||||
"turn so the server does not write a new user message. "
|
||||
"Currently unused but accepted to keep request bodies "
|
||||
"uniform across the three streaming entrypoints."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CancelActiveTurnResponse(BaseModel):
|
||||
|
|
|
|||
515
surfsense_backend/app/tasks/chat/content_builder.py
Normal file
515
surfsense_backend/app/tasks/chat/content_builder.py
Normal file
|
|
@ -0,0 +1,515 @@
|
|||
"""Server-side mirror of the frontend's assistant-ui ``ContentPart`` projection.
|
||||
|
||||
Background
|
||||
----------
|
||||
The streaming chat task in ``stream_new_chat`` / ``stream_resume_chat`` yields
|
||||
SSE events that the frontend folds into a ``ContentPartsState`` (see
|
||||
``surfsense_web/lib/chat/streaming-state.ts`` and the matching pipeline in
|
||||
``stream-pipeline.ts``). When a turn ends, the frontend calls
|
||||
``buildContentForPersistence(...)`` and round-trips that ``ContentPart[]``
|
||||
JSONB to ``POST /threads/{id}/messages``, which is what was historically
|
||||
written to ``new_chat_messages.content``.
|
||||
|
||||
After the ghost-thread fix moved persistence server-side, the assistant
|
||||
row is written by ``finalize_assistant_turn`` in the streaming finally
|
||||
block. The frontend's later ``appendMessage`` is now a no-op (recovers
|
||||
via the ``(thread_id, turn_id, role)`` partial unique index added in
|
||||
migration 141), which means the *server* is now responsible for
|
||||
producing the rich ``ContentPart[]`` shape the FE expects on history
|
||||
reload — text + reasoning + tool-call cards (with ``args``, ``argsText``,
|
||||
``result``, ``langchainToolCallId``) + thinking-step buckets +
|
||||
step-separators.
|
||||
|
||||
This module is the in-memory accumulator that mirrors the FE state for
|
||||
exactly that purpose. The streaming code calls ``on_text_*`` / ``on_reasoning_*``
|
||||
/ ``on_tool_*`` / ``on_thinking_step`` / ``on_step_separator`` /
|
||||
``mark_interrupted`` at the same call sites it yields the matching
|
||||
``streaming_service.format_*`` SSE event, so the in-memory ``parts`` list
|
||||
stays in lockstep with what the FE's pipeline would have produced live.
|
||||
``snapshot()`` is then taken once in the ``finally`` block and persisted
|
||||
in a single UPDATE.
|
||||
|
||||
Pure synchronous state — no DB I/O, no async, no flush callbacks. The
|
||||
streaming code is responsible for driving lifecycle methods; this class
|
||||
is a thin projection helper.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Mirrors the FE's filter in ``buildContentForPersistence`` / ``buildContentForUI``:
|
||||
# only text/reasoning/tool-call parts count as "meaningful". data-thinking-steps
|
||||
# and data-step-separator decorate the meaningful parts but never stand alone
|
||||
# in a successful turn.
|
||||
_MEANINGFUL_PART_TYPES: frozenset[str] = frozenset({"text", "reasoning", "tool-call"})
|
||||
|
||||
|
||||
class AssistantContentBuilder:
|
||||
"""Server-side projection of ``surfsense_web/lib/chat/streaming-state.ts``.
|
||||
|
||||
Output shape (deep copy of ``self.parts`` via ``snapshot()``) strictly
|
||||
matches the FE ``ContentPart`` union::
|
||||
|
||||
| { type: "text"; text: string }
|
||||
| { type: "reasoning"; text: string }
|
||||
| { type: "tool-call"; toolCallId: str; toolName: str;
|
||||
args: dict; result?: any; argsText?: str; langchainToolCallId?: str;
|
||||
state?: "aborted" }
|
||||
| { type: "data-thinking-steps"; data: { steps: ThinkingStepData[] } }
|
||||
| { type: "data-step-separator"; data: { stepIndex: int } }
|
||||
|
||||
Order matches the wire order of the SSE events that drive the lifecycle
|
||||
methods, with two FE-mirrored exceptions:
|
||||
|
||||
1. ``data-thinking-steps`` is a *singleton* and pinned at index 0 the
|
||||
first time we see a ``data-thinking-step`` SSE event (the FE's
|
||||
``updateThinkingSteps`` does ``unshift`` on first sight). Subsequent
|
||||
thinking-step updates mutate that singleton in place.
|
||||
2. ``data-step-separator`` is appended only when the message already has
|
||||
meaningful content and the previous part isn't itself a separator
|
||||
(so the FIRST step of a turn doesn't generate a leading divider).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.parts: list[dict[str, Any]] = []
|
||||
# Index of the active text/reasoning part within ``parts`` while
|
||||
# streaming is open; -1 means "no active part" and the next delta
|
||||
# opens a fresh one. Mirrors ``ContentPartsState.currentTextPartIndex``.
|
||||
self._current_text_idx: int = -1
|
||||
self._current_reasoning_idx: int = -1
|
||||
# ``ui_id``-keyed indexes for tool-call parts. ``ui_id`` is the
|
||||
# synthetic ``call_<run_id>`` (legacy) or the LangChain
|
||||
# ``tool_call.id`` (parity_v2) — same key the streaming layer
|
||||
# threads through every ``tool-input-*`` / ``tool-output-*`` event.
|
||||
self._tool_call_idx_by_ui_id: dict[str, int] = {}
|
||||
# Live argsText accumulator (concatenated ``tool-input-delta`` chunks)
|
||||
# so we can reproduce the FE's ``appendToolInputDelta`` behaviour
|
||||
# before ``tool-input-available`` overwrites it with the
|
||||
# pretty-printed final JSON.
|
||||
self._args_text_by_ui_id: dict[str, str] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Text
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def on_text_start(self, text_id: str) -> None:
|
||||
"""Begin a fresh text block.
|
||||
|
||||
Symmetric to FE ``appendText``: opening text closes any active
|
||||
reasoning so the renderer treats them as separate parts. The
|
||||
actual text part isn't materialised here — it's lazily created
|
||||
on the first ``on_text_delta`` so an empty start/end pair
|
||||
leaves no trace. Matches the FE pipeline which has no explicit
|
||||
``text-start`` handler at all.
|
||||
"""
|
||||
if self._current_reasoning_idx >= 0:
|
||||
self._current_reasoning_idx = -1
|
||||
|
||||
def on_text_delta(self, text_id: str, delta: str) -> None:
|
||||
if not delta:
|
||||
return
|
||||
if self._current_reasoning_idx >= 0:
|
||||
# FE behaviour: a text delta after reasoning implicitly
|
||||
# closes the reasoning block (see ``appendText`` lines
|
||||
# 178-180).
|
||||
self._current_reasoning_idx = -1
|
||||
if (
|
||||
self._current_text_idx >= 0
|
||||
and 0 <= self._current_text_idx < len(self.parts)
|
||||
and self.parts[self._current_text_idx].get("type") == "text"
|
||||
):
|
||||
self.parts[self._current_text_idx]["text"] += delta
|
||||
return
|
||||
self.parts.append({"type": "text", "text": delta})
|
||||
self._current_text_idx = len(self.parts) - 1
|
||||
|
||||
def on_text_end(self, text_id: str) -> None:
|
||||
"""Close the active text block.
|
||||
|
||||
Mirrors the wire-level ``text-end`` boundary the streaming layer
|
||||
emits before tool calls / reasoning / step boundaries. The FE
|
||||
pipeline implicitly closes via ``currentTextPartIndex = -1``
|
||||
in ``addToolCall`` / ``appendReasoning`` / ``addStepSeparator``;
|
||||
our helper does the same explicitly so callers don't have to
|
||||
maintain that invariant per call site.
|
||||
"""
|
||||
self._current_text_idx = -1
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Reasoning
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def on_reasoning_start(self, reasoning_id: str) -> None:
|
||||
if self._current_text_idx >= 0:
|
||||
self._current_text_idx = -1
|
||||
|
||||
def on_reasoning_delta(self, reasoning_id: str, delta: str) -> None:
|
||||
if not delta:
|
||||
return
|
||||
if self._current_text_idx >= 0:
|
||||
self._current_text_idx = -1
|
||||
if (
|
||||
self._current_reasoning_idx >= 0
|
||||
and 0 <= self._current_reasoning_idx < len(self.parts)
|
||||
and self.parts[self._current_reasoning_idx].get("type") == "reasoning"
|
||||
):
|
||||
self.parts[self._current_reasoning_idx]["text"] += delta
|
||||
return
|
||||
self.parts.append({"type": "reasoning", "text": delta})
|
||||
self._current_reasoning_idx = len(self.parts) - 1
|
||||
|
||||
def on_reasoning_end(self, reasoning_id: str) -> None:
|
||||
self._current_reasoning_idx = -1
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool calls
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def on_tool_input_start(
|
||||
self,
|
||||
ui_id: str,
|
||||
tool_name: str,
|
||||
langchain_tool_call_id: str | None,
|
||||
) -> None:
|
||||
"""Register a tool-call card. Args are filled in by later events."""
|
||||
if not ui_id:
|
||||
return
|
||||
# Skip duplicate registration: parity_v2 may emit
|
||||
# ``tool-input-start`` from both ``on_chat_model_stream``
|
||||
# (when tool_call_chunks register a name) and ``on_tool_start``
|
||||
# (the canonical path). The FE de-dupes via ``toolCallIndices``;
|
||||
# we mirror that here.
|
||||
if ui_id in self._tool_call_idx_by_ui_id:
|
||||
if langchain_tool_call_id:
|
||||
idx = self._tool_call_idx_by_ui_id[ui_id]
|
||||
part = self.parts[idx]
|
||||
if not part.get("langchainToolCallId"):
|
||||
part["langchainToolCallId"] = langchain_tool_call_id
|
||||
return
|
||||
|
||||
part: dict[str, Any] = {
|
||||
"type": "tool-call",
|
||||
"toolCallId": ui_id,
|
||||
"toolName": tool_name,
|
||||
"args": {},
|
||||
}
|
||||
if langchain_tool_call_id:
|
||||
part["langchainToolCallId"] = langchain_tool_call_id
|
||||
self.parts.append(part)
|
||||
self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1
|
||||
|
||||
self._current_text_idx = -1
|
||||
self._current_reasoning_idx = -1
|
||||
|
||||
def on_tool_input_delta(self, ui_id: str, args_chunk: str) -> None:
|
||||
"""Append a streamed args-delta chunk to the matching card's argsText.
|
||||
|
||||
Mirrors FE ``appendToolInputDelta``: no-ops when no card has been
|
||||
registered yet for the given ``ui_id`` — the deltas have nowhere
|
||||
safe to land.
|
||||
"""
|
||||
if not ui_id or not args_chunk:
|
||||
return
|
||||
idx = self._tool_call_idx_by_ui_id.get(ui_id)
|
||||
if idx is None:
|
||||
return
|
||||
if not (0 <= idx < len(self.parts)):
|
||||
return
|
||||
part = self.parts[idx]
|
||||
if part.get("type") != "tool-call":
|
||||
return
|
||||
new_text = (part.get("argsText") or "") + args_chunk
|
||||
part["argsText"] = new_text
|
||||
self._args_text_by_ui_id[ui_id] = new_text
|
||||
|
||||
def on_tool_input_available(
|
||||
self,
|
||||
ui_id: str,
|
||||
tool_name: str,
|
||||
args: dict[str, Any],
|
||||
langchain_tool_call_id: str | None,
|
||||
) -> None:
|
||||
"""Finalize the tool-call card's input.
|
||||
|
||||
Mirrors FE ``stream-pipeline.ts`` lines 127-153: replaces ``argsText``
|
||||
with ``json.dumps(input, indent=2)`` so the post-stream card renders
|
||||
pretty-printed JSON, sets the full ``args`` dict, and backfills
|
||||
``langchainToolCallId`` if it wasn't known at ``tool-input-start`` time.
|
||||
Also creates the card if no prior ``tool-input-start`` registered it
|
||||
(legacy parity_v2-OFF / late-registration paths).
|
||||
"""
|
||||
if not ui_id:
|
||||
return
|
||||
try:
|
||||
final_args_text = json.dumps(args or {}, indent=2, ensure_ascii=False)
|
||||
except (TypeError, ValueError):
|
||||
# Defensive: ``args`` should already be JSON-safe (the
|
||||
# streaming layer sanitizes it before emitting), but if a
|
||||
# caller hands us a non-serializable value we still want
|
||||
# to record the call without breaking the snapshot.
|
||||
final_args_text = str(args)
|
||||
|
||||
idx = self._tool_call_idx_by_ui_id.get(ui_id)
|
||||
if idx is not None and 0 <= idx < len(self.parts):
|
||||
part = self.parts[idx]
|
||||
if part.get("type") == "tool-call":
|
||||
part["args"] = args or {}
|
||||
part["argsText"] = final_args_text
|
||||
if langchain_tool_call_id and not part.get("langchainToolCallId"):
|
||||
part["langchainToolCallId"] = langchain_tool_call_id
|
||||
return
|
||||
|
||||
# No prior tool-input-start: register the card now.
|
||||
new_part: dict[str, Any] = {
|
||||
"type": "tool-call",
|
||||
"toolCallId": ui_id,
|
||||
"toolName": tool_name,
|
||||
"args": args or {},
|
||||
"argsText": final_args_text,
|
||||
}
|
||||
if langchain_tool_call_id:
|
||||
new_part["langchainToolCallId"] = langchain_tool_call_id
|
||||
self.parts.append(new_part)
|
||||
self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1
|
||||
|
||||
self._current_text_idx = -1
|
||||
self._current_reasoning_idx = -1
|
||||
|
||||
def on_tool_output_available(
|
||||
self,
|
||||
ui_id: str,
|
||||
output: Any,
|
||||
langchain_tool_call_id: str | None,
|
||||
) -> None:
|
||||
"""Attach the tool's output (``result``) to the matching card.
|
||||
|
||||
Mirrors FE ``updateToolCall``: backfill ``langchainToolCallId``
|
||||
only if not already set (a NULL late-arriving value never blows
|
||||
away an earlier known good one).
|
||||
"""
|
||||
if not ui_id:
|
||||
return
|
||||
idx = self._tool_call_idx_by_ui_id.get(ui_id)
|
||||
if idx is None or not (0 <= idx < len(self.parts)):
|
||||
return
|
||||
part = self.parts[idx]
|
||||
if part.get("type") != "tool-call":
|
||||
return
|
||||
part["result"] = output
|
||||
if langchain_tool_call_id and not part.get("langchainToolCallId"):
|
||||
part["langchainToolCallId"] = langchain_tool_call_id
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Thinking steps & step separators
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def on_thinking_step(
|
||||
self,
|
||||
step_id: str,
|
||||
title: str,
|
||||
status: str,
|
||||
items: list[str] | None,
|
||||
) -> None:
|
||||
"""Update / insert the singleton ``data-thinking-steps`` part.
|
||||
|
||||
Mirrors FE ``updateThinkingSteps``: maintain a single
|
||||
``data-thinking-steps`` part anchored at index 0, replacing or
|
||||
unshifting on first sight. Each ``on_thinking_step`` call
|
||||
replaces the entry in the steps list keyed by ``step_id`` (or
|
||||
appends if new).
|
||||
"""
|
||||
if not step_id:
|
||||
return
|
||||
|
||||
new_step = {
|
||||
"id": step_id,
|
||||
"title": title or "",
|
||||
"status": status or "in_progress",
|
||||
"items": list(items) if items else [],
|
||||
}
|
||||
|
||||
# Find existing data-thinking-steps part.
|
||||
existing_idx = -1
|
||||
for i, p in enumerate(self.parts):
|
||||
if p.get("type") == "data-thinking-steps":
|
||||
existing_idx = i
|
||||
break
|
||||
|
||||
if existing_idx >= 0:
|
||||
current_steps = self.parts[existing_idx].get("data", {}).get("steps") or []
|
||||
replaced = False
|
||||
for i, step in enumerate(current_steps):
|
||||
if step.get("id") == step_id:
|
||||
current_steps[i] = new_step
|
||||
replaced = True
|
||||
break
|
||||
if not replaced:
|
||||
current_steps.append(new_step)
|
||||
self.parts[existing_idx] = {
|
||||
"type": "data-thinking-steps",
|
||||
"data": {"steps": current_steps},
|
||||
}
|
||||
return
|
||||
|
||||
# First sight: unshift to position 0 (FE parity).
|
||||
self.parts.insert(
|
||||
0,
|
||||
{
|
||||
"type": "data-thinking-steps",
|
||||
"data": {"steps": [new_step]},
|
||||
},
|
||||
)
|
||||
# Bump tracked indices since we inserted at the head.
|
||||
if self._current_text_idx >= 0:
|
||||
self._current_text_idx += 1
|
||||
if self._current_reasoning_idx >= 0:
|
||||
self._current_reasoning_idx += 1
|
||||
for ui_id, idx in list(self._tool_call_idx_by_ui_id.items()):
|
||||
self._tool_call_idx_by_ui_id[ui_id] = idx + 1
|
||||
|
||||
def on_step_separator(self) -> None:
|
||||
"""Append a ``data-step-separator`` between consecutive model steps.
|
||||
|
||||
Mirrors FE ``addStepSeparator``: only emit when the message
|
||||
already has meaningful content AND the previous part isn't
|
||||
itself a separator. ``stepIndex`` is the running count of
|
||||
separators already in ``parts``.
|
||||
"""
|
||||
has_content = any(p.get("type") in _MEANINGFUL_PART_TYPES for p in self.parts)
|
||||
if not has_content:
|
||||
return
|
||||
if self.parts and self.parts[-1].get("type") == "data-step-separator":
|
||||
return
|
||||
step_index = sum(
|
||||
1 for p in self.parts if p.get("type") == "data-step-separator"
|
||||
)
|
||||
self.parts.append(
|
||||
{
|
||||
"type": "data-step-separator",
|
||||
"data": {"stepIndex": step_index},
|
||||
}
|
||||
)
|
||||
self._current_text_idx = -1
|
||||
self._current_reasoning_idx = -1
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Interruption handling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def mark_interrupted(self) -> None:
|
||||
"""Close any open text/reasoning and flip running tools to aborted.
|
||||
|
||||
Called from the streaming ``finally`` block before ``snapshot()`` so
|
||||
the persisted JSONB reflects a coherent end-state even when the
|
||||
client disconnected mid-turn or the agent hit a fatal error.
|
||||
|
||||
- Active text/reasoning blocks: simply lose their "active"
|
||||
marker (no synthetic content appended). Whatever was streamed
|
||||
stays as-is.
|
||||
- Tool-call parts that never received a ``result`` get
|
||||
``state="aborted"`` so the FE history loader can render them
|
||||
as "interrupted" rather than "still running".
|
||||
"""
|
||||
self._current_text_idx = -1
|
||||
self._current_reasoning_idx = -1
|
||||
for part in self.parts:
|
||||
if part.get("type") != "tool-call":
|
||||
continue
|
||||
if "result" in part:
|
||||
continue
|
||||
part["state"] = "aborted"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Snapshot & introspection
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def snapshot(self) -> list[dict[str, Any]]:
|
||||
"""Return a deep copy of ``parts`` ready for SQL UPDATE / json.dumps.
|
||||
|
||||
Deep-copied so callers that finalize from the shielded ``finally``
|
||||
block can't accidentally mutate the persisted payload while the
|
||||
SQL UPDATE is in flight (the streaming layer doesn't touch the
|
||||
builder after this call, but defensive copies are cheap and cheap
|
||||
is what we want in a finally block).
|
||||
"""
|
||||
return copy.deepcopy(self.parts)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""True if no meaningful content was captured.
|
||||
|
||||
``data-thinking-steps`` and ``data-step-separator`` decorate
|
||||
meaningful content but don't count on their own — a turn that
|
||||
only emitted a thinking step before being interrupted should
|
||||
still be treated as empty for the status-marker fallback.
|
||||
"""
|
||||
return not any(p.get("type") in _MEANINGFUL_PART_TYPES for p in self.parts)
|
||||
|
||||
def stats(self) -> dict[str, int]:
|
||||
"""Return counts of each part-type plus rough byte size.
|
||||
|
||||
Used by the streaming layer's perf logger so an ops dashboard
|
||||
can correlate finalize latency with payload size, and so a
|
||||
regression that quietly stops emitting tool-call parts (or
|
||||
starts emitting hundreds) shows up in [PERF] grep rather than
|
||||
only as a "history reload looks weird" bug report.
|
||||
|
||||
``bytes`` is the JSON-serialised payload length — what actually
|
||||
crosses the wire to PostgreSQL's JSONB column. We compute it
|
||||
with ``ensure_ascii=False`` to match the JSONB encoder's UTF-8
|
||||
on-disk layout closely enough for back-of-the-envelope sizing.
|
||||
Reasoning/text/tool-call/thinking-step/step-separator counts are
|
||||
independent so any one can spike without the others.
|
||||
|
||||
Defensive: ``json.dumps`` failure (a non-serializable value
|
||||
slipped past the streaming layer's sanitization) is reported as
|
||||
``bytes=-1`` rather than raised — perf logging must not be the
|
||||
thing that breaks the streaming finally block.
|
||||
"""
|
||||
text_blocks = 0
|
||||
reasoning_blocks = 0
|
||||
tool_calls = 0
|
||||
tool_calls_completed = 0
|
||||
tool_calls_aborted = 0
|
||||
thinking_step_parts = 0
|
||||
step_separators = 0
|
||||
|
||||
for part in self.parts:
|
||||
kind = part.get("type")
|
||||
if kind == "text":
|
||||
text_blocks += 1
|
||||
elif kind == "reasoning":
|
||||
reasoning_blocks += 1
|
||||
elif kind == "tool-call":
|
||||
tool_calls += 1
|
||||
if part.get("state") == "aborted":
|
||||
tool_calls_aborted += 1
|
||||
elif "result" in part:
|
||||
tool_calls_completed += 1
|
||||
elif kind == "data-thinking-steps":
|
||||
thinking_step_parts += 1
|
||||
elif kind == "data-step-separator":
|
||||
step_separators += 1
|
||||
|
||||
try:
|
||||
byte_size = len(json.dumps(self.parts, ensure_ascii=False, default=str))
|
||||
except (TypeError, ValueError):
|
||||
byte_size = -1
|
||||
|
||||
return {
|
||||
"parts": len(self.parts),
|
||||
"bytes": byte_size,
|
||||
"text": text_blocks,
|
||||
"reasoning": reasoning_blocks,
|
||||
"tool_calls": tool_calls,
|
||||
"tool_calls_completed": tool_calls_completed,
|
||||
"tool_calls_aborted": tool_calls_aborted,
|
||||
"thinking_step_parts": thinking_step_parts,
|
||||
"step_separators": step_separators,
|
||||
}
|
||||
534
surfsense_backend/app/tasks/chat/persistence.py
Normal file
534
surfsense_backend/app/tasks/chat/persistence.py
Normal file
|
|
@ -0,0 +1,534 @@
|
|||
"""Server-side message persistence helpers for the streaming chat agent.
|
||||
|
||||
Historically the streaming task (``stream_new_chat``/``stream_resume_chat``)
|
||||
left ``new_chat_messages`` empty and relied on the frontend to round-trip
|
||||
``POST /threads/{id}/messages`` afterwards. That gave authenticated clients
|
||||
a "ghost-thread" abuse vector: skip the round-trip and burn LLM tokens
|
||||
without leaving an audit trail. These helpers move both writes (the user
|
||||
turn that triggered the stream and the assistant turn the stream produced)
|
||||
into the server itself, idempotent against the partial unique index
|
||||
``uq_new_chat_messages_thread_turn_role`` so legacy frontends that *do*
|
||||
keep posting via ``appendMessage`` simply hit the unique-index recovery
|
||||
path on the second writer instead of creating duplicates.
|
||||
|
||||
Assistant turn lifecycle
|
||||
------------------------
|
||||
The assistant side is split into two helpers so we can capture the row id
|
||||
*before* the stream produces any output:
|
||||
|
||||
* ``persist_assistant_shell`` runs immediately after ``persist_user_turn``
|
||||
and INSERTs an empty assistant row anchored to ``(thread_id, turn_id,
|
||||
ASSISTANT)``. Returns the row id so the streaming layer can correlate
|
||||
later writes (token_usage, AgentActionLog future-correlation) against
|
||||
a stable PK from the start of the turn.
|
||||
* ``finalize_assistant_turn`` runs from the streaming ``finally`` block.
|
||||
It UPDATEs the row's ``content`` to the rich ``ContentPart[]`` snapshot
|
||||
produced server-side by ``AssistantContentBuilder`` and writes the
|
||||
``token_usage`` row using ``INSERT ... ON CONFLICT DO NOTHING`` against
|
||||
the ``uq_token_usage_message_id`` partial unique index from migration
|
||||
142, hard-eliminating any race against ``append_message``'s recovery
|
||||
branch.
|
||||
|
||||
Defensive contract
|
||||
------------------
|
||||
|
||||
* Every helper runs inside ``shielded_async_session()`` so ``session.close()``
|
||||
survives starlette's mid-stream cancel scope on client disconnect.
|
||||
* ``persist_user_turn`` and ``persist_assistant_shell`` use ``INSERT ... ON
|
||||
CONFLICT DO NOTHING ... RETURNING id`` keyed on the ``(thread_id, turn_id,
|
||||
role)`` partial unique index. On conflict the insert silently no-ops at
|
||||
the DB level — no Python ``IntegrityError`` is constructed, which
|
||||
eliminates spurious debugger pauses and keeps logs clean. On conflict a
|
||||
follow-up ``SELECT`` resolves the existing row id so the streaming layer
|
||||
can correlate writes against a stable PK.
|
||||
* ``finalize_assistant_turn`` is best-effort: it never raises. The
|
||||
streaming ``finally`` block calls it from within
|
||||
``anyio.CancelScope(shield=True)`` and any raised exception there
|
||||
would mask the real error.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import text as sa_text
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import (
|
||||
NewChatMessage,
|
||||
NewChatMessageRole,
|
||||
NewChatThread,
|
||||
TokenUsage,
|
||||
shielded_async_session,
|
||||
)
|
||||
from app.services.token_tracking_service import (
|
||||
TurnTokenAccumulator,
|
||||
)
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
# Empty initial assistant content. ``finalize_assistant_turn`` overwrites
|
||||
# this in a single UPDATE at end-of-stream with the full ``ContentPart[]``
|
||||
# snapshot produced by ``AssistantContentBuilder``. We persist a one-element
|
||||
# list with an empty text part so a crash between shell-INSERT and finalize
|
||||
# leaves the row in a FE-renderable shape (blank bubble) instead of
|
||||
# blowing up the history loader.
|
||||
_EMPTY_SHELL_CONTENT: list[dict[str, Any]] = [{"type": "text", "text": ""}]
|
||||
|
||||
# Substituted content for genuinely empty turns (no text, no reasoning,
|
||||
# no tool calls). The streaming layer flips to this when
|
||||
# ``AssistantContentBuilder.is_empty()`` returns True so the persisted
|
||||
# row is at least somewhat self-describing instead of an empty text
|
||||
# bubble. The FE's ``ContentPart`` union doesn't include ``status``
|
||||
# yet, so the history loader will silently drop this part and render
|
||||
# a blank bubble (matches today's behaviour for empty turns); a follow-up
|
||||
# FE PR adds the explicit "no response" rendering.
|
||||
_STATUS_NO_RESPONSE: list[dict[str, Any]] = [
|
||||
{"type": "status", "text": "(no text response)"}
|
||||
]
|
||||
|
||||
|
||||
def _build_user_content(
|
||||
user_query: str,
|
||||
user_image_data_urls: list[str] | None,
|
||||
mentioned_documents: list[dict[str, Any]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build the persisted user-message ``content`` (assistant-ui v2 parts).
|
||||
|
||||
Mirrors the shape the existing frontend posts via
|
||||
``appendMessage`` (see ``surfsense_web/.../new-chat/[[...chat_id]]/page.tsx``):
|
||||
|
||||
[{"type": "text", "text": "..."},
|
||||
{"type": "image", "image": "data:..."},
|
||||
{"type": "mentioned-documents", "documents": [{"id": int,
|
||||
"title": str, "document_type": str}, ...]}]
|
||||
|
||||
The companion reader is
|
||||
``app.utils.user_message_multimodal.split_persisted_user_content_parts``
|
||||
which expects exactly this shape — keep them in sync.
|
||||
|
||||
``mentioned_documents``: optional list of ``{id, title, document_type}``
|
||||
dicts. When non-empty (and a ``mentioned-documents`` part is not already
|
||||
in some other input shape), a single ``{"type": "mentioned-documents",
|
||||
"documents": [...]}`` part is appended. Mirrors the FE injection at
|
||||
``page.tsx:281-286`` (``persistUserTurn``).
|
||||
"""
|
||||
parts: list[dict[str, Any]] = [{"type": "text", "text": user_query or ""}]
|
||||
for url in user_image_data_urls or ():
|
||||
if isinstance(url, str) and url:
|
||||
parts.append({"type": "image", "image": url})
|
||||
if mentioned_documents:
|
||||
normalized: list[dict[str, Any]] = []
|
||||
for doc in mentioned_documents:
|
||||
if not isinstance(doc, dict):
|
||||
continue
|
||||
doc_id = doc.get("id")
|
||||
title = doc.get("title")
|
||||
document_type = doc.get("document_type")
|
||||
if doc_id is None or title is None or document_type is None:
|
||||
continue
|
||||
normalized.append(
|
||||
{
|
||||
"id": doc_id,
|
||||
"title": str(title),
|
||||
"document_type": str(document_type),
|
||||
}
|
||||
)
|
||||
if normalized:
|
||||
parts.append({"type": "mentioned-documents", "documents": normalized})
|
||||
return parts
|
||||
|
||||
|
||||
async def persist_user_turn(
|
||||
*,
|
||||
chat_id: int,
|
||||
user_id: str | None,
|
||||
turn_id: str,
|
||||
user_query: str,
|
||||
user_image_data_urls: list[str] | None = None,
|
||||
mentioned_documents: list[dict[str, Any]] | None = None,
|
||||
) -> int | None:
|
||||
"""Persist the user-side row for a chat turn and return its ``id``.
|
||||
|
||||
Uses ``INSERT ... ON CONFLICT DO NOTHING ... RETURNING id`` keyed on the
|
||||
``(thread_id, turn_id, role)`` partial unique index from migration 141
|
||||
(``WHERE turn_id IS NOT NULL``). On conflict the insert silently no-ops
|
||||
at the DB level — no Python ``IntegrityError`` is constructed, which
|
||||
eliminates the debugger pause that ``justMyCode=false`` + async greenlet
|
||||
interactions used to produce, and keeps production logs clean.
|
||||
|
||||
Returns the ``id`` of the row that exists for this turn after the call:
|
||||
the freshly inserted ``id`` on the happy path, or the existing ``id``
|
||||
when a previous writer (legacy FE ``appendMessage`` racing the SSE
|
||||
stream, redelivered request, etc.) already wrote it. Returns ``None``
|
||||
only on genuine DB failure; the caller should yield a streaming error
|
||||
and abort the turn so we never produce a title/assistant row that
|
||||
isn't anchored to a persisted user message.
|
||||
|
||||
Other constraint violations (FK, NOT NULL, etc.) still raise
|
||||
``IntegrityError`` — only the ``(thread_id, turn_id, role)`` collision
|
||||
is silenced.
|
||||
"""
|
||||
if not turn_id:
|
||||
# Defensive: turn_id is always populated by the streaming path
|
||||
# before this helper is called. If it isn't, we cannot be
|
||||
# idempotent against the unique index — refuse to write rather
|
||||
# than create a row the unique index can't dedupe.
|
||||
logger.error(
|
||||
"persist_user_turn called without a turn_id (chat_id=%s); skipping",
|
||||
chat_id,
|
||||
)
|
||||
return None
|
||||
|
||||
t0 = time.perf_counter()
|
||||
outcome = "failed"
|
||||
resolved_id: int | None = None
|
||||
try:
|
||||
async with shielded_async_session() as ws:
|
||||
# Re-attach the thread row so we can also bump updated_at
|
||||
# in the same write — keeps the sidebar ordering accurate
|
||||
# when a user fires off a turn but never reaches the
|
||||
# legacy appendMessage.
|
||||
thread = await ws.get(NewChatThread, chat_id)
|
||||
author_uuid: UUID | None = None
|
||||
if user_id:
|
||||
try:
|
||||
author_uuid = UUID(user_id)
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(
|
||||
"persist_user_turn: invalid user_id=%r, persisting as anonymous",
|
||||
user_id,
|
||||
)
|
||||
|
||||
content_payload = _build_user_content(
|
||||
user_query, user_image_data_urls, mentioned_documents
|
||||
)
|
||||
insert_stmt = (
|
||||
pg_insert(NewChatMessage)
|
||||
.values(
|
||||
thread_id=chat_id,
|
||||
role=NewChatMessageRole.USER,
|
||||
content=content_payload,
|
||||
author_id=author_uuid,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=["thread_id", "turn_id", "role"],
|
||||
index_where=sa_text("turn_id IS NOT NULL"),
|
||||
)
|
||||
.returning(NewChatMessage.id)
|
||||
)
|
||||
inserted_id = (await ws.execute(insert_stmt)).scalar()
|
||||
|
||||
if inserted_id is None:
|
||||
# Conflict on partial unique index — another writer
|
||||
# (legacy FE appendMessage, redelivered request, etc.)
|
||||
# already persisted this row. Look it up and reuse.
|
||||
lookup = await ws.execute(
|
||||
select(NewChatMessage.id).where(
|
||||
NewChatMessage.thread_id == chat_id,
|
||||
NewChatMessage.turn_id == turn_id,
|
||||
NewChatMessage.role == NewChatMessageRole.USER,
|
||||
)
|
||||
)
|
||||
existing_id = lookup.scalars().first()
|
||||
if existing_id is None:
|
||||
# Conflict reported but no row found — extremely
|
||||
# unlikely (concurrent DELETE). Surface as failure.
|
||||
logger.warning(
|
||||
"persist_user_turn: conflict but no matching row "
|
||||
"(chat_id=%s, turn_id=%s)",
|
||||
chat_id,
|
||||
turn_id,
|
||||
)
|
||||
outcome = "integrity_no_match"
|
||||
return None
|
||||
resolved_id = int(existing_id)
|
||||
outcome = "race_recovered"
|
||||
else:
|
||||
resolved_id = int(inserted_id)
|
||||
outcome = "inserted"
|
||||
# Bump thread.updated_at only on a real insert — when
|
||||
# we recovered an existing row the prior writer
|
||||
# already touched the thread.
|
||||
if thread is not None:
|
||||
thread.updated_at = datetime.now(UTC)
|
||||
|
||||
await ws.commit()
|
||||
return resolved_id
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"persist_user_turn failed (chat_id=%s, turn_id=%s)",
|
||||
chat_id,
|
||||
turn_id,
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
_perf_log.info(
|
||||
"[persist_user_turn] outcome=%s chat_id=%s turn_id=%s "
|
||||
"message_id=%s query_len=%d images=%d mentioned_docs=%d "
|
||||
"in %.3fs",
|
||||
outcome,
|
||||
chat_id,
|
||||
turn_id,
|
||||
resolved_id,
|
||||
len(user_query or ""),
|
||||
len(user_image_data_urls or ()),
|
||||
len(mentioned_documents or ()),
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
|
||||
|
||||
async def persist_assistant_shell(
|
||||
*,
|
||||
chat_id: int,
|
||||
user_id: str | None,
|
||||
turn_id: str,
|
||||
) -> int | None:
|
||||
"""Pre-write an empty assistant row for the turn and return its id.
|
||||
|
||||
Inserts a placeholder ``new_chat_messages`` row (empty text content) so
|
||||
the streaming layer has a stable ``message_id`` to correlate against
|
||||
for the rest of the turn. ``finalize_assistant_turn`` overwrites the
|
||||
``content`` field at end-of-stream with the rich ``ContentPart[]``
|
||||
snapshot produced by ``AssistantContentBuilder``.
|
||||
|
||||
Returns the row id on success, ``None`` on a genuine DB failure (caller
|
||||
should abort the turn rather than stream into a void).
|
||||
|
||||
Idempotent against the ``(thread_id, turn_id, ASSISTANT)`` partial unique
|
||||
index from migration 141: if a row already exists (resume retry, racing
|
||||
legacy frontend, redelivered request, etc.) we look it up by
|
||||
``(thread_id, turn_id, role)`` and return its existing id. The streaming
|
||||
layer is then free to UPDATE that row at finalize time.
|
||||
"""
|
||||
if not turn_id:
|
||||
logger.error(
|
||||
"persist_assistant_shell called without a turn_id (chat_id=%s); skipping",
|
||||
chat_id,
|
||||
)
|
||||
return None
|
||||
|
||||
t0 = time.perf_counter()
|
||||
outcome = "failed"
|
||||
resolved_id: int | None = None
|
||||
try:
|
||||
async with shielded_async_session() as ws:
|
||||
insert_stmt = (
|
||||
pg_insert(NewChatMessage)
|
||||
.values(
|
||||
thread_id=chat_id,
|
||||
role=NewChatMessageRole.ASSISTANT,
|
||||
content=_EMPTY_SHELL_CONTENT,
|
||||
author_id=None,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=["thread_id", "turn_id", "role"],
|
||||
index_where=sa_text("turn_id IS NOT NULL"),
|
||||
)
|
||||
.returning(NewChatMessage.id)
|
||||
)
|
||||
inserted_id = (await ws.execute(insert_stmt)).scalar()
|
||||
|
||||
if inserted_id is None:
|
||||
# Conflict — another writer (legacy FE appendMessage,
|
||||
# resume retry, redelivered request) wrote the
|
||||
# (thread_id, turn_id, ASSISTANT) row first. Look it up
|
||||
# so the streaming layer can UPDATE the same row at
|
||||
# finalize time.
|
||||
lookup = await ws.execute(
|
||||
select(NewChatMessage.id).where(
|
||||
NewChatMessage.thread_id == chat_id,
|
||||
NewChatMessage.turn_id == turn_id,
|
||||
NewChatMessage.role == NewChatMessageRole.ASSISTANT,
|
||||
)
|
||||
)
|
||||
existing_id = lookup.scalars().first()
|
||||
if existing_id is None:
|
||||
logger.warning(
|
||||
"persist_assistant_shell: conflict but no matching "
|
||||
"(thread_id, turn_id, role) row found "
|
||||
"(chat_id=%s, turn_id=%s)",
|
||||
chat_id,
|
||||
turn_id,
|
||||
)
|
||||
outcome = "integrity_no_match"
|
||||
return None
|
||||
resolved_id = int(existing_id)
|
||||
outcome = "race_recovered"
|
||||
else:
|
||||
resolved_id = int(inserted_id)
|
||||
outcome = "inserted"
|
||||
|
||||
await ws.commit()
|
||||
return resolved_id
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"persist_assistant_shell failed (chat_id=%s, turn_id=%s)",
|
||||
chat_id,
|
||||
turn_id,
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
_perf_log.info(
|
||||
"[persist_assistant_shell] outcome=%s chat_id=%s turn_id=%s "
|
||||
"message_id=%s in %.3fs",
|
||||
outcome,
|
||||
chat_id,
|
||||
turn_id,
|
||||
resolved_id,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
|
||||
|
||||
async def finalize_assistant_turn(
|
||||
*,
|
||||
message_id: int,
|
||||
chat_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
turn_id: str,
|
||||
content: list[dict[str, Any]],
|
||||
accumulator: TurnTokenAccumulator | None,
|
||||
) -> None:
|
||||
"""Finalize the assistant row and write its token_usage.
|
||||
|
||||
Two writes in a single shielded session:
|
||||
|
||||
1. ``UPDATE new_chat_messages SET content = :c, updated_at = now()
|
||||
WHERE id = :id`` — overwrites the placeholder ``persist_assistant_shell``
|
||||
wrote with the full ``ContentPart[]`` snapshot produced server-side.
|
||||
2. ``INSERT INTO token_usage (...) VALUES (...) ON CONFLICT (message_id)
|
||||
WHERE message_id IS NOT NULL DO NOTHING`` — uses the partial unique
|
||||
index ``uq_token_usage_message_id`` from migration 142 to make the
|
||||
insert idempotent against ``append_message``'s recovery branch
|
||||
(which uses the same ON CONFLICT clause).
|
||||
|
||||
Substitutes the status-marker payload when ``content`` is empty
|
||||
(pure tool-call turn that aborted before any output, or interrupt
|
||||
before any event arrived). The status marker is preferable to a
|
||||
blank text bubble because token accounting still runs and an ops
|
||||
dashboard can flag the row.
|
||||
|
||||
Best-effort — never raises. The streaming ``finally`` calls this
|
||||
from within ``anyio.CancelScope(shield=True)``; any raised exception
|
||||
here would mask the real error that triggered the cleanup.
|
||||
"""
|
||||
if not turn_id:
|
||||
logger.error(
|
||||
"finalize_assistant_turn called without turn_id "
|
||||
"(chat_id=%s, message_id=%s); skipping",
|
||||
chat_id,
|
||||
message_id,
|
||||
)
|
||||
return
|
||||
if not message_id:
|
||||
logger.error(
|
||||
"finalize_assistant_turn called without message_id "
|
||||
"(chat_id=%s, turn_id=%s); skipping",
|
||||
chat_id,
|
||||
turn_id,
|
||||
)
|
||||
return
|
||||
|
||||
payload: list[dict[str, Any]]
|
||||
is_status_marker = False
|
||||
if content:
|
||||
payload = content
|
||||
else:
|
||||
payload = _STATUS_NO_RESPONSE
|
||||
is_status_marker = True
|
||||
|
||||
t0 = time.perf_counter()
|
||||
outcome = "failed"
|
||||
token_usage_attempted = bool(
|
||||
accumulator is not None and accumulator.calls and user_id
|
||||
)
|
||||
try:
|
||||
async with shielded_async_session() as ws:
|
||||
assistant_row = await ws.get(NewChatMessage, message_id)
|
||||
if assistant_row is None:
|
||||
logger.warning(
|
||||
"finalize_assistant_turn: row not found "
|
||||
"(chat_id=%s, message_id=%s, turn_id=%s); skipping",
|
||||
chat_id,
|
||||
message_id,
|
||||
turn_id,
|
||||
)
|
||||
outcome = "row_missing"
|
||||
return
|
||||
|
||||
assistant_row.content = payload
|
||||
assistant_row.updated_at = datetime.now(UTC)
|
||||
|
||||
# Token usage. ``record_token_usage`` (used elsewhere) does
|
||||
# SELECT-then-INSERT in two statements which races with
|
||||
# ``append_message``. Switch to a single INSERT ... ON
|
||||
# CONFLICT DO NOTHING keyed on the migration-142 partial
|
||||
# unique index so the loser silently drops its write at
|
||||
# the DB level — exactly one row per ``message_id``,
|
||||
# regardless of which session committed first.
|
||||
if accumulator is not None and accumulator.calls and user_id:
|
||||
try:
|
||||
user_uuid = UUID(user_id)
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(
|
||||
"finalize_assistant_turn: invalid user_id=%r, "
|
||||
"skipping token_usage row",
|
||||
user_id,
|
||||
)
|
||||
else:
|
||||
insert_stmt = (
|
||||
pg_insert(TokenUsage)
|
||||
.values(
|
||||
usage_type="chat",
|
||||
prompt_tokens=accumulator.total_prompt_tokens,
|
||||
completion_tokens=accumulator.total_completion_tokens,
|
||||
total_tokens=accumulator.grand_total,
|
||||
cost_micros=accumulator.total_cost_micros,
|
||||
model_breakdown=accumulator.per_message_summary(),
|
||||
call_details={"calls": accumulator.serialized_calls()},
|
||||
thread_id=chat_id,
|
||||
message_id=message_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_uuid,
|
||||
)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=["message_id"],
|
||||
index_where=sa_text("message_id IS NOT NULL"),
|
||||
)
|
||||
)
|
||||
await ws.execute(insert_stmt)
|
||||
|
||||
await ws.commit()
|
||||
outcome = "ok"
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"finalize_assistant_turn failed (chat_id=%s, message_id=%s, turn_id=%s)",
|
||||
chat_id,
|
||||
message_id,
|
||||
turn_id,
|
||||
)
|
||||
finally:
|
||||
_perf_log.info(
|
||||
"[finalize_assistant_turn] outcome=%s chat_id=%s message_id=%s "
|
||||
"turn_id=%s parts=%d status_marker=%s "
|
||||
"token_usage_attempted=%s in %.3fs",
|
||||
outcome,
|
||||
chat_id,
|
||||
message_id,
|
||||
turn_id,
|
||||
len(payload),
|
||||
is_status_marker,
|
||||
token_usage_attempted,
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
|
|
@ -25,7 +25,6 @@ from uuid import UUID
|
|||
|
||||
import anyio
|
||||
from langchain_core.messages import HumanMessage
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
|
|
@ -314,6 +313,19 @@ class StreamResult:
|
|||
verification_succeeded: bool = False
|
||||
commit_gate_passed: bool = True
|
||||
commit_gate_reason: str = ""
|
||||
# Pre-allocated assistant ``new_chat_messages.id`` for this turn,
|
||||
# captured by ``persist_assistant_shell`` right after the user row is
|
||||
# persisted. ``None`` for the legacy / anonymous code paths that don't
|
||||
# opt in to server-side ``ContentPart[]`` projection.
|
||||
assistant_message_id: int | None = None
|
||||
# In-memory mirror of the FE's assistant-ui ``ContentPartsState``,
|
||||
# populated by the lifecycle methods called from ``_stream_agent_events``
|
||||
# at each ``streaming_service.format_*`` yield site. Snapshot in the
|
||||
# streaming ``finally`` to produce the rich JSONB persisted by
|
||||
# ``finalize_assistant_turn``. ``repr=False`` keeps the
|
||||
# log-on-error path (``StreamResult`` is logged in some error
|
||||
# branches) from dumping a potentially-large parts list.
|
||||
content_builder: Any | None = field(default=None, repr=False)
|
||||
|
||||
|
||||
def _safe_float(value: Any, default: float = 0.0) -> float:
|
||||
|
|
@ -721,6 +733,7 @@ async def _stream_agent_events(
|
|||
fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
||||
fallback_commit_thread_id: int | None = None,
|
||||
runtime_context: Any = None,
|
||||
content_builder: Any | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Shared async generator that streams and formats astream_events from the agent.
|
||||
|
||||
|
|
@ -737,6 +750,15 @@ async def _stream_agent_events(
|
|||
initial_step_id: If set, the helper inherits an already-active thinking step.
|
||||
initial_step_title: Title of the inherited thinking step.
|
||||
initial_step_items: Items of the inherited thinking step.
|
||||
content_builder: Optional ``AssistantContentBuilder``. When set, every
|
||||
``streaming_service.format_*`` yield site also drives the matching
|
||||
builder lifecycle method (``on_text_*``, ``on_reasoning_*``,
|
||||
``on_tool_*``, ``on_thinking_step``, ``on_step_separator``) so the
|
||||
in-memory ``ContentPart[]`` projection stays in lockstep with what
|
||||
the FE renders live. Pure in-memory accumulation — no DB I/O —
|
||||
consumed by the streaming ``finally`` to produce the rich JSONB
|
||||
persisted via ``finalize_assistant_turn``. ``None`` (the default)
|
||||
is used by the anonymous / legacy code paths and is a no-op.
|
||||
|
||||
Yields:
|
||||
SSE-formatted strings for each event.
|
||||
|
|
@ -801,12 +823,46 @@ async def _stream_agent_events(
|
|||
current_lc_tool_call_id: dict[str, str | None] = {"value": None}
|
||||
|
||||
def _emit_tool_output(call_id: str, output: Any) -> str:
|
||||
# Drive the builder before formatting the SSE so the in-memory
|
||||
# ContentPart[] mirror sees the result attached to the same
|
||||
# card the FE will render. Builder method is a no-op when
|
||||
# ``content_builder`` is None (anonymous / legacy paths).
|
||||
if content_builder is not None:
|
||||
content_builder.on_tool_output_available(
|
||||
call_id, output, current_lc_tool_call_id["value"]
|
||||
)
|
||||
return streaming_service.format_tool_output_available(
|
||||
call_id,
|
||||
output,
|
||||
langchain_tool_call_id=current_lc_tool_call_id["value"],
|
||||
)
|
||||
|
||||
def _emit_thinking_step(
|
||||
*,
|
||||
step_id: str,
|
||||
title: str,
|
||||
status: str = "in_progress",
|
||||
items: list[str] | None = None,
|
||||
) -> str:
|
||||
"""Format a thinking-step SSE event and notify the builder.
|
||||
|
||||
Single helper used at every ``format_thinking_step`` yield site
|
||||
in this generator. Drives ``AssistantContentBuilder.on_thinking_step``
|
||||
first so the FE-mirror state lands the update before the SSE
|
||||
carrying the same data leaves the wire — order matches the FE
|
||||
pipeline (``processSharedStreamEvent`` updates state, then
|
||||
flushes). Builder call is a no-op when ``content_builder`` is
|
||||
None (anonymous / legacy paths).
|
||||
"""
|
||||
if content_builder is not None:
|
||||
content_builder.on_thinking_step(step_id, title, status, items)
|
||||
return streaming_service.format_thinking_step(
|
||||
step_id=step_id,
|
||||
title=title,
|
||||
status=status,
|
||||
items=items,
|
||||
)
|
||||
|
||||
def next_thinking_step_id() -> str:
|
||||
nonlocal thinking_step_counter
|
||||
thinking_step_counter += 1
|
||||
|
|
@ -816,7 +872,7 @@ async def _stream_agent_events(
|
|||
nonlocal last_active_step_id
|
||||
if last_active_step_id and last_active_step_id not in completed_step_ids:
|
||||
completed_step_ids.add(last_active_step_id)
|
||||
event = streaming_service.format_thinking_step(
|
||||
event = _emit_thinking_step(
|
||||
step_id=last_active_step_id,
|
||||
title=last_active_step_title,
|
||||
status="completed",
|
||||
|
|
@ -861,6 +917,8 @@ async def _stream_agent_events(
|
|||
if parity_v2 and reasoning_delta:
|
||||
if current_text_id is not None:
|
||||
yield streaming_service.format_text_end(current_text_id)
|
||||
if content_builder is not None:
|
||||
content_builder.on_text_end(current_text_id)
|
||||
current_text_id = None
|
||||
if current_reasoning_id is None:
|
||||
completion_event = complete_current_step()
|
||||
|
|
@ -873,13 +931,21 @@ async def _stream_agent_events(
|
|||
just_finished_tool = False
|
||||
current_reasoning_id = streaming_service.generate_reasoning_id()
|
||||
yield streaming_service.format_reasoning_start(current_reasoning_id)
|
||||
if content_builder is not None:
|
||||
content_builder.on_reasoning_start(current_reasoning_id)
|
||||
yield streaming_service.format_reasoning_delta(
|
||||
current_reasoning_id, reasoning_delta
|
||||
)
|
||||
if content_builder is not None:
|
||||
content_builder.on_reasoning_delta(
|
||||
current_reasoning_id, reasoning_delta
|
||||
)
|
||||
|
||||
if text_delta:
|
||||
if current_reasoning_id is not None:
|
||||
yield streaming_service.format_reasoning_end(current_reasoning_id)
|
||||
if content_builder is not None:
|
||||
content_builder.on_reasoning_end(current_reasoning_id)
|
||||
current_reasoning_id = None
|
||||
if current_text_id is None:
|
||||
completion_event = complete_current_step()
|
||||
|
|
@ -892,8 +958,12 @@ async def _stream_agent_events(
|
|||
just_finished_tool = False
|
||||
current_text_id = streaming_service.generate_text_id()
|
||||
yield streaming_service.format_text_start(current_text_id)
|
||||
if content_builder is not None:
|
||||
content_builder.on_text_start(current_text_id)
|
||||
yield streaming_service.format_text_delta(current_text_id, text_delta)
|
||||
accumulated_text += text_delta
|
||||
if content_builder is not None:
|
||||
content_builder.on_text_delta(current_text_id, text_delta)
|
||||
|
||||
# Live tool-call argument streaming. Runs AFTER text/reasoning
|
||||
# processing so chunks containing both stay in their natural
|
||||
|
|
@ -925,11 +995,17 @@ async def _stream_agent_events(
|
|||
# within the same stream window.
|
||||
if current_text_id is not None:
|
||||
yield streaming_service.format_text_end(current_text_id)
|
||||
if content_builder is not None:
|
||||
content_builder.on_text_end(current_text_id)
|
||||
current_text_id = None
|
||||
if current_reasoning_id is not None:
|
||||
yield streaming_service.format_reasoning_end(
|
||||
current_reasoning_id
|
||||
)
|
||||
if content_builder is not None:
|
||||
content_builder.on_reasoning_end(
|
||||
current_reasoning_id
|
||||
)
|
||||
current_reasoning_id = None
|
||||
|
||||
index_to_meta[idx] = {
|
||||
|
|
@ -942,6 +1018,8 @@ async def _stream_agent_events(
|
|||
name,
|
||||
langchain_tool_call_id=lc_id,
|
||||
)
|
||||
if content_builder is not None:
|
||||
content_builder.on_tool_input_start(ui_id, name, lc_id)
|
||||
|
||||
# Emit args delta for any chunk at a registered
|
||||
# index (including idless continuations). Once an
|
||||
|
|
@ -957,6 +1035,10 @@ async def _stream_agent_events(
|
|||
yield streaming_service.format_tool_input_delta(
|
||||
meta["ui_id"], args_chunk
|
||||
)
|
||||
if content_builder is not None:
|
||||
content_builder.on_tool_input_delta(
|
||||
meta["ui_id"], args_chunk
|
||||
)
|
||||
else:
|
||||
pending_tool_call_chunks.append(tcc)
|
||||
|
||||
|
|
@ -974,6 +1056,8 @@ async def _stream_agent_events(
|
|||
|
||||
if current_text_id is not None:
|
||||
yield streaming_service.format_text_end(current_text_id)
|
||||
if content_builder is not None:
|
||||
content_builder.on_text_end(current_text_id)
|
||||
current_text_id = None
|
||||
|
||||
if last_active_step_title != "Synthesizing response":
|
||||
|
|
@ -994,7 +1078,7 @@ async def _stream_agent_events(
|
|||
)
|
||||
last_active_step_title = "Listing files"
|
||||
last_active_step_items = [ls_path]
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Listing files",
|
||||
status="in_progress",
|
||||
|
|
@ -1009,7 +1093,7 @@ async def _stream_agent_events(
|
|||
display_fp = fp if len(fp) <= 80 else "…" + fp[-77:]
|
||||
last_active_step_title = "Reading file"
|
||||
last_active_step_items = [display_fp]
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Reading file",
|
||||
status="in_progress",
|
||||
|
|
@ -1024,7 +1108,7 @@ async def _stream_agent_events(
|
|||
display_fp = fp if len(fp) <= 80 else "…" + fp[-77:]
|
||||
last_active_step_title = "Writing file"
|
||||
last_active_step_items = [display_fp]
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Writing file",
|
||||
status="in_progress",
|
||||
|
|
@ -1039,7 +1123,7 @@ async def _stream_agent_events(
|
|||
display_fp = fp if len(fp) <= 80 else "…" + fp[-77:]
|
||||
last_active_step_title = "Editing file"
|
||||
last_active_step_items = [display_fp]
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Editing file",
|
||||
status="in_progress",
|
||||
|
|
@ -1056,7 +1140,7 @@ async def _stream_agent_events(
|
|||
)
|
||||
last_active_step_title = "Searching files"
|
||||
last_active_step_items = [f"{pat} in {base_path}"]
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Searching files",
|
||||
status="in_progress",
|
||||
|
|
@ -1076,7 +1160,7 @@ async def _stream_agent_events(
|
|||
last_active_step_items = [
|
||||
f'"{display_pat}"' + (f" in {grep_path}" if grep_path else "")
|
||||
]
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Searching content",
|
||||
status="in_progress",
|
||||
|
|
@ -1091,7 +1175,7 @@ async def _stream_agent_events(
|
|||
display_path = rm_path if len(rm_path) <= 80 else "…" + rm_path[-77:]
|
||||
last_active_step_title = "Deleting file"
|
||||
last_active_step_items = [display_path] if display_path else []
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Deleting file",
|
||||
status="in_progress",
|
||||
|
|
@ -1108,7 +1192,7 @@ async def _stream_agent_events(
|
|||
)
|
||||
last_active_step_title = "Deleting folder"
|
||||
last_active_step_items = [display_path] if display_path else []
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Deleting folder",
|
||||
status="in_progress",
|
||||
|
|
@ -1125,7 +1209,7 @@ async def _stream_agent_events(
|
|||
)
|
||||
last_active_step_title = "Creating folder"
|
||||
last_active_step_items = [display_path] if display_path else []
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Creating folder",
|
||||
status="in_progress",
|
||||
|
|
@ -1148,7 +1232,7 @@ async def _stream_agent_events(
|
|||
last_active_step_items = (
|
||||
[f"{display_src} → {display_dst}"] if src or dst else []
|
||||
)
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Moving file",
|
||||
status="in_progress",
|
||||
|
|
@ -1165,7 +1249,7 @@ async def _stream_agent_events(
|
|||
if todo_count
|
||||
else []
|
||||
)
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Planning tasks",
|
||||
status="in_progress",
|
||||
|
|
@ -1180,7 +1264,7 @@ async def _stream_agent_events(
|
|||
display_title = doc_title[:60] + ("…" if len(doc_title) > 60 else "")
|
||||
last_active_step_title = "Saving document"
|
||||
last_active_step_items = [display_title]
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Saving document",
|
||||
status="in_progress",
|
||||
|
|
@ -1196,7 +1280,7 @@ async def _stream_agent_events(
|
|||
last_active_step_items = [
|
||||
f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}"
|
||||
]
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Generating image",
|
||||
status="in_progress",
|
||||
|
|
@ -1212,7 +1296,7 @@ async def _stream_agent_events(
|
|||
last_active_step_items = [
|
||||
f"URL: {url[:80]}{'...' if len(url) > 80 else ''}"
|
||||
]
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Scraping webpage",
|
||||
status="in_progress",
|
||||
|
|
@ -1235,7 +1319,7 @@ async def _stream_agent_events(
|
|||
f"Content: {content_len:,} characters",
|
||||
"Preparing audio generation...",
|
||||
]
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Generating podcast",
|
||||
status="in_progress",
|
||||
|
|
@ -1256,7 +1340,7 @@ async def _stream_agent_events(
|
|||
f"Topic: {report_topic}",
|
||||
"Analyzing source content...",
|
||||
]
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title=step_title,
|
||||
status="in_progress",
|
||||
|
|
@ -1271,7 +1355,7 @@ async def _stream_agent_events(
|
|||
display_cmd = cmd[:80] + ("…" if len(cmd) > 80 else "")
|
||||
last_active_step_title = "Running command"
|
||||
last_active_step_items = [f"$ {display_cmd}"]
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title="Running command",
|
||||
status="in_progress",
|
||||
|
|
@ -1288,7 +1372,7 @@ async def _stream_agent_events(
|
|||
tool_name.replace("_", " ").strip().capitalize() or tool_name
|
||||
)
|
||||
last_active_step_items = []
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=tool_step_id,
|
||||
title=last_active_step_title,
|
||||
status="in_progress",
|
||||
|
|
@ -1349,6 +1433,10 @@ async def _stream_agent_events(
|
|||
tool_name,
|
||||
langchain_tool_call_id=langchain_tool_call_id,
|
||||
)
|
||||
if content_builder is not None:
|
||||
content_builder.on_tool_input_start(
|
||||
tool_call_id, tool_name, langchain_tool_call_id
|
||||
)
|
||||
|
||||
if run_id:
|
||||
ui_tool_call_id_by_run[run_id] = tool_call_id
|
||||
|
|
@ -1371,6 +1459,13 @@ async def _stream_agent_events(
|
|||
_safe_input,
|
||||
langchain_tool_call_id=langchain_tool_call_id,
|
||||
)
|
||||
if content_builder is not None:
|
||||
content_builder.on_tool_input_available(
|
||||
tool_call_id,
|
||||
tool_name,
|
||||
_safe_input,
|
||||
langchain_tool_call_id,
|
||||
)
|
||||
|
||||
elif event_type == "on_tool_end":
|
||||
active_tool_depth = max(0, active_tool_depth - 1)
|
||||
|
|
@ -1443,70 +1538,70 @@ async def _stream_agent_events(
|
|||
current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id]
|
||||
|
||||
if tool_name == "read_file":
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Reading file",
|
||||
status="completed",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "write_file":
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Writing file",
|
||||
status="completed",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "edit_file":
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Editing file",
|
||||
status="completed",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "glob":
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Searching files",
|
||||
status="completed",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "grep":
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Searching content",
|
||||
status="completed",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "rm":
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Deleting file",
|
||||
status="completed",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "rmdir":
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Deleting folder",
|
||||
status="completed",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "mkdir":
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Creating folder",
|
||||
status="completed",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "move_file":
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Moving file",
|
||||
status="completed",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "write_todos":
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Planning tasks",
|
||||
status="completed",
|
||||
|
|
@ -1523,7 +1618,7 @@ async def _stream_agent_events(
|
|||
*last_active_step_items,
|
||||
result_str[:80] if is_error else "Saved to knowledge base",
|
||||
]
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Saving document",
|
||||
status="completed",
|
||||
|
|
@ -1542,7 +1637,7 @@ async def _stream_agent_events(
|
|||
else "Generation failed"
|
||||
)
|
||||
completed_items = [*last_active_step_items, f"Error: {error_msg}"]
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Generating image",
|
||||
status="completed",
|
||||
|
|
@ -1566,7 +1661,7 @@ async def _stream_agent_events(
|
|||
]
|
||||
else:
|
||||
completed_items = [*last_active_step_items, "Content extracted"]
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Scraping webpage",
|
||||
status="completed",
|
||||
|
|
@ -1612,7 +1707,7 @@ async def _stream_agent_events(
|
|||
]
|
||||
else:
|
||||
completed_items = last_active_step_items
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Generating podcast",
|
||||
status="completed",
|
||||
|
|
@ -1647,7 +1742,7 @@ async def _stream_agent_events(
|
|||
]
|
||||
else:
|
||||
completed_items = last_active_step_items
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Generating video presentation",
|
||||
status="completed",
|
||||
|
|
@ -1695,7 +1790,7 @@ async def _stream_agent_events(
|
|||
else:
|
||||
completed_items = last_active_step_items
|
||||
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title=step_title,
|
||||
status="completed",
|
||||
|
|
@ -1721,7 +1816,7 @@ async def _stream_agent_events(
|
|||
]
|
||||
else:
|
||||
completed_items = [*last_active_step_items, "Finished"]
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Running command",
|
||||
status="completed",
|
||||
|
|
@ -1761,7 +1856,7 @@ async def _stream_agent_events(
|
|||
completed_items.append(f"(+{len(file_names) - 4} more)")
|
||||
else:
|
||||
completed_items = ["No files found"]
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title="Listing files",
|
||||
status="completed",
|
||||
|
|
@ -1773,7 +1868,7 @@ async def _stream_agent_events(
|
|||
fallback_title = (
|
||||
tool_name.replace("_", " ").strip().capitalize() or tool_name
|
||||
)
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=original_step_id,
|
||||
title=fallback_title,
|
||||
status="completed",
|
||||
|
|
@ -2113,7 +2208,7 @@ async def _stream_agent_events(
|
|||
# Phase transitions: replace everything after topic
|
||||
last_active_step_items = [*topic_items, message]
|
||||
|
||||
yield streaming_service.format_thinking_step(
|
||||
yield _emit_thinking_step(
|
||||
step_id=last_active_step_id,
|
||||
title=last_active_step_title,
|
||||
status="in_progress",
|
||||
|
|
@ -2155,10 +2250,14 @@ async def _stream_agent_events(
|
|||
elif event_type in ("on_chain_end", "on_agent_end"):
|
||||
if current_text_id is not None:
|
||||
yield streaming_service.format_text_end(current_text_id)
|
||||
if content_builder is not None:
|
||||
content_builder.on_text_end(current_text_id)
|
||||
current_text_id = None
|
||||
|
||||
if current_text_id is not None:
|
||||
yield streaming_service.format_text_end(current_text_id)
|
||||
if content_builder is not None:
|
||||
content_builder.on_text_end(current_text_id)
|
||||
|
||||
completion_event = complete_current_step()
|
||||
if completion_event:
|
||||
|
|
@ -2243,8 +2342,14 @@ async def _stream_agent_events(
|
|||
)
|
||||
gate_text_id = streaming_service.generate_text_id()
|
||||
yield streaming_service.format_text_start(gate_text_id)
|
||||
if content_builder is not None:
|
||||
content_builder.on_text_start(gate_text_id)
|
||||
yield streaming_service.format_text_delta(gate_text_id, gate_notice)
|
||||
if content_builder is not None:
|
||||
content_builder.on_text_delta(gate_text_id, gate_notice)
|
||||
yield streaming_service.format_text_end(gate_text_id)
|
||||
if content_builder is not None:
|
||||
content_builder.on_text_end(gate_text_id)
|
||||
yield streaming_service.format_terminal_info(gate_notice, "error")
|
||||
accumulated_text = gate_notice
|
||||
else:
|
||||
|
|
@ -2270,6 +2375,7 @@ async def stream_new_chat(
|
|||
llm_config_id: int = -1,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
mentioned_surfsense_doc_ids: list[int] | None = None,
|
||||
mentioned_documents: list[dict[str, Any]] | None = None,
|
||||
checkpoint_id: str | None = None,
|
||||
needs_history_bootstrap: bool = False,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
|
|
@ -2949,6 +3055,96 @@ async def stream_new_chat(
|
|||
)
|
||||
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||
|
||||
# Persist the user-side row for this turn before any expensive
|
||||
# work runs. Closes the "ghost-thread" abuse vector
|
||||
# (authenticated client hits POST /new_chat then never calls
|
||||
# /messages — empty new_chat_messages, free LLM completion).
|
||||
# Idempotent against the unique index in migration 141 so the
|
||||
# legacy frontend appendMessage call is a no-op on the second
|
||||
# writer. Hard failure aborts the turn so we never produce a
|
||||
# title or assistant row that isn't anchored to a persisted
|
||||
# user message.
|
||||
from app.tasks.chat.content_builder import AssistantContentBuilder
|
||||
from app.tasks.chat.persistence import (
|
||||
persist_assistant_shell,
|
||||
persist_user_turn,
|
||||
)
|
||||
|
||||
user_message_id = await persist_user_turn(
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
turn_id=stream_result.turn_id,
|
||||
user_query=user_query,
|
||||
user_image_data_urls=user_image_data_urls,
|
||||
mentioned_documents=mentioned_documents,
|
||||
)
|
||||
if user_message_id is None:
|
||||
yield _emit_stream_error(
|
||||
message=(
|
||||
"We couldn't save your message. Please try again in a moment."
|
||||
),
|
||||
error_kind="server_error",
|
||||
error_code="MESSAGE_PERSIST_FAILED",
|
||||
)
|
||||
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
# Emit canonical user message id BEFORE any LLM streaming so the
|
||||
# FE can rename its optimistic ``msg-user-XXX`` placeholder to
|
||||
# ``msg-{user_message_id}`` and unlock features gated on a real
|
||||
# DB id (comments, edit-from-this-message). See B4 in
|
||||
# ``sse-based_message_id_handshake`` plan.
|
||||
yield streaming_service.format_data(
|
||||
"user-message-id",
|
||||
{"message_id": user_message_id, "turn_id": stream_result.turn_id},
|
||||
)
|
||||
|
||||
# Pre-write the assistant row for this turn so we have a stable
|
||||
# ``message_id`` to anchor mid-stream metadata (token_usage,
|
||||
# future agent_action_log.message_id correlation) and a
|
||||
# write-once UPDATE target at finalize time. Idempotent against
|
||||
# the (thread_id, turn_id, ASSISTANT) partial unique index from
|
||||
# migration 141 — if the legacy frontend appendMessage races
|
||||
# this, we recover the existing row's id.
|
||||
assistant_message_id = await persist_assistant_shell(
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
turn_id=stream_result.turn_id,
|
||||
)
|
||||
if assistant_message_id is None:
|
||||
# Genuine DB failure — abort the turn rather than stream
|
||||
# into a void. The user row is already persisted so the
|
||||
# legacy "ghost-thread" gate isn't reopened.
|
||||
yield _emit_stream_error(
|
||||
message=(
|
||||
"We couldn't initialize the assistant message. Please try again."
|
||||
),
|
||||
error_kind="server_error",
|
||||
error_code="MESSAGE_PERSIST_FAILED",
|
||||
)
|
||||
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
# Emit canonical assistant message id BEFORE any LLM streaming
|
||||
# so the FE can rename its optimistic ``msg-assistant-XXX``
|
||||
# placeholder to ``msg-{assistant_message_id}`` and bind
|
||||
# ``tokenUsageStore`` / ``pendingInterrupt`` to the real id
|
||||
# immediately. See B4 in ``sse-based_message_id_handshake``
|
||||
# plan.
|
||||
yield streaming_service.format_data(
|
||||
"assistant-message-id",
|
||||
{"message_id": assistant_message_id, "turn_id": stream_result.turn_id},
|
||||
)
|
||||
|
||||
stream_result.assistant_message_id = assistant_message_id
|
||||
stream_result.content_builder = AssistantContentBuilder()
|
||||
|
||||
# Initial thinking step - analyzing the request
|
||||
if mentioned_surfsense_docs:
|
||||
initial_title = "Analyzing referenced content"
|
||||
|
|
@ -2981,6 +3177,15 @@ async def stream_new_chat(
|
|||
initial_items = [f"{action_verb}: {' '.join(processing_parts)}"]
|
||||
initial_step_id = "thinking-1"
|
||||
|
||||
# Drive the builder for this initial thinking step too — the
|
||||
# ``_emit_thinking_step`` helper lives inside ``_stream_agent_events``
|
||||
# so it isn't in scope here, but the FE folds this step into
|
||||
# the same singleton ``data-thinking-steps`` part as everything
|
||||
# the agent stream emits later. Mirror that fold server-side.
|
||||
if stream_result.content_builder is not None:
|
||||
stream_result.content_builder.on_thinking_step(
|
||||
initial_step_id, initial_title, "in_progress", initial_items
|
||||
)
|
||||
yield streaming_service.format_thinking_step(
|
||||
step_id=initial_step_id,
|
||||
title=initial_title,
|
||||
|
|
@ -2997,16 +3202,34 @@ async def stream_new_chat(
|
|||
# Check if this is the first assistant response so we can generate
|
||||
# a title in parallel with the agent stream (better UX than waiting
|
||||
# until after the full response).
|
||||
assistant_count_result = await session.execute(
|
||||
select(func.count(NewChatMessage.id)).filter(
|
||||
# Use a LIMIT 1 EXISTS-style probe rather than COUNT(*) because
|
||||
# this is now a hot path executed on every turn, and COUNT scales
|
||||
# with thread length (server-side persistence can grow rows
|
||||
# quickly under power users).
|
||||
#
|
||||
# IMPORTANT: ``persist_assistant_shell`` above (line ~3112) already
|
||||
# inserted THIS turn's assistant row. We must therefore exclude
|
||||
# it from the probe — otherwise the gate fires on every turn
|
||||
# except the very first, and title generation never runs for new
|
||||
# threads. Excluding by primary key (``id != assistant_message_id``)
|
||||
# is bulletproof regardless of ``turn_id`` shape (legacy NULLs,
|
||||
# resume turns, etc.).
|
||||
first_assistant_probe = await session.execute(
|
||||
select(NewChatMessage.id)
|
||||
.filter(
|
||||
NewChatMessage.thread_id == chat_id,
|
||||
NewChatMessage.role == "assistant",
|
||||
NewChatMessage.id != assistant_message_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
is_first_response = (assistant_count_result.scalar() or 0) == 0
|
||||
is_first_response = first_assistant_probe.scalars().first() is None
|
||||
|
||||
title_task: asyncio.Task[tuple[str | None, dict | None]] | None = None
|
||||
if is_first_response:
|
||||
# Gate title generation on a persisted user message so a stream
|
||||
# that fails before persistence (we abort above) can never leave
|
||||
# behind a thread with a generated title and no anchoring rows.
|
||||
if is_first_response and user_message_id is not None:
|
||||
|
||||
async def _generate_title() -> tuple[str | None, dict | None]:
|
||||
"""Generate a short title via litellm.acompletion.
|
||||
|
|
@ -3138,6 +3361,7 @@ async def stream_new_chat(
|
|||
),
|
||||
fallback_commit_thread_id=chat_id,
|
||||
runtime_context=runtime_context,
|
||||
content_builder=stream_result.content_builder,
|
||||
):
|
||||
if not _first_event_logged:
|
||||
_perf_log.info(
|
||||
|
|
@ -3493,6 +3717,81 @@ async def stream_new_chat(
|
|||
with contextlib.suppress(Exception):
|
||||
await session.close()
|
||||
|
||||
# Server-side assistant-message + token_usage finalization.
|
||||
# Runs after the main session has been closed (uses its own
|
||||
# shielded session) so we don't fight the same DB connection.
|
||||
# Idempotent against the legacy frontend appendMessage:
|
||||
# * the assistant row was already INSERTed by
|
||||
# ``persist_assistant_shell`` above, so this just UPDATEs
|
||||
# it with the rich ContentPart[] from the builder.
|
||||
# * token_usage uses INSERT ... ON CONFLICT DO NOTHING
|
||||
# against migration 142's partial unique index, so a
|
||||
# racing append_message recovery branch can never
|
||||
# double-write.
|
||||
# ``mark_interrupted`` closes any open text/reasoning blocks
|
||||
# and flips running tool-calls (no result) to state=aborted
|
||||
# so the persisted JSONB reflects a coherent end-state even
|
||||
# on client disconnect.
|
||||
# Never raises (best-effort, logs only).
|
||||
if (
|
||||
stream_result
|
||||
and stream_result.turn_id
|
||||
and stream_result.assistant_message_id
|
||||
):
|
||||
from app.tasks.chat.persistence import finalize_assistant_turn
|
||||
|
||||
builder_stats: dict[str, int] | None = None
|
||||
if stream_result.content_builder is not None:
|
||||
stream_result.content_builder.mark_interrupted()
|
||||
# Snapshot stats BEFORE deepcopy in ``snapshot()`` so
|
||||
# the perf log records the actual finalised payload
|
||||
# (post-mark_interrupted), not the live-mutating
|
||||
# builder state.
|
||||
builder_stats = stream_result.content_builder.stats()
|
||||
content_payload = stream_result.content_builder.snapshot()
|
||||
else:
|
||||
# Defensive fallback — we always set the builder
|
||||
# alongside ``assistant_message_id`` above, so this
|
||||
# branch only fires if a future refactor ever
|
||||
# decouples them. Persist whatever accumulated
|
||||
# text we captured so the row at least renders.
|
||||
content_payload = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": stream_result.accumulated_text or "",
|
||||
}
|
||||
]
|
||||
|
||||
if builder_stats is not None:
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] finalize_payload chat_id=%s "
|
||||
"message_id=%s parts=%d bytes=%d text=%d "
|
||||
"reasoning=%d tool_calls=%d "
|
||||
"tool_calls_completed=%d tool_calls_aborted=%d "
|
||||
"thinking_step_parts=%d step_separators=%d",
|
||||
chat_id,
|
||||
stream_result.assistant_message_id,
|
||||
builder_stats["parts"],
|
||||
builder_stats["bytes"],
|
||||
builder_stats["text"],
|
||||
builder_stats["reasoning"],
|
||||
builder_stats["tool_calls"],
|
||||
builder_stats["tool_calls_completed"],
|
||||
builder_stats["tool_calls_aborted"],
|
||||
builder_stats["thinking_step_parts"],
|
||||
builder_stats["step_separators"],
|
||||
)
|
||||
|
||||
await finalize_assistant_turn(
|
||||
message_id=stream_result.assistant_message_id,
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
turn_id=stream_result.turn_id,
|
||||
content=content_payload,
|
||||
accumulator=accumulator,
|
||||
)
|
||||
|
||||
# Persist any sandbox-produced files to local storage so they
|
||||
# remain downloadable after the Daytona sandbox auto-deletes.
|
||||
if stream_result and stream_result.sandbox_files:
|
||||
|
|
@ -3937,6 +4236,50 @@ async def stream_resume_chat(
|
|||
)
|
||||
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||
|
||||
# Pre-write a fresh assistant row for this resume turn. The
|
||||
# original (interrupted) ``stream_new_chat`` invocation already
|
||||
# persisted its own assistant row anchored to a different
|
||||
# ``turn_id``; resume allocates a new ``turn_id`` (above) so we
|
||||
# need a separate row keyed on the same ``(thread_id, turn_id,
|
||||
# ASSISTANT)`` invariant. Idempotent against migration 141's
|
||||
# partial unique index — recovers existing id on retry.
|
||||
from app.tasks.chat.content_builder import AssistantContentBuilder
|
||||
from app.tasks.chat.persistence import persist_assistant_shell
|
||||
|
||||
assistant_message_id = await persist_assistant_shell(
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
turn_id=stream_result.turn_id,
|
||||
)
|
||||
if assistant_message_id is None:
|
||||
yield _emit_stream_error(
|
||||
message=(
|
||||
"We couldn't initialize the assistant message. Please try again."
|
||||
),
|
||||
error_kind="server_error",
|
||||
error_code="MESSAGE_PERSIST_FAILED",
|
||||
)
|
||||
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
# Emit canonical assistant message id BEFORE any LLM streaming
|
||||
# so the FE can rename ``pendingInterrupt.assistantMsgId`` to
|
||||
# ``msg-{assistant_message_id}`` immediately. Resume does NOT
|
||||
# emit ``data-user-message-id`` because the user row is from
|
||||
# the original interrupted turn (different ``turn_id``) and is
|
||||
# never re-persisted here. See B5 in the
|
||||
# ``sse-based_message_id_handshake`` plan.
|
||||
yield streaming_service.format_data(
|
||||
"assistant-message-id",
|
||||
{"message_id": assistant_message_id, "turn_id": stream_result.turn_id},
|
||||
)
|
||||
|
||||
stream_result.assistant_message_id = assistant_message_id
|
||||
stream_result.content_builder = AssistantContentBuilder()
|
||||
|
||||
# Resume path doesn't carry new ``mentioned_document_ids`` —
|
||||
# those are seeded in the original turn. We still pass a
|
||||
# context so future middleware extensions (Phase 2) can rely on
|
||||
|
|
@ -3968,6 +4311,7 @@ async def stream_resume_chat(
|
|||
),
|
||||
fallback_commit_thread_id=chat_id,
|
||||
runtime_context=runtime_context,
|
||||
content_builder=stream_result.content_builder,
|
||||
):
|
||||
if not _first_event_logged:
|
||||
_perf_log.info(
|
||||
|
|
@ -4219,6 +4563,64 @@ async def stream_resume_chat(
|
|||
with contextlib.suppress(Exception):
|
||||
await session.close()
|
||||
|
||||
# Server-side assistant-message + token_usage finalization for
|
||||
# the resume flow. The original user message was persisted by
|
||||
# the original (interrupted) ``stream_new_chat`` invocation;
|
||||
# the resume's own ``persist_assistant_shell`` write lives at
|
||||
# the new ``turn_id`` above. This finalize updates that row
|
||||
# with the rich ContentPart[] from the builder and writes
|
||||
# token_usage idempotently via migration 142's partial
|
||||
# unique index. Best-effort, never raises.
|
||||
if (
|
||||
stream_result
|
||||
and stream_result.turn_id
|
||||
and stream_result.assistant_message_id
|
||||
):
|
||||
from app.tasks.chat.persistence import finalize_assistant_turn
|
||||
|
||||
builder_stats: dict[str, int] | None = None
|
||||
if stream_result.content_builder is not None:
|
||||
stream_result.content_builder.mark_interrupted()
|
||||
builder_stats = stream_result.content_builder.stats()
|
||||
content_payload = stream_result.content_builder.snapshot()
|
||||
else:
|
||||
content_payload = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": stream_result.accumulated_text or "",
|
||||
}
|
||||
]
|
||||
|
||||
if builder_stats is not None:
|
||||
_perf_log.info(
|
||||
"[stream_resume] finalize_payload chat_id=%s "
|
||||
"message_id=%s parts=%d bytes=%d text=%d "
|
||||
"reasoning=%d tool_calls=%d "
|
||||
"tool_calls_completed=%d tool_calls_aborted=%d "
|
||||
"thinking_step_parts=%d step_separators=%d",
|
||||
chat_id,
|
||||
stream_result.assistant_message_id,
|
||||
builder_stats["parts"],
|
||||
builder_stats["bytes"],
|
||||
builder_stats["text"],
|
||||
builder_stats["reasoning"],
|
||||
builder_stats["tool_calls"],
|
||||
builder_stats["tool_calls_completed"],
|
||||
builder_stats["tool_calls_aborted"],
|
||||
builder_stats["thinking_step_parts"],
|
||||
builder_stats["step_separators"],
|
||||
)
|
||||
|
||||
await finalize_assistant_turn(
|
||||
message_id=stream_result.assistant_message_id,
|
||||
chat_id=chat_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
turn_id=stream_result.turn_id,
|
||||
content=content_payload,
|
||||
accumulator=accumulator,
|
||||
)
|
||||
|
||||
agent = llm = connector_service = None
|
||||
stream_result = None
|
||||
session = None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue