mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-05 22:02:39 +02:00
Merge pull request #1341 from MODSetter/dev
feat: moved chat persistance to Server Side
This commit is contained in:
commit
743eff42cd
19 changed files with 4515 additions and 390 deletions
31
.vscode/launch.json
vendored
31
.vscode/launch.json
vendored
|
|
@ -26,7 +26,16 @@
|
||||||
"pythonArgs": [
|
"pythonArgs": [
|
||||||
"run",
|
"run",
|
||||||
"python"
|
"python"
|
||||||
]
|
],
|
||||||
|
// Mute LangGraph/Pydantic checkpoint serializer warnings
|
||||||
|
// (UserWarnings emitted from pydantic/main.py when the
|
||||||
|
// runtime snapshots a SurfSenseContextSchema into a field
|
||||||
|
// typed `None`) so the debugger's "Raised Exceptions"
|
||||||
|
// breakpoint doesn't pause on a known-harmless event.
|
||||||
|
// Production logs are unaffected.
|
||||||
|
"env": {
|
||||||
|
"PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Backend: FastAPI (No Reload)",
|
"name": "Backend: FastAPI (No Reload)",
|
||||||
|
|
@ -40,7 +49,10 @@
|
||||||
"pythonArgs": [
|
"pythonArgs": [
|
||||||
"run",
|
"run",
|
||||||
"python"
|
"python"
|
||||||
]
|
],
|
||||||
|
"env": {
|
||||||
|
"PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Backend: FastAPI (main.py)",
|
"name": "Backend: FastAPI (main.py)",
|
||||||
|
|
@ -54,7 +66,10 @@
|
||||||
"pythonArgs": [
|
"pythonArgs": [
|
||||||
"run",
|
"run",
|
||||||
"python"
|
"python"
|
||||||
]
|
],
|
||||||
|
"env": {
|
||||||
|
"PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Frontend: Next.js",
|
"name": "Frontend: Next.js",
|
||||||
|
|
@ -104,7 +119,10 @@
|
||||||
"pythonArgs": [
|
"pythonArgs": [
|
||||||
"run",
|
"run",
|
||||||
"python"
|
"python"
|
||||||
]
|
],
|
||||||
|
"env": {
|
||||||
|
"PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Celery: Beat Scheduler",
|
"name": "Celery: Beat Scheduler",
|
||||||
|
|
@ -124,7 +142,10 @@
|
||||||
"pythonArgs": [
|
"pythonArgs": [
|
||||||
"run",
|
"run",
|
||||||
"python"
|
"python"
|
||||||
]
|
],
|
||||||
|
"env": {
|
||||||
|
"PYTHONWARNINGS": "ignore::UserWarning:pydantic.main"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"compounds": [
|
"compounds": [
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,66 @@
|
||||||
|
"""141_unique_chat_message_turn_role
|
||||||
|
|
||||||
|
Revision ID: 141
|
||||||
|
Revises: 140
|
||||||
|
Create Date: 2026-05-04
|
||||||
|
|
||||||
|
Add a partial unique index on ``new_chat_messages(thread_id, turn_id, role)``
|
||||||
|
where ``turn_id IS NOT NULL``.
|
||||||
|
|
||||||
|
Why
|
||||||
|
---
|
||||||
|
The streaming chat path (`stream_new_chat` / `stream_resume_chat`) is being
|
||||||
|
moved to write its own ``new_chat_messages`` rows server-side instead of
|
||||||
|
relying on the frontend's later ``POST /threads/{id}/messages`` call. This
|
||||||
|
closes the "ghost-thread" abuse vector where authenticated callers got free
|
||||||
|
LLM completions while ``new_chat_messages`` stayed empty.
|
||||||
|
|
||||||
|
For server-side and legacy frontend writes to coexist we need an idempotency
|
||||||
|
key. The natural triple is ``(thread_id, turn_id, role)``: the server issues
|
||||||
|
exactly one ``turn_id`` per turn, and a turn produces at most one user
|
||||||
|
message and one assistant message. Whichever side wins the race writes the
|
||||||
|
row; the loser hits ``IntegrityError`` and recovers gracefully.
|
||||||
|
|
||||||
|
Partial — ``WHERE turn_id IS NOT NULL`` — so:
|
||||||
|
|
||||||
|
* Legacy rows that predate the ``turn_id`` column (migration 136) keep
|
||||||
|
co-existing without de-dup.
|
||||||
|
* Clone / snapshot inserts in
|
||||||
|
``app/services/public_chat_service.py`` that build ``NewChatMessage``
|
||||||
|
without ``turn_id`` are unaffected (multiple snapshot copies of the same
|
||||||
|
user/assistant pair are intentional).
|
||||||
|
|
||||||
|
This index coexists with the existing single-column ``ix_new_chat_messages_turn_id``
|
||||||
|
from migration 136 — no collision.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "141"
|
||||||
|
down_revision: str | None = "140"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
INDEX_NAME = "uq_new_chat_messages_thread_turn_role"
|
||||||
|
TABLE_NAME = "new_chat_messages"
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_index(
|
||||||
|
INDEX_NAME,
|
||||||
|
TABLE_NAME,
|
||||||
|
["thread_id", "turn_id", "role"],
|
||||||
|
unique=True,
|
||||||
|
postgresql_where=sa.text("turn_id IS NOT NULL"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index(INDEX_NAME, table_name=TABLE_NAME)
|
||||||
|
|
@ -0,0 +1,134 @@
|
||||||
|
"""142_token_usage_message_id_unique
|
||||||
|
|
||||||
|
Revision ID: 142
|
||||||
|
Revises: 141
|
||||||
|
Create Date: 2026-05-04
|
||||||
|
|
||||||
|
Add a partial unique index on ``token_usage(message_id)`` where
|
||||||
|
``message_id IS NOT NULL``.
|
||||||
|
|
||||||
|
Why
|
||||||
|
---
|
||||||
|
Two writers can race on the same assistant turn's ``token_usage`` row:
|
||||||
|
|
||||||
|
* ``finalize_assistant_turn`` (server-side, called from the streaming
|
||||||
|
finally block in ``stream_new_chat`` / ``stream_resume_chat``)
|
||||||
|
* ``append_message``'s recovery branch in
|
||||||
|
``app/routes/new_chat_routes.py`` (legacy frontend round-trip)
|
||||||
|
|
||||||
|
Both currently use ``SELECT ... THEN INSERT`` in separate sessions, so a
|
||||||
|
micro-second-aligned race could observe "no row" on each side and double
|
||||||
|
INSERT, producing duplicate ``token_usage`` rows for the same
|
||||||
|
``message_id``.
|
||||||
|
|
||||||
|
A partial unique index on ``message_id`` (``WHERE message_id IS NOT NULL``)
|
||||||
|
turns both writes into ``INSERT ... ON CONFLICT (message_id) DO NOTHING``
|
||||||
|
no-ops for the loser, hard-eliminating the race at the DB level. Partial
|
||||||
|
because non-chat usage rows (indexing, image generation, podcasts) keep
|
||||||
|
``message_id`` NULL — they're per-event, no de-dup needed.
|
||||||
|
|
||||||
|
Pre-flight
|
||||||
|
----------
|
||||||
|
Today's schema only has a non-unique index on ``message_id`` so a
|
||||||
|
duplicate population could already exist from any past race. We:
|
||||||
|
|
||||||
|
* Detect duplicate ``message_id`` groups (``HAVING COUNT(*) > 1``).
|
||||||
|
* If the group count is at or below ``DUPLICATE_ABORT_THRESHOLD`` (50)
|
||||||
|
we dedupe by deleting all but the smallest ``id`` per group.
|
||||||
|
* If the count exceeds the threshold we abort with a descriptive
|
||||||
|
error rather than silently mutate prod data — operator must
|
||||||
|
investigate before retrying.
|
||||||
|
|
||||||
|
Concurrency
|
||||||
|
-----------
|
||||||
|
``CREATE INDEX CONCURRENTLY`` is required on this hot table to avoid
|
||||||
|
stalling production writes during deploy (a regular ``CREATE INDEX``
|
||||||
|
holds an ACCESS EXCLUSIVE lock for the duration of the build, which
|
||||||
|
would block ``token_usage`` INSERTs for every active streaming chat).
|
||||||
|
The trade-off is a slower migration (CONCURRENTLY scans the table
|
||||||
|
twice) and the ``CREATE`` statement cannot run inside alembic's default
|
||||||
|
transaction wrapper — ``autocommit_block()`` handles that.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "142"
|
||||||
|
down_revision: str | None = "141"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
INDEX_NAME = "uq_token_usage_message_id"
|
||||||
|
TABLE_NAME = "token_usage"
|
||||||
|
|
||||||
|
# Refuse to silently mutate prod data if the duplicate population is
|
||||||
|
# unexpectedly large — operator should investigate the upstream cause
|
||||||
|
# before retrying. 50 is comfortably above any plausible duplicate
|
||||||
|
# count from the existing race window (the race is microseconds wide).
|
||||||
|
DUPLICATE_ABORT_THRESHOLD = 50
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
conn = op.get_bind()
|
||||||
|
|
||||||
|
dup_groups = conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"SELECT message_id, COUNT(*) AS n "
|
||||||
|
"FROM token_usage "
|
||||||
|
"WHERE message_id IS NOT NULL "
|
||||||
|
"GROUP BY message_id "
|
||||||
|
"HAVING COUNT(*) > 1"
|
||||||
|
)
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
if len(dup_groups) > DUPLICATE_ABORT_THRESHOLD:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"token_usage has {len(dup_groups)} duplicate message_id groups "
|
||||||
|
f"(threshold={DUPLICATE_ABORT_THRESHOLD}). "
|
||||||
|
"Resolve the duplicates manually before re-running this migration."
|
||||||
|
)
|
||||||
|
|
||||||
|
if dup_groups:
|
||||||
|
# Delete all but the smallest-id row per duplicate group. The
|
||||||
|
# smallest id is by definition the earliest insert, so we keep
|
||||||
|
# the row most likely to reflect the actual stream's first
|
||||||
|
# successful write.
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"""
|
||||||
|
DELETE FROM token_usage
|
||||||
|
WHERE id IN (
|
||||||
|
SELECT id FROM (
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
row_number() OVER (
|
||||||
|
PARTITION BY message_id ORDER BY id ASC
|
||||||
|
) AS rn
|
||||||
|
FROM token_usage
|
||||||
|
WHERE message_id IS NOT NULL
|
||||||
|
) ranked
|
||||||
|
WHERE rn > 1
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# CREATE INDEX CONCURRENTLY cannot run inside a transaction. Drop
|
||||||
|
# alembic's auto-transaction for this op only.
|
||||||
|
with op.get_context().autocommit_block():
|
||||||
|
op.execute(
|
||||||
|
f"CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS {INDEX_NAME} "
|
||||||
|
f"ON {TABLE_NAME} (message_id) "
|
||||||
|
"WHERE message_id IS NOT NULL"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
with op.get_context().autocommit_block():
|
||||||
|
op.execute(f"DROP INDEX CONCURRENTLY IF EXISTS {INDEX_NAME}")
|
||||||
|
|
@ -675,6 +675,23 @@ class NewChatMessage(BaseModel, TimestampMixin):
|
||||||
|
|
||||||
__tablename__ = "new_chat_messages"
|
__tablename__ = "new_chat_messages"
|
||||||
|
|
||||||
|
# Partial unique index on (thread_id, turn_id, role) where turn_id IS NOT NULL.
|
||||||
|
# Mirrors alembic migration 141. Lets the streaming agent and the
|
||||||
|
# legacy frontend appendMessage call coexist idempotently — the second
|
||||||
|
# writer trips the unique and recovers without creating a duplicate row.
|
||||||
|
# Partial so legacy NULL turn_id rows and clone/snapshot inserts in
|
||||||
|
# app/services/public_chat_service.py (which omit turn_id) are unaffected.
|
||||||
|
__table_args__ = (
|
||||||
|
Index(
|
||||||
|
"uq_new_chat_messages_thread_turn_role",
|
||||||
|
"thread_id",
|
||||||
|
"turn_id",
|
||||||
|
"role",
|
||||||
|
unique=True,
|
||||||
|
postgresql_where=text("turn_id IS NOT NULL"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
role = Column(SQLAlchemyEnum(NewChatMessageRole), nullable=False)
|
role = Column(SQLAlchemyEnum(NewChatMessageRole), nullable=False)
|
||||||
# Content stored as JSONB to support rich content (text, tool calls, etc.)
|
# Content stored as JSONB to support rich content (text, tool calls, etc.)
|
||||||
content = Column(JSONB, nullable=False)
|
content = Column(JSONB, nullable=False)
|
||||||
|
|
@ -728,6 +745,22 @@ class TokenUsage(BaseModel, TimestampMixin):
|
||||||
|
|
||||||
__tablename__ = "token_usage"
|
__tablename__ = "token_usage"
|
||||||
|
|
||||||
|
# Partial unique index on (message_id) where message_id IS NOT NULL.
|
||||||
|
# Mirrors alembic migration 142. Lets the streaming agent's
|
||||||
|
# ``finalize_assistant_turn`` and the legacy frontend ``append_message``
|
||||||
|
# recovery branch both use ``INSERT ... ON CONFLICT DO NOTHING`` without
|
||||||
|
# racing on a SELECT-then-INSERT window. Partial so non-chat usage rows
|
||||||
|
# (indexing, image generation, podcasts) — which keep ``message_id`` NULL
|
||||||
|
# because there is no per-message anchor — are unaffected.
|
||||||
|
__table_args__ = (
|
||||||
|
Index(
|
||||||
|
"uq_token_usage_message_id",
|
||||||
|
"message_id",
|
||||||
|
unique=True,
|
||||||
|
postgresql_where=text("message_id IS NOT NULL"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
prompt_tokens = Column(Integer, nullable=False, default=0)
|
prompt_tokens = Column(Integer, nullable=False, default=0)
|
||||||
completion_tokens = Column(Integer, nullable=False, default=0)
|
completion_tokens = Column(Integer, nullable=False, default=0)
|
||||||
total_tokens = Column(Integer, nullable=False, default=0)
|
total_tokens = Column(Integer, nullable=False, default=0)
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,8 @@ from datetime import UTC, datetime
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy import func, or_
|
from sqlalchemy import func, or_, text as sa_text
|
||||||
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
|
@ -44,6 +45,7 @@ from app.db import (
|
||||||
NewChatThread,
|
NewChatThread,
|
||||||
Permission,
|
Permission,
|
||||||
SearchSpace,
|
SearchSpace,
|
||||||
|
TokenUsage,
|
||||||
User,
|
User,
|
||||||
get_async_session,
|
get_async_session,
|
||||||
shielded_async_session,
|
shielded_async_session,
|
||||||
|
|
@ -69,9 +71,9 @@ from app.schemas.new_chat import (
|
||||||
TokenUsageSummary,
|
TokenUsageSummary,
|
||||||
TurnStatusResponse,
|
TurnStatusResponse,
|
||||||
)
|
)
|
||||||
from app.services.token_tracking_service import record_token_usage
|
|
||||||
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
||||||
from app.users import current_active_user
|
from app.users import current_active_user
|
||||||
|
from app.utils.perf import get_perf_logger
|
||||||
from app.utils.rbac import check_permission
|
from app.utils.rbac import check_permission
|
||||||
from app.utils.user_message_multimodal import (
|
from app.utils.user_message_multimodal import (
|
||||||
split_langchain_human_content,
|
split_langchain_human_content,
|
||||||
|
|
@ -79,6 +81,7 @@ from app.utils.user_message_multimodal import (
|
||||||
)
|
)
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
_perf_log = get_perf_logger()
|
||||||
_background_tasks: set[asyncio.Task] = set()
|
_background_tasks: set[asyncio.Task] = set()
|
||||||
TURN_CANCELLING_INITIAL_DELAY_MS = 200
|
TURN_CANCELLING_INITIAL_DELAY_MS = 200
|
||||||
TURN_CANCELLING_BACKOFF_FACTOR = 2
|
TURN_CANCELLING_BACKOFF_FACTOR = 2
|
||||||
|
|
@ -1287,6 +1290,24 @@ async def append_message(
|
||||||
user: User = Depends(current_active_user),
|
user: User = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
.. deprecated:: 2026-05
|
||||||
|
Replaced by the **SSE-based message ID handshake**. The streaming
|
||||||
|
generator (`stream_new_chat` / `stream_resume_chat`) now persists
|
||||||
|
both the user and assistant rows server-side via
|
||||||
|
``persist_user_turn`` / ``persist_assistant_shell`` and emits
|
||||||
|
``data-user-message-id`` / ``data-assistant-message-id`` SSE events
|
||||||
|
so the frontend can rename its optimistic IDs in real time. The
|
||||||
|
new FE bundle no longer calls this route.
|
||||||
|
|
||||||
|
This handler is retained as a **silent no-op for legacy / cached
|
||||||
|
FE bundles**: the underlying ``INSERT ... ON CONFLICT DO NOTHING``
|
||||||
|
pattern means a stale bundle hitting this route after the SSE
|
||||||
|
handshake already wrote the row simply returns the existing row
|
||||||
|
(200 OK) without raising or duplicating data. After a 2-week soak
|
||||||
|
(target: ``[persist_user_turn] outcome=race_recovered`` rate ~0)
|
||||||
|
this entire route — and the FE ``appendMessage`` function — is
|
||||||
|
earmarked for removal.
|
||||||
|
|
||||||
Append a message to a thread.
|
Append a message to a thread.
|
||||||
This is used by ThreadHistoryAdapter.append() to persist messages.
|
This is used by ThreadHistoryAdapter.append() to persist messages.
|
||||||
|
|
||||||
|
|
@ -1297,6 +1318,22 @@ async def append_message(
|
||||||
Requires CHATS_UPDATE permission.
|
Requires CHATS_UPDATE permission.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Capture ``user.id`` as a primitive UUID up front. The
|
||||||
|
# ``current_active_user`` dependency hands us a ``User`` ORM
|
||||||
|
# row bound to ``session``; if the outer ``except
|
||||||
|
# IntegrityError`` block below ever fires (an unexpected
|
||||||
|
# constraint like a foreign key violation — the common
|
||||||
|
# ``(thread_id, turn_id, role)`` race is now handled silently
|
||||||
|
# by ``ON CONFLICT DO NOTHING`` so it never raises) it calls
|
||||||
|
# ``session.rollback()``, which expires every attached ORM
|
||||||
|
# row including this user. Any later ``user.id`` access would
|
||||||
|
# then trigger a lazy PK reload — which on async SQLAlchemy
|
||||||
|
# fails with ``MissingGreenlet`` because the reload happens
|
||||||
|
# outside the awaitable greenlet boundary. Reading ``id``
|
||||||
|
# once here pins the value as a plain UUID so all downstream
|
||||||
|
# uses (TokenUsage insert, response build) are immune.
|
||||||
|
user_uuid = user.id
|
||||||
|
|
||||||
# Parse raw body - extract only role and content, ignoring extra fields
|
# Parse raw body - extract only role and content, ignoring extra fields
|
||||||
raw_body = await request.json()
|
raw_body = await request.json()
|
||||||
role = raw_body.get("role")
|
role = raw_body.get("role")
|
||||||
|
|
@ -1351,42 +1388,166 @@ async def append_message(
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
db_message = NewChatMessage(
|
# Update thread's updated_at timestamp (always — both insert
|
||||||
thread_id=thread_id,
|
# and recovery paths represent thread activity).
|
||||||
role=message_role,
|
|
||||||
content=content,
|
|
||||||
author_id=user.id,
|
|
||||||
turn_id=turn_id_value,
|
|
||||||
)
|
|
||||||
session.add(db_message)
|
|
||||||
|
|
||||||
# Update thread's updated_at timestamp
|
|
||||||
thread.updated_at = datetime.now(UTC)
|
thread.updated_at = datetime.now(UTC)
|
||||||
|
|
||||||
# flush assigns the PK/defaults without a round-trip SELECT
|
# Insert the new message via ``INSERT ... ON CONFLICT DO NOTHING``
|
||||||
await session.flush()
|
# keyed on the ``(thread_id, turn_id, role)`` partial unique
|
||||||
|
# index from migration 141 (``WHERE turn_id IS NOT NULL``).
|
||||||
|
#
|
||||||
|
# Why ON CONFLICT instead of ``session.add() + flush() + except
|
||||||
|
# IntegrityError``:
|
||||||
|
# 1. The conflict between this legacy FE ``appendMessage``
|
||||||
|
# round-trip and the server-side
|
||||||
|
# ``finalize_assistant_turn`` writer is a NORMAL,
|
||||||
|
# *expected* race — every assistant turn fires it. Using
|
||||||
|
# catch-and-recover means asyncpg raises
|
||||||
|
# ``UniqueViolationError`` -> SQLAlchemy wraps it as
|
||||||
|
# ``IntegrityError`` -> our handler catches and recovers.
|
||||||
|
# Functionally fine, but every ``raise`` event lights up
|
||||||
|
# VS Code's debugger (debugpy's ``justMyCode=false`` mode
|
||||||
|
# loses track of the catch frame across SQLAlchemy's
|
||||||
|
# async greenlet boundary, so even ``Raised Exceptions``
|
||||||
|
# being unchecked doesn't reliably suppress the pause).
|
||||||
|
# ON CONFLICT pushes the conflict resolution into Postgres
|
||||||
|
# where no Python exception is constructed at all.
|
||||||
|
# 2. No ``session.rollback()`` -> no expiring of attached
|
||||||
|
# ORM rows -> no risk of ``MissingGreenlet`` from
|
||||||
|
# lazy-loading expired user/thread state later in the
|
||||||
|
# handler.
|
||||||
|
# 3. Cleaner production logs (no SQLAlchemy ``IntegrityError``
|
||||||
|
# tracebacks emitted by uvicorn's logger between the
|
||||||
|
# ``raise`` and our ``except``).
|
||||||
|
#
|
||||||
|
# When ``turn_id_value`` is ``None`` the partial index doesn't
|
||||||
|
# apply and the INSERT proceeds normally. Other constraint
|
||||||
|
# violations (FK, NOT NULL, etc.) still raise ``IntegrityError``
|
||||||
|
# and are caught by the outer ``except IntegrityError`` block
|
||||||
|
# to preserve the legacy 400 behavior.
|
||||||
|
#
|
||||||
|
# Note on ``content``: when we recover the existing row, we
|
||||||
|
# intentionally discard the FE's ``content`` payload from
|
||||||
|
# ``raw_body`` and return the row's existing ``content``. The
|
||||||
|
# streaming task is now the *authoritative writer* for
|
||||||
|
# assistant ``ContentPart[]`` shape (mid-stream
|
||||||
|
# ``AssistantContentBuilder`` -> ``finalize_assistant_turn``)
|
||||||
|
# so the FE's later ``appendMessage`` is just a stale snapshot
|
||||||
|
# of the same data — keeping the server-built rich content
|
||||||
|
# (with full tool-call args / argsText / langchainToolCallId)
|
||||||
|
# is correct, not lossy.
|
||||||
|
insert_stmt = (
|
||||||
|
pg_insert(NewChatMessage)
|
||||||
|
.values(
|
||||||
|
thread_id=thread_id,
|
||||||
|
role=message_role,
|
||||||
|
content=content,
|
||||||
|
author_id=user_uuid,
|
||||||
|
turn_id=turn_id_value,
|
||||||
|
)
|
||||||
|
.on_conflict_do_nothing(
|
||||||
|
index_elements=["thread_id", "turn_id", "role"],
|
||||||
|
index_where=sa_text("turn_id IS NOT NULL"),
|
||||||
|
)
|
||||||
|
.returning(NewChatMessage.id)
|
||||||
|
)
|
||||||
|
inserted_id = (await session.execute(insert_stmt)).scalar()
|
||||||
|
|
||||||
|
if inserted_id is None:
|
||||||
|
# Conflict on partial unique index — server-side stream
|
||||||
|
# already wrote this row. Look it up and reuse it.
|
||||||
|
if turn_id_value is None:
|
||||||
|
# Defensive: ON CONFLICT only fires for ``turn_id IS
|
||||||
|
# NOT NULL`` rows, so this branch should be
|
||||||
|
# unreachable. Preserve the legacy 400 just in case
|
||||||
|
# Postgres ever surprises us.
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Database constraint violation. Please check your input data.",
|
||||||
|
) from None
|
||||||
|
lookup = await session.execute(
|
||||||
|
select(NewChatMessage).filter(
|
||||||
|
NewChatMessage.thread_id == thread_id,
|
||||||
|
NewChatMessage.turn_id == turn_id_value,
|
||||||
|
NewChatMessage.role == message_role,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
existing_message = lookup.scalars().first()
|
||||||
|
if existing_message is None:
|
||||||
|
# Conflict reported but the row vanished between
|
||||||
|
# INSERT and SELECT — extremely unlikely (would
|
||||||
|
# require a concurrent DELETE within the same
|
||||||
|
# transaction visibility), but preserve safe
|
||||||
|
# behavior.
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Database constraint violation. Please check your input data.",
|
||||||
|
) from None
|
||||||
|
db_message = existing_message
|
||||||
|
# Perf signal: counts how often the legacy FE round-trip
|
||||||
|
# races the server-side ``finalize_assistant_turn``. A
|
||||||
|
# rising rate after the rework is OK (it's exactly the
|
||||||
|
# ghost-thread fix's recovery path firing); a sudden drop
|
||||||
|
# to zero would mean the FE isn't posting appendMessage
|
||||||
|
# at all (different bug).
|
||||||
|
_perf_log.info(
|
||||||
|
"[append_message] outcome=recovered_via_unique_index "
|
||||||
|
"thread_id=%s turn_id=%s role=%s message_id=%s",
|
||||||
|
thread_id,
|
||||||
|
turn_id_value,
|
||||||
|
message_role.value,
|
||||||
|
db_message.id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# INSERT succeeded — load the full ORM row so the
|
||||||
|
# response can include server-side-defaulted columns
|
||||||
|
# (``created_at``, etc.) and the relationship surface
|
||||||
|
# stays consistent with the recovery path.
|
||||||
|
inserted_row = await session.get(NewChatMessage, inserted_id)
|
||||||
|
if inserted_row is None:
|
||||||
|
# Should be impossible: we just inserted it in this
|
||||||
|
# same transaction. Fail loud if it happens.
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Inserted message could not be loaded.",
|
||||||
|
) from None
|
||||||
|
db_message = inserted_row
|
||||||
|
|
||||||
# Persist token usage if provided (for assistant messages).
|
# Persist token usage if provided (for assistant messages).
|
||||||
# ``cost_micros`` is the provider USD cost reported by LiteLLM,
|
# ``cost_micros`` is the provider USD cost reported by LiteLLM,
|
||||||
# forwarded by the FE through the appendMessage round-trip so
|
# forwarded by the FE through the appendMessage round-trip so
|
||||||
# the historical TokenUsage row matches the credit debit applied
|
# the historical TokenUsage row matches the credit debit applied
|
||||||
# at finalize time.
|
# at finalize time.
|
||||||
|
#
|
||||||
|
# De-dup: ``finalize_assistant_turn`` may also race to write a
|
||||||
|
# token_usage row for this same ``message_id`` (cross-session,
|
||||||
|
# cross-shielded). Use ``INSERT ... ON CONFLICT DO NOTHING`` keyed
|
||||||
|
# on the ``uq_token_usage_message_id`` partial unique index
|
||||||
|
# (migration 142). The loser silently drops its insert; exactly
|
||||||
|
# one row results regardless of which writer commits first.
|
||||||
token_usage_data = raw_body.get("token_usage")
|
token_usage_data = raw_body.get("token_usage")
|
||||||
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
|
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
|
||||||
await record_token_usage(
|
insert_stmt = (
|
||||||
session,
|
pg_insert(TokenUsage)
|
||||||
usage_type="chat",
|
.values(
|
||||||
search_space_id=thread.search_space_id,
|
usage_type="chat",
|
||||||
user_id=user.id,
|
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
|
||||||
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
|
completion_tokens=token_usage_data.get("completion_tokens", 0),
|
||||||
completion_tokens=token_usage_data.get("completion_tokens", 0),
|
total_tokens=token_usage_data.get("total_tokens", 0),
|
||||||
total_tokens=token_usage_data.get("total_tokens", 0),
|
cost_micros=token_usage_data.get("cost_micros", 0),
|
||||||
cost_micros=token_usage_data.get("cost_micros", 0),
|
model_breakdown=token_usage_data.get("usage"),
|
||||||
model_breakdown=token_usage_data.get("usage"),
|
call_details=token_usage_data.get("call_details"),
|
||||||
call_details=token_usage_data.get("call_details"),
|
thread_id=thread_id,
|
||||||
thread_id=thread_id,
|
message_id=db_message.id,
|
||||||
message_id=db_message.id,
|
search_space_id=thread.search_space_id,
|
||||||
|
user_id=user_uuid,
|
||||||
|
)
|
||||||
|
.on_conflict_do_nothing(
|
||||||
|
index_elements=["message_id"],
|
||||||
|
index_where=sa_text("message_id IS NOT NULL"),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
await session.execute(insert_stmt)
|
||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
|
@ -1406,6 +1567,9 @@ async def append_message(
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except IntegrityError:
|
except IntegrityError:
|
||||||
|
# Any IntegrityError that escaped the inline handler above
|
||||||
|
# comes from a *different* constraint (foreign key, etc.) —
|
||||||
|
# preserve the legacy 400 path.
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
|
|
@ -1599,6 +1763,12 @@ async def handle_new_chat(
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mentioned_documents_payload = (
|
||||||
|
[doc.model_dump() for doc in request.mentioned_documents]
|
||||||
|
if request.mentioned_documents
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
stream_new_chat(
|
stream_new_chat(
|
||||||
user_query=request.user_query,
|
user_query=request.user_query,
|
||||||
|
|
@ -1608,6 +1778,7 @@ async def handle_new_chat(
|
||||||
llm_config_id=llm_config_id,
|
llm_config_id=llm_config_id,
|
||||||
mentioned_document_ids=request.mentioned_document_ids,
|
mentioned_document_ids=request.mentioned_document_ids,
|
||||||
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
|
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
|
||||||
|
mentioned_documents=mentioned_documents_payload,
|
||||||
needs_history_bootstrap=thread.needs_history_bootstrap,
|
needs_history_bootstrap=thread.needs_history_bootstrap,
|
||||||
thread_visibility=thread.visibility,
|
thread_visibility=thread.visibility,
|
||||||
current_user_display_name=user.display_name or "A team member",
|
current_user_display_name=user.display_name or "A team member",
|
||||||
|
|
@ -2078,6 +2249,11 @@ async def regenerate_response(
|
||||||
"data": revert_results,
|
"data": revert_results,
|
||||||
}
|
}
|
||||||
yield f"data: {json.dumps(envelope, default=str)}\n\n".encode()
|
yield f"data: {json.dumps(envelope, default=str)}\n\n".encode()
|
||||||
|
mentioned_documents_payload = (
|
||||||
|
[doc.model_dump() for doc in request.mentioned_documents]
|
||||||
|
if request.mentioned_documents
|
||||||
|
else None
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
async for chunk in stream_new_chat(
|
async for chunk in stream_new_chat(
|
||||||
user_query=str(user_query_to_use),
|
user_query=str(user_query_to_use),
|
||||||
|
|
@ -2087,6 +2263,7 @@ async def regenerate_response(
|
||||||
llm_config_id=llm_config_id,
|
llm_config_id=llm_config_id,
|
||||||
mentioned_document_ids=request.mentioned_document_ids,
|
mentioned_document_ids=request.mentioned_document_ids,
|
||||||
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
|
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
|
||||||
|
mentioned_documents=mentioned_documents_payload,
|
||||||
checkpoint_id=target_checkpoint_id,
|
checkpoint_id=target_checkpoint_id,
|
||||||
needs_history_bootstrap=thread.needs_history_bootstrap,
|
needs_history_bootstrap=thread.needs_history_bootstrap,
|
||||||
thread_visibility=thread.visibility,
|
thread_visibility=thread.visibility,
|
||||||
|
|
|
||||||
|
|
@ -200,6 +200,21 @@ class NewChatUserImagePart(BaseModel):
|
||||||
return to_data_url(self.media_type, self.data)
|
return to_data_url(self.media_type, self.data)
|
||||||
|
|
||||||
|
|
||||||
|
class MentionedDocumentInfo(BaseModel):
|
||||||
|
"""Display metadata for a single ``@``-mentioned document.
|
||||||
|
|
||||||
|
The full triple ``{id, title, document_type}`` is forwarded by the
|
||||||
|
frontend mention chip so the server can embed it in the persisted
|
||||||
|
user message ``ContentPart[]`` (single ``mentioned-documents`` part).
|
||||||
|
The history loader then renders the chips on reload without an extra
|
||||||
|
fetch — mirrors the pre-refactor frontend ``persistUserTurn`` shape.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
title: str = Field(..., min_length=1, max_length=500)
|
||||||
|
document_type: str = Field(..., min_length=1, max_length=100)
|
||||||
|
|
||||||
|
|
||||||
class NewChatRequest(BaseModel):
|
class NewChatRequest(BaseModel):
|
||||||
"""Request schema for the deep agent chat endpoint."""
|
"""Request schema for the deep agent chat endpoint."""
|
||||||
|
|
||||||
|
|
@ -213,6 +228,17 @@ class NewChatRequest(BaseModel):
|
||||||
mentioned_surfsense_doc_ids: list[int] | None = (
|
mentioned_surfsense_doc_ids: list[int] | None = (
|
||||||
None # Optional SurfSense documentation IDs mentioned with @ in the chat
|
None # Optional SurfSense documentation IDs mentioned with @ in the chat
|
||||||
)
|
)
|
||||||
|
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Display metadata (id, title, document_type) for every "
|
||||||
|
"@-mentioned document. Persisted as a ``mentioned-documents`` "
|
||||||
|
"ContentPart on the user message so reload renders chips "
|
||||||
|
"without an extra fetch. Optional and additive — when None "
|
||||||
|
"the user message is persisted without a mentioned-documents "
|
||||||
|
"part."
|
||||||
|
),
|
||||||
|
)
|
||||||
disabled_tools: list[str] | None = (
|
disabled_tools: list[str] | None = (
|
||||||
None # Optional list of tool names the user has disabled from the UI
|
None # Optional list of tool names the user has disabled from the UI
|
||||||
)
|
)
|
||||||
|
|
@ -264,6 +290,16 @@ class RegenerateRequest(BaseModel):
|
||||||
)
|
)
|
||||||
mentioned_document_ids: list[int] | None = None
|
mentioned_document_ids: list[int] | None = None
|
||||||
mentioned_surfsense_doc_ids: list[int] | None = None
|
mentioned_surfsense_doc_ids: list[int] | None = None
|
||||||
|
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Display metadata (id, title, document_type) for every "
|
||||||
|
"@-mentioned document on the edited user turn. Only used "
|
||||||
|
"when ``user_query`` is non-None (edit). Persisted as a "
|
||||||
|
"``mentioned-documents`` ContentPart on the new user "
|
||||||
|
"message. None means no chip metadata."
|
||||||
|
),
|
||||||
|
)
|
||||||
disabled_tools: list[str] | None = None
|
disabled_tools: list[str] | None = None
|
||||||
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
|
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
|
||||||
client_platform: Literal["web", "desktop"] = "web"
|
client_platform: Literal["web", "desktop"] = "web"
|
||||||
|
|
@ -334,6 +370,16 @@ class ResumeRequest(BaseModel):
|
||||||
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
|
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
|
||||||
client_platform: Literal["web", "desktop"] = "web"
|
client_platform: Literal["web", "desktop"] = "web"
|
||||||
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
|
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
|
||||||
|
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Display metadata forwarded for symmetry with /new_chat and "
|
||||||
|
"/regenerate. Resume reuses the original interrupted user "
|
||||||
|
"turn so the server does not write a new user message. "
|
||||||
|
"Currently unused but accepted to keep request bodies "
|
||||||
|
"uniform across the three streaming entrypoints."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CancelActiveTurnResponse(BaseModel):
|
class CancelActiveTurnResponse(BaseModel):
|
||||||
|
|
|
||||||
515
surfsense_backend/app/tasks/chat/content_builder.py
Normal file
515
surfsense_backend/app/tasks/chat/content_builder.py
Normal file
|
|
@ -0,0 +1,515 @@
|
||||||
|
"""Server-side mirror of the frontend's assistant-ui ``ContentPart`` projection.
|
||||||
|
|
||||||
|
Background
|
||||||
|
----------
|
||||||
|
The streaming chat task in ``stream_new_chat`` / ``stream_resume_chat`` yields
|
||||||
|
SSE events that the frontend folds into a ``ContentPartsState`` (see
|
||||||
|
``surfsense_web/lib/chat/streaming-state.ts`` and the matching pipeline in
|
||||||
|
``stream-pipeline.ts``). When a turn ends, the frontend calls
|
||||||
|
``buildContentForPersistence(...)`` and round-trips that ``ContentPart[]``
|
||||||
|
JSONB to ``POST /threads/{id}/messages``, which is what was historically
|
||||||
|
written to ``new_chat_messages.content``.
|
||||||
|
|
||||||
|
After the ghost-thread fix moved persistence server-side, the assistant
|
||||||
|
row is written by ``finalize_assistant_turn`` in the streaming finally
|
||||||
|
block. The frontend's later ``appendMessage`` is now a no-op (recovers
|
||||||
|
via the ``(thread_id, turn_id, role)`` partial unique index added in
|
||||||
|
migration 141), which means the *server* is now responsible for
|
||||||
|
producing the rich ``ContentPart[]`` shape the FE expects on history
|
||||||
|
reload — text + reasoning + tool-call cards (with ``args``, ``argsText``,
|
||||||
|
``result``, ``langchainToolCallId``) + thinking-step buckets +
|
||||||
|
step-separators.
|
||||||
|
|
||||||
|
This module is the in-memory accumulator that mirrors the FE state for
|
||||||
|
exactly that purpose. The streaming code calls ``on_text_*`` / ``on_reasoning_*``
|
||||||
|
/ ``on_tool_*`` / ``on_thinking_step`` / ``on_step_separator`` /
|
||||||
|
``mark_interrupted`` at the same call sites it yields the matching
|
||||||
|
``streaming_service.format_*`` SSE event, so the in-memory ``parts`` list
|
||||||
|
stays in lockstep with what the FE's pipeline would have produced live.
|
||||||
|
``snapshot()`` is then taken once in the ``finally`` block and persisted
|
||||||
|
in a single UPDATE.
|
||||||
|
|
||||||
|
Pure synchronous state — no DB I/O, no async, no flush callbacks. The
|
||||||
|
streaming code is responsible for driving lifecycle methods; this class
|
||||||
|
is a thin projection helper.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Mirrors the FE's filter in ``buildContentForPersistence`` / ``buildContentForUI``:
|
||||||
|
# only text/reasoning/tool-call parts count as "meaningful". data-thinking-steps
|
||||||
|
# and data-step-separator decorate the meaningful parts but never stand alone
|
||||||
|
# in a successful turn.
|
||||||
|
_MEANINGFUL_PART_TYPES: frozenset[str] = frozenset({"text", "reasoning", "tool-call"})
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantContentBuilder:
|
||||||
|
"""Server-side projection of ``surfsense_web/lib/chat/streaming-state.ts``.
|
||||||
|
|
||||||
|
Output shape (deep copy of ``self.parts`` via ``snapshot()``) strictly
|
||||||
|
matches the FE ``ContentPart`` union::
|
||||||
|
|
||||||
|
| { type: "text"; text: string }
|
||||||
|
| { type: "reasoning"; text: string }
|
||||||
|
| { type: "tool-call"; toolCallId: str; toolName: str;
|
||||||
|
args: dict; result?: any; argsText?: str; langchainToolCallId?: str;
|
||||||
|
state?: "aborted" }
|
||||||
|
| { type: "data-thinking-steps"; data: { steps: ThinkingStepData[] } }
|
||||||
|
| { type: "data-step-separator"; data: { stepIndex: int } }
|
||||||
|
|
||||||
|
Order matches the wire order of the SSE events that drive the lifecycle
|
||||||
|
methods, with two FE-mirrored exceptions:
|
||||||
|
|
||||||
|
1. ``data-thinking-steps`` is a *singleton* and pinned at index 0 the
|
||||||
|
first time we see a ``data-thinking-step`` SSE event (the FE's
|
||||||
|
``updateThinkingSteps`` does ``unshift`` on first sight). Subsequent
|
||||||
|
thinking-step updates mutate that singleton in place.
|
||||||
|
2. ``data-step-separator`` is appended only when the message already has
|
||||||
|
meaningful content and the previous part isn't itself a separator
|
||||||
|
(so the FIRST step of a turn doesn't generate a leading divider).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.parts: list[dict[str, Any]] = []
|
||||||
|
# Index of the active text/reasoning part within ``parts`` while
|
||||||
|
# streaming is open; -1 means "no active part" and the next delta
|
||||||
|
# opens a fresh one. Mirrors ``ContentPartsState.currentTextPartIndex``.
|
||||||
|
self._current_text_idx: int = -1
|
||||||
|
self._current_reasoning_idx: int = -1
|
||||||
|
# ``ui_id``-keyed indexes for tool-call parts. ``ui_id`` is the
|
||||||
|
# synthetic ``call_<run_id>`` (legacy) or the LangChain
|
||||||
|
# ``tool_call.id`` (parity_v2) — same key the streaming layer
|
||||||
|
# threads through every ``tool-input-*`` / ``tool-output-*`` event.
|
||||||
|
self._tool_call_idx_by_ui_id: dict[str, int] = {}
|
||||||
|
# Live argsText accumulator (concatenated ``tool-input-delta`` chunks)
|
||||||
|
# so we can reproduce the FE's ``appendToolInputDelta`` behaviour
|
||||||
|
# before ``tool-input-available`` overwrites it with the
|
||||||
|
# pretty-printed final JSON.
|
||||||
|
self._args_text_by_ui_id: dict[str, str] = {}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Text
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def on_text_start(self, text_id: str) -> None:
|
||||||
|
"""Begin a fresh text block.
|
||||||
|
|
||||||
|
Symmetric to FE ``appendText``: opening text closes any active
|
||||||
|
reasoning so the renderer treats them as separate parts. The
|
||||||
|
actual text part isn't materialised here — it's lazily created
|
||||||
|
on the first ``on_text_delta`` so an empty start/end pair
|
||||||
|
leaves no trace. Matches the FE pipeline which has no explicit
|
||||||
|
``text-start`` handler at all.
|
||||||
|
"""
|
||||||
|
if self._current_reasoning_idx >= 0:
|
||||||
|
self._current_reasoning_idx = -1
|
||||||
|
|
||||||
|
def on_text_delta(self, text_id: str, delta: str) -> None:
|
||||||
|
if not delta:
|
||||||
|
return
|
||||||
|
if self._current_reasoning_idx >= 0:
|
||||||
|
# FE behaviour: a text delta after reasoning implicitly
|
||||||
|
# closes the reasoning block (see ``appendText`` lines
|
||||||
|
# 178-180).
|
||||||
|
self._current_reasoning_idx = -1
|
||||||
|
if (
|
||||||
|
self._current_text_idx >= 0
|
||||||
|
and 0 <= self._current_text_idx < len(self.parts)
|
||||||
|
and self.parts[self._current_text_idx].get("type") == "text"
|
||||||
|
):
|
||||||
|
self.parts[self._current_text_idx]["text"] += delta
|
||||||
|
return
|
||||||
|
self.parts.append({"type": "text", "text": delta})
|
||||||
|
self._current_text_idx = len(self.parts) - 1
|
||||||
|
|
||||||
|
def on_text_end(self, text_id: str) -> None:
|
||||||
|
"""Close the active text block.
|
||||||
|
|
||||||
|
Mirrors the wire-level ``text-end`` boundary the streaming layer
|
||||||
|
emits before tool calls / reasoning / step boundaries. The FE
|
||||||
|
pipeline implicitly closes via ``currentTextPartIndex = -1``
|
||||||
|
in ``addToolCall`` / ``appendReasoning`` / ``addStepSeparator``;
|
||||||
|
our helper does the same explicitly so callers don't have to
|
||||||
|
maintain that invariant per call site.
|
||||||
|
"""
|
||||||
|
self._current_text_idx = -1
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Reasoning
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def on_reasoning_start(self, reasoning_id: str) -> None:
|
||||||
|
if self._current_text_idx >= 0:
|
||||||
|
self._current_text_idx = -1
|
||||||
|
|
||||||
|
def on_reasoning_delta(self, reasoning_id: str, delta: str) -> None:
|
||||||
|
if not delta:
|
||||||
|
return
|
||||||
|
if self._current_text_idx >= 0:
|
||||||
|
self._current_text_idx = -1
|
||||||
|
if (
|
||||||
|
self._current_reasoning_idx >= 0
|
||||||
|
and 0 <= self._current_reasoning_idx < len(self.parts)
|
||||||
|
and self.parts[self._current_reasoning_idx].get("type") == "reasoning"
|
||||||
|
):
|
||||||
|
self.parts[self._current_reasoning_idx]["text"] += delta
|
||||||
|
return
|
||||||
|
self.parts.append({"type": "reasoning", "text": delta})
|
||||||
|
self._current_reasoning_idx = len(self.parts) - 1
|
||||||
|
|
||||||
|
def on_reasoning_end(self, reasoning_id: str) -> None:
|
||||||
|
self._current_reasoning_idx = -1
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Tool calls
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def on_tool_input_start(
|
||||||
|
self,
|
||||||
|
ui_id: str,
|
||||||
|
tool_name: str,
|
||||||
|
langchain_tool_call_id: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Register a tool-call card. Args are filled in by later events."""
|
||||||
|
if not ui_id:
|
||||||
|
return
|
||||||
|
# Skip duplicate registration: parity_v2 may emit
|
||||||
|
# ``tool-input-start`` from both ``on_chat_model_stream``
|
||||||
|
# (when tool_call_chunks register a name) and ``on_tool_start``
|
||||||
|
# (the canonical path). The FE de-dupes via ``toolCallIndices``;
|
||||||
|
# we mirror that here.
|
||||||
|
if ui_id in self._tool_call_idx_by_ui_id:
|
||||||
|
if langchain_tool_call_id:
|
||||||
|
idx = self._tool_call_idx_by_ui_id[ui_id]
|
||||||
|
part = self.parts[idx]
|
||||||
|
if not part.get("langchainToolCallId"):
|
||||||
|
part["langchainToolCallId"] = langchain_tool_call_id
|
||||||
|
return
|
||||||
|
|
||||||
|
part: dict[str, Any] = {
|
||||||
|
"type": "tool-call",
|
||||||
|
"toolCallId": ui_id,
|
||||||
|
"toolName": tool_name,
|
||||||
|
"args": {},
|
||||||
|
}
|
||||||
|
if langchain_tool_call_id:
|
||||||
|
part["langchainToolCallId"] = langchain_tool_call_id
|
||||||
|
self.parts.append(part)
|
||||||
|
self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1
|
||||||
|
|
||||||
|
self._current_text_idx = -1
|
||||||
|
self._current_reasoning_idx = -1
|
||||||
|
|
||||||
|
def on_tool_input_delta(self, ui_id: str, args_chunk: str) -> None:
|
||||||
|
"""Append a streamed args-delta chunk to the matching card's argsText.
|
||||||
|
|
||||||
|
Mirrors FE ``appendToolInputDelta``: no-ops when no card has been
|
||||||
|
registered yet for the given ``ui_id`` — the deltas have nowhere
|
||||||
|
safe to land.
|
||||||
|
"""
|
||||||
|
if not ui_id or not args_chunk:
|
||||||
|
return
|
||||||
|
idx = self._tool_call_idx_by_ui_id.get(ui_id)
|
||||||
|
if idx is None:
|
||||||
|
return
|
||||||
|
if not (0 <= idx < len(self.parts)):
|
||||||
|
return
|
||||||
|
part = self.parts[idx]
|
||||||
|
if part.get("type") != "tool-call":
|
||||||
|
return
|
||||||
|
new_text = (part.get("argsText") or "") + args_chunk
|
||||||
|
part["argsText"] = new_text
|
||||||
|
self._args_text_by_ui_id[ui_id] = new_text
|
||||||
|
|
||||||
|
def on_tool_input_available(
|
||||||
|
self,
|
||||||
|
ui_id: str,
|
||||||
|
tool_name: str,
|
||||||
|
args: dict[str, Any],
|
||||||
|
langchain_tool_call_id: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Finalize the tool-call card's input.
|
||||||
|
|
||||||
|
Mirrors FE ``stream-pipeline.ts`` lines 127-153: replaces ``argsText``
|
||||||
|
with ``json.dumps(input, indent=2)`` so the post-stream card renders
|
||||||
|
pretty-printed JSON, sets the full ``args`` dict, and backfills
|
||||||
|
``langchainToolCallId`` if it wasn't known at ``tool-input-start`` time.
|
||||||
|
Also creates the card if no prior ``tool-input-start`` registered it
|
||||||
|
(legacy parity_v2-OFF / late-registration paths).
|
||||||
|
"""
|
||||||
|
if not ui_id:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
final_args_text = json.dumps(args or {}, indent=2, ensure_ascii=False)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
# Defensive: ``args`` should already be JSON-safe (the
|
||||||
|
# streaming layer sanitizes it before emitting), but if a
|
||||||
|
# caller hands us a non-serializable value we still want
|
||||||
|
# to record the call without breaking the snapshot.
|
||||||
|
final_args_text = str(args)
|
||||||
|
|
||||||
|
idx = self._tool_call_idx_by_ui_id.get(ui_id)
|
||||||
|
if idx is not None and 0 <= idx < len(self.parts):
|
||||||
|
part = self.parts[idx]
|
||||||
|
if part.get("type") == "tool-call":
|
||||||
|
part["args"] = args or {}
|
||||||
|
part["argsText"] = final_args_text
|
||||||
|
if langchain_tool_call_id and not part.get("langchainToolCallId"):
|
||||||
|
part["langchainToolCallId"] = langchain_tool_call_id
|
||||||
|
return
|
||||||
|
|
||||||
|
# No prior tool-input-start: register the card now.
|
||||||
|
new_part: dict[str, Any] = {
|
||||||
|
"type": "tool-call",
|
||||||
|
"toolCallId": ui_id,
|
||||||
|
"toolName": tool_name,
|
||||||
|
"args": args or {},
|
||||||
|
"argsText": final_args_text,
|
||||||
|
}
|
||||||
|
if langchain_tool_call_id:
|
||||||
|
new_part["langchainToolCallId"] = langchain_tool_call_id
|
||||||
|
self.parts.append(new_part)
|
||||||
|
self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1
|
||||||
|
|
||||||
|
self._current_text_idx = -1
|
||||||
|
self._current_reasoning_idx = -1
|
||||||
|
|
||||||
|
def on_tool_output_available(
|
||||||
|
self,
|
||||||
|
ui_id: str,
|
||||||
|
output: Any,
|
||||||
|
langchain_tool_call_id: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Attach the tool's output (``result``) to the matching card.
|
||||||
|
|
||||||
|
Mirrors FE ``updateToolCall``: backfill ``langchainToolCallId``
|
||||||
|
only if not already set (a NULL late-arriving value never blows
|
||||||
|
away an earlier known good one).
|
||||||
|
"""
|
||||||
|
if not ui_id:
|
||||||
|
return
|
||||||
|
idx = self._tool_call_idx_by_ui_id.get(ui_id)
|
||||||
|
if idx is None or not (0 <= idx < len(self.parts)):
|
||||||
|
return
|
||||||
|
part = self.parts[idx]
|
||||||
|
if part.get("type") != "tool-call":
|
||||||
|
return
|
||||||
|
part["result"] = output
|
||||||
|
if langchain_tool_call_id and not part.get("langchainToolCallId"):
|
||||||
|
part["langchainToolCallId"] = langchain_tool_call_id
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Thinking steps & step separators
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def on_thinking_step(
|
||||||
|
self,
|
||||||
|
step_id: str,
|
||||||
|
title: str,
|
||||||
|
status: str,
|
||||||
|
items: list[str] | None,
|
||||||
|
) -> None:
|
||||||
|
"""Update / insert the singleton ``data-thinking-steps`` part.
|
||||||
|
|
||||||
|
Mirrors FE ``updateThinkingSteps``: maintain a single
|
||||||
|
``data-thinking-steps`` part anchored at index 0, replacing or
|
||||||
|
unshifting on first sight. Each ``on_thinking_step`` call
|
||||||
|
replaces the entry in the steps list keyed by ``step_id`` (or
|
||||||
|
appends if new).
|
||||||
|
"""
|
||||||
|
if not step_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
new_step = {
|
||||||
|
"id": step_id,
|
||||||
|
"title": title or "",
|
||||||
|
"status": status or "in_progress",
|
||||||
|
"items": list(items) if items else [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Find existing data-thinking-steps part.
|
||||||
|
existing_idx = -1
|
||||||
|
for i, p in enumerate(self.parts):
|
||||||
|
if p.get("type") == "data-thinking-steps":
|
||||||
|
existing_idx = i
|
||||||
|
break
|
||||||
|
|
||||||
|
if existing_idx >= 0:
|
||||||
|
current_steps = self.parts[existing_idx].get("data", {}).get("steps") or []
|
||||||
|
replaced = False
|
||||||
|
for i, step in enumerate(current_steps):
|
||||||
|
if step.get("id") == step_id:
|
||||||
|
current_steps[i] = new_step
|
||||||
|
replaced = True
|
||||||
|
break
|
||||||
|
if not replaced:
|
||||||
|
current_steps.append(new_step)
|
||||||
|
self.parts[existing_idx] = {
|
||||||
|
"type": "data-thinking-steps",
|
||||||
|
"data": {"steps": current_steps},
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
# First sight: unshift to position 0 (FE parity).
|
||||||
|
self.parts.insert(
|
||||||
|
0,
|
||||||
|
{
|
||||||
|
"type": "data-thinking-steps",
|
||||||
|
"data": {"steps": [new_step]},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Bump tracked indices since we inserted at the head.
|
||||||
|
if self._current_text_idx >= 0:
|
||||||
|
self._current_text_idx += 1
|
||||||
|
if self._current_reasoning_idx >= 0:
|
||||||
|
self._current_reasoning_idx += 1
|
||||||
|
for ui_id, idx in list(self._tool_call_idx_by_ui_id.items()):
|
||||||
|
self._tool_call_idx_by_ui_id[ui_id] = idx + 1
|
||||||
|
|
||||||
|
def on_step_separator(self) -> None:
|
||||||
|
"""Append a ``data-step-separator`` between consecutive model steps.
|
||||||
|
|
||||||
|
Mirrors FE ``addStepSeparator``: only emit when the message
|
||||||
|
already has meaningful content AND the previous part isn't
|
||||||
|
itself a separator. ``stepIndex`` is the running count of
|
||||||
|
separators already in ``parts``.
|
||||||
|
"""
|
||||||
|
has_content = any(p.get("type") in _MEANINGFUL_PART_TYPES for p in self.parts)
|
||||||
|
if not has_content:
|
||||||
|
return
|
||||||
|
if self.parts and self.parts[-1].get("type") == "data-step-separator":
|
||||||
|
return
|
||||||
|
step_index = sum(
|
||||||
|
1 for p in self.parts if p.get("type") == "data-step-separator"
|
||||||
|
)
|
||||||
|
self.parts.append(
|
||||||
|
{
|
||||||
|
"type": "data-step-separator",
|
||||||
|
"data": {"stepIndex": step_index},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self._current_text_idx = -1
|
||||||
|
self._current_reasoning_idx = -1
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Interruption handling
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def mark_interrupted(self) -> None:
|
||||||
|
"""Close any open text/reasoning and flip running tools to aborted.
|
||||||
|
|
||||||
|
Called from the streaming ``finally`` block before ``snapshot()`` so
|
||||||
|
the persisted JSONB reflects a coherent end-state even when the
|
||||||
|
client disconnected mid-turn or the agent hit a fatal error.
|
||||||
|
|
||||||
|
- Active text/reasoning blocks: simply lose their "active"
|
||||||
|
marker (no synthetic content appended). Whatever was streamed
|
||||||
|
stays as-is.
|
||||||
|
- Tool-call parts that never received a ``result`` get
|
||||||
|
``state="aborted"`` so the FE history loader can render them
|
||||||
|
as "interrupted" rather than "still running".
|
||||||
|
"""
|
||||||
|
self._current_text_idx = -1
|
||||||
|
self._current_reasoning_idx = -1
|
||||||
|
for part in self.parts:
|
||||||
|
if part.get("type") != "tool-call":
|
||||||
|
continue
|
||||||
|
if "result" in part:
|
||||||
|
continue
|
||||||
|
part["state"] = "aborted"
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Snapshot & introspection
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def snapshot(self) -> list[dict[str, Any]]:
|
||||||
|
"""Return a deep copy of ``parts`` ready for SQL UPDATE / json.dumps.
|
||||||
|
|
||||||
|
Deep-copied so callers that finalize from the shielded ``finally``
|
||||||
|
block can't accidentally mutate the persisted payload while the
|
||||||
|
SQL UPDATE is in flight (the streaming layer doesn't touch the
|
||||||
|
builder after this call, but defensive copies are cheap and cheap
|
||||||
|
is what we want in a finally block).
|
||||||
|
"""
|
||||||
|
return copy.deepcopy(self.parts)
|
||||||
|
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
"""True if no meaningful content was captured.
|
||||||
|
|
||||||
|
``data-thinking-steps`` and ``data-step-separator`` decorate
|
||||||
|
meaningful content but don't count on their own — a turn that
|
||||||
|
only emitted a thinking step before being interrupted should
|
||||||
|
still be treated as empty for the status-marker fallback.
|
||||||
|
"""
|
||||||
|
return not any(p.get("type") in _MEANINGFUL_PART_TYPES for p in self.parts)
|
||||||
|
|
||||||
|
def stats(self) -> dict[str, int]:
|
||||||
|
"""Return counts of each part-type plus rough byte size.
|
||||||
|
|
||||||
|
Used by the streaming layer's perf logger so an ops dashboard
|
||||||
|
can correlate finalize latency with payload size, and so a
|
||||||
|
regression that quietly stops emitting tool-call parts (or
|
||||||
|
starts emitting hundreds) shows up in [PERF] grep rather than
|
||||||
|
only as a "history reload looks weird" bug report.
|
||||||
|
|
||||||
|
``bytes`` is the JSON-serialised payload length — what actually
|
||||||
|
crosses the wire to PostgreSQL's JSONB column. We compute it
|
||||||
|
with ``ensure_ascii=False`` to match the JSONB encoder's UTF-8
|
||||||
|
on-disk layout closely enough for back-of-the-envelope sizing.
|
||||||
|
Reasoning/text/tool-call/thinking-step/step-separator counts are
|
||||||
|
independent so any one can spike without the others.
|
||||||
|
|
||||||
|
Defensive: ``json.dumps`` failure (a non-serializable value
|
||||||
|
slipped past the streaming layer's sanitization) is reported as
|
||||||
|
``bytes=-1`` rather than raised — perf logging must not be the
|
||||||
|
thing that breaks the streaming finally block.
|
||||||
|
"""
|
||||||
|
text_blocks = 0
|
||||||
|
reasoning_blocks = 0
|
||||||
|
tool_calls = 0
|
||||||
|
tool_calls_completed = 0
|
||||||
|
tool_calls_aborted = 0
|
||||||
|
thinking_step_parts = 0
|
||||||
|
step_separators = 0
|
||||||
|
|
||||||
|
for part in self.parts:
|
||||||
|
kind = part.get("type")
|
||||||
|
if kind == "text":
|
||||||
|
text_blocks += 1
|
||||||
|
elif kind == "reasoning":
|
||||||
|
reasoning_blocks += 1
|
||||||
|
elif kind == "tool-call":
|
||||||
|
tool_calls += 1
|
||||||
|
if part.get("state") == "aborted":
|
||||||
|
tool_calls_aborted += 1
|
||||||
|
elif "result" in part:
|
||||||
|
tool_calls_completed += 1
|
||||||
|
elif kind == "data-thinking-steps":
|
||||||
|
thinking_step_parts += 1
|
||||||
|
elif kind == "data-step-separator":
|
||||||
|
step_separators += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
byte_size = len(json.dumps(self.parts, ensure_ascii=False, default=str))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
byte_size = -1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"parts": len(self.parts),
|
||||||
|
"bytes": byte_size,
|
||||||
|
"text": text_blocks,
|
||||||
|
"reasoning": reasoning_blocks,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
"tool_calls_completed": tool_calls_completed,
|
||||||
|
"tool_calls_aborted": tool_calls_aborted,
|
||||||
|
"thinking_step_parts": thinking_step_parts,
|
||||||
|
"step_separators": step_separators,
|
||||||
|
}
|
||||||
534
surfsense_backend/app/tasks/chat/persistence.py
Normal file
534
surfsense_backend/app/tasks/chat/persistence.py
Normal file
|
|
@ -0,0 +1,534 @@
|
||||||
|
"""Server-side message persistence helpers for the streaming chat agent.
|
||||||
|
|
||||||
|
Historically the streaming task (``stream_new_chat``/``stream_resume_chat``)
|
||||||
|
left ``new_chat_messages`` empty and relied on the frontend to round-trip
|
||||||
|
``POST /threads/{id}/messages`` afterwards. That gave authenticated clients
|
||||||
|
a "ghost-thread" abuse vector: skip the round-trip and burn LLM tokens
|
||||||
|
without leaving an audit trail. These helpers move both writes (the user
|
||||||
|
turn that triggered the stream and the assistant turn the stream produced)
|
||||||
|
into the server itself, idempotent against the partial unique index
|
||||||
|
``uq_new_chat_messages_thread_turn_role`` so legacy frontends that *do*
|
||||||
|
keep posting via ``appendMessage`` simply hit the unique-index recovery
|
||||||
|
path on the second writer instead of creating duplicates.
|
||||||
|
|
||||||
|
Assistant turn lifecycle
|
||||||
|
------------------------
|
||||||
|
The assistant side is split into two helpers so we can capture the row id
|
||||||
|
*before* the stream produces any output:
|
||||||
|
|
||||||
|
* ``persist_assistant_shell`` runs immediately after ``persist_user_turn``
|
||||||
|
and INSERTs an empty assistant row anchored to ``(thread_id, turn_id,
|
||||||
|
ASSISTANT)``. Returns the row id so the streaming layer can correlate
|
||||||
|
later writes (token_usage, AgentActionLog future-correlation) against
|
||||||
|
a stable PK from the start of the turn.
|
||||||
|
* ``finalize_assistant_turn`` runs from the streaming ``finally`` block.
|
||||||
|
It UPDATEs the row's ``content`` to the rich ``ContentPart[]`` snapshot
|
||||||
|
produced server-side by ``AssistantContentBuilder`` and writes the
|
||||||
|
``token_usage`` row using ``INSERT ... ON CONFLICT DO NOTHING`` against
|
||||||
|
the ``uq_token_usage_message_id`` partial unique index from migration
|
||||||
|
142, hard-eliminating any race against ``append_message``'s recovery
|
||||||
|
branch.
|
||||||
|
|
||||||
|
Defensive contract
|
||||||
|
------------------
|
||||||
|
|
||||||
|
* Every helper runs inside ``shielded_async_session()`` so ``session.close()``
|
||||||
|
survives starlette's mid-stream cancel scope on client disconnect.
|
||||||
|
* ``persist_user_turn`` and ``persist_assistant_shell`` use ``INSERT ... ON
|
||||||
|
CONFLICT DO NOTHING ... RETURNING id`` keyed on the ``(thread_id, turn_id,
|
||||||
|
role)`` partial unique index. On conflict the insert silently no-ops at
|
||||||
|
the DB level — no Python ``IntegrityError`` is constructed, which
|
||||||
|
eliminates spurious debugger pauses and keeps logs clean. On conflict a
|
||||||
|
follow-up ``SELECT`` resolves the existing row id so the streaming layer
|
||||||
|
can correlate writes against a stable PK.
|
||||||
|
* ``finalize_assistant_turn`` is best-effort: it never raises. The
|
||||||
|
streaming ``finally`` block calls it from within
|
||||||
|
``anyio.CancelScope(shield=True)`` and any raised exception there
|
||||||
|
would mask the real error.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import text as sa_text
|
||||||
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.db import (
|
||||||
|
NewChatMessage,
|
||||||
|
NewChatMessageRole,
|
||||||
|
NewChatThread,
|
||||||
|
TokenUsage,
|
||||||
|
shielded_async_session,
|
||||||
|
)
|
||||||
|
from app.services.token_tracking_service import (
|
||||||
|
TurnTokenAccumulator,
|
||||||
|
)
|
||||||
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
|
|
||||||
|
# Empty initial assistant content. ``finalize_assistant_turn`` overwrites
|
||||||
|
# this in a single UPDATE at end-of-stream with the full ``ContentPart[]``
|
||||||
|
# snapshot produced by ``AssistantContentBuilder``. We persist a one-element
|
||||||
|
# list with an empty text part so a crash between shell-INSERT and finalize
|
||||||
|
# leaves the row in a FE-renderable shape (blank bubble) instead of
|
||||||
|
# blowing up the history loader.
|
||||||
|
_EMPTY_SHELL_CONTENT: list[dict[str, Any]] = [{"type": "text", "text": ""}]
|
||||||
|
|
||||||
|
# Substituted content for genuinely empty turns (no text, no reasoning,
|
||||||
|
# no tool calls). The streaming layer flips to this when
|
||||||
|
# ``AssistantContentBuilder.is_empty()`` returns True so the persisted
|
||||||
|
# row is at least somewhat self-describing instead of an empty text
|
||||||
|
# bubble. The FE's ``ContentPart`` union doesn't include ``status``
|
||||||
|
# yet, so the history loader will silently drop this part and render
|
||||||
|
# a blank bubble (matches today's behaviour for empty turns); a follow-up
|
||||||
|
# FE PR adds the explicit "no response" rendering.
|
||||||
|
_STATUS_NO_RESPONSE: list[dict[str, Any]] = [
|
||||||
|
{"type": "status", "text": "(no text response)"}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_user_content(
|
||||||
|
user_query: str,
|
||||||
|
user_image_data_urls: list[str] | None,
|
||||||
|
mentioned_documents: list[dict[str, Any]] | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Build the persisted user-message ``content`` (assistant-ui v2 parts).
|
||||||
|
|
||||||
|
Mirrors the shape the existing frontend posts via
|
||||||
|
``appendMessage`` (see ``surfsense_web/.../new-chat/[[...chat_id]]/page.tsx``):
|
||||||
|
|
||||||
|
[{"type": "text", "text": "..."},
|
||||||
|
{"type": "image", "image": "data:..."},
|
||||||
|
{"type": "mentioned-documents", "documents": [{"id": int,
|
||||||
|
"title": str, "document_type": str}, ...]}]
|
||||||
|
|
||||||
|
The companion reader is
|
||||||
|
``app.utils.user_message_multimodal.split_persisted_user_content_parts``
|
||||||
|
which expects exactly this shape — keep them in sync.
|
||||||
|
|
||||||
|
``mentioned_documents``: optional list of ``{id, title, document_type}``
|
||||||
|
dicts. When non-empty (and a ``mentioned-documents`` part is not already
|
||||||
|
in some other input shape), a single ``{"type": "mentioned-documents",
|
||||||
|
"documents": [...]}`` part is appended. Mirrors the FE injection at
|
||||||
|
``page.tsx:281-286`` (``persistUserTurn``).
|
||||||
|
"""
|
||||||
|
parts: list[dict[str, Any]] = [{"type": "text", "text": user_query or ""}]
|
||||||
|
for url in user_image_data_urls or ():
|
||||||
|
if isinstance(url, str) and url:
|
||||||
|
parts.append({"type": "image", "image": url})
|
||||||
|
if mentioned_documents:
|
||||||
|
normalized: list[dict[str, Any]] = []
|
||||||
|
for doc in mentioned_documents:
|
||||||
|
if not isinstance(doc, dict):
|
||||||
|
continue
|
||||||
|
doc_id = doc.get("id")
|
||||||
|
title = doc.get("title")
|
||||||
|
document_type = doc.get("document_type")
|
||||||
|
if doc_id is None or title is None or document_type is None:
|
||||||
|
continue
|
||||||
|
normalized.append(
|
||||||
|
{
|
||||||
|
"id": doc_id,
|
||||||
|
"title": str(title),
|
||||||
|
"document_type": str(document_type),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if normalized:
|
||||||
|
parts.append({"type": "mentioned-documents", "documents": normalized})
|
||||||
|
return parts
|
||||||
|
|
||||||
|
|
||||||
|
async def persist_user_turn(
|
||||||
|
*,
|
||||||
|
chat_id: int,
|
||||||
|
user_id: str | None,
|
||||||
|
turn_id: str,
|
||||||
|
user_query: str,
|
||||||
|
user_image_data_urls: list[str] | None = None,
|
||||||
|
mentioned_documents: list[dict[str, Any]] | None = None,
|
||||||
|
) -> int | None:
|
||||||
|
"""Persist the user-side row for a chat turn and return its ``id``.
|
||||||
|
|
||||||
|
Uses ``INSERT ... ON CONFLICT DO NOTHING ... RETURNING id`` keyed on the
|
||||||
|
``(thread_id, turn_id, role)`` partial unique index from migration 141
|
||||||
|
(``WHERE turn_id IS NOT NULL``). On conflict the insert silently no-ops
|
||||||
|
at the DB level — no Python ``IntegrityError`` is constructed, which
|
||||||
|
eliminates the debugger pause that ``justMyCode=false`` + async greenlet
|
||||||
|
interactions used to produce, and keeps production logs clean.
|
||||||
|
|
||||||
|
Returns the ``id`` of the row that exists for this turn after the call:
|
||||||
|
the freshly inserted ``id`` on the happy path, or the existing ``id``
|
||||||
|
when a previous writer (legacy FE ``appendMessage`` racing the SSE
|
||||||
|
stream, redelivered request, etc.) already wrote it. Returns ``None``
|
||||||
|
only on genuine DB failure; the caller should yield a streaming error
|
||||||
|
and abort the turn so we never produce a title/assistant row that
|
||||||
|
isn't anchored to a persisted user message.
|
||||||
|
|
||||||
|
Other constraint violations (FK, NOT NULL, etc.) still raise
|
||||||
|
``IntegrityError`` — only the ``(thread_id, turn_id, role)`` collision
|
||||||
|
is silenced.
|
||||||
|
"""
|
||||||
|
if not turn_id:
|
||||||
|
# Defensive: turn_id is always populated by the streaming path
|
||||||
|
# before this helper is called. If it isn't, we cannot be
|
||||||
|
# idempotent against the unique index — refuse to write rather
|
||||||
|
# than create a row the unique index can't dedupe.
|
||||||
|
logger.error(
|
||||||
|
"persist_user_turn called without a turn_id (chat_id=%s); skipping",
|
||||||
|
chat_id,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
outcome = "failed"
|
||||||
|
resolved_id: int | None = None
|
||||||
|
try:
|
||||||
|
async with shielded_async_session() as ws:
|
||||||
|
# Re-attach the thread row so we can also bump updated_at
|
||||||
|
# in the same write — keeps the sidebar ordering accurate
|
||||||
|
# when a user fires off a turn but never reaches the
|
||||||
|
# legacy appendMessage.
|
||||||
|
thread = await ws.get(NewChatThread, chat_id)
|
||||||
|
author_uuid: UUID | None = None
|
||||||
|
if user_id:
|
||||||
|
try:
|
||||||
|
author_uuid = UUID(user_id)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
logger.warning(
|
||||||
|
"persist_user_turn: invalid user_id=%r, persisting as anonymous",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
content_payload = _build_user_content(
|
||||||
|
user_query, user_image_data_urls, mentioned_documents
|
||||||
|
)
|
||||||
|
insert_stmt = (
|
||||||
|
pg_insert(NewChatMessage)
|
||||||
|
.values(
|
||||||
|
thread_id=chat_id,
|
||||||
|
role=NewChatMessageRole.USER,
|
||||||
|
content=content_payload,
|
||||||
|
author_id=author_uuid,
|
||||||
|
turn_id=turn_id,
|
||||||
|
)
|
||||||
|
.on_conflict_do_nothing(
|
||||||
|
index_elements=["thread_id", "turn_id", "role"],
|
||||||
|
index_where=sa_text("turn_id IS NOT NULL"),
|
||||||
|
)
|
||||||
|
.returning(NewChatMessage.id)
|
||||||
|
)
|
||||||
|
inserted_id = (await ws.execute(insert_stmt)).scalar()
|
||||||
|
|
||||||
|
if inserted_id is None:
|
||||||
|
# Conflict on partial unique index — another writer
|
||||||
|
# (legacy FE appendMessage, redelivered request, etc.)
|
||||||
|
# already persisted this row. Look it up and reuse.
|
||||||
|
lookup = await ws.execute(
|
||||||
|
select(NewChatMessage.id).where(
|
||||||
|
NewChatMessage.thread_id == chat_id,
|
||||||
|
NewChatMessage.turn_id == turn_id,
|
||||||
|
NewChatMessage.role == NewChatMessageRole.USER,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
existing_id = lookup.scalars().first()
|
||||||
|
if existing_id is None:
|
||||||
|
# Conflict reported but no row found — extremely
|
||||||
|
# unlikely (concurrent DELETE). Surface as failure.
|
||||||
|
logger.warning(
|
||||||
|
"persist_user_turn: conflict but no matching row "
|
||||||
|
"(chat_id=%s, turn_id=%s)",
|
||||||
|
chat_id,
|
||||||
|
turn_id,
|
||||||
|
)
|
||||||
|
outcome = "integrity_no_match"
|
||||||
|
return None
|
||||||
|
resolved_id = int(existing_id)
|
||||||
|
outcome = "race_recovered"
|
||||||
|
else:
|
||||||
|
resolved_id = int(inserted_id)
|
||||||
|
outcome = "inserted"
|
||||||
|
# Bump thread.updated_at only on a real insert — when
|
||||||
|
# we recovered an existing row the prior writer
|
||||||
|
# already touched the thread.
|
||||||
|
if thread is not None:
|
||||||
|
thread.updated_at = datetime.now(UTC)
|
||||||
|
|
||||||
|
await ws.commit()
|
||||||
|
return resolved_id
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"persist_user_turn failed (chat_id=%s, turn_id=%s)",
|
||||||
|
chat_id,
|
||||||
|
turn_id,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
finally:
|
||||||
|
_perf_log.info(
|
||||||
|
"[persist_user_turn] outcome=%s chat_id=%s turn_id=%s "
|
||||||
|
"message_id=%s query_len=%d images=%d mentioned_docs=%d "
|
||||||
|
"in %.3fs",
|
||||||
|
outcome,
|
||||||
|
chat_id,
|
||||||
|
turn_id,
|
||||||
|
resolved_id,
|
||||||
|
len(user_query or ""),
|
||||||
|
len(user_image_data_urls or ()),
|
||||||
|
len(mentioned_documents or ()),
|
||||||
|
time.perf_counter() - t0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def persist_assistant_shell(
|
||||||
|
*,
|
||||||
|
chat_id: int,
|
||||||
|
user_id: str | None,
|
||||||
|
turn_id: str,
|
||||||
|
) -> int | None:
|
||||||
|
"""Pre-write an empty assistant row for the turn and return its id.
|
||||||
|
|
||||||
|
Inserts a placeholder ``new_chat_messages`` row (empty text content) so
|
||||||
|
the streaming layer has a stable ``message_id`` to correlate against
|
||||||
|
for the rest of the turn. ``finalize_assistant_turn`` overwrites the
|
||||||
|
``content`` field at end-of-stream with the rich ``ContentPart[]``
|
||||||
|
snapshot produced by ``AssistantContentBuilder``.
|
||||||
|
|
||||||
|
Returns the row id on success, ``None`` on a genuine DB failure (caller
|
||||||
|
should abort the turn rather than stream into a void).
|
||||||
|
|
||||||
|
Idempotent against the ``(thread_id, turn_id, ASSISTANT)`` partial unique
|
||||||
|
index from migration 141: if a row already exists (resume retry, racing
|
||||||
|
legacy frontend, redelivered request, etc.) we look it up by
|
||||||
|
``(thread_id, turn_id, role)`` and return its existing id. The streaming
|
||||||
|
layer is then free to UPDATE that row at finalize time.
|
||||||
|
"""
|
||||||
|
if not turn_id:
|
||||||
|
logger.error(
|
||||||
|
"persist_assistant_shell called without a turn_id (chat_id=%s); skipping",
|
||||||
|
chat_id,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
outcome = "failed"
|
||||||
|
resolved_id: int | None = None
|
||||||
|
try:
|
||||||
|
async with shielded_async_session() as ws:
|
||||||
|
insert_stmt = (
|
||||||
|
pg_insert(NewChatMessage)
|
||||||
|
.values(
|
||||||
|
thread_id=chat_id,
|
||||||
|
role=NewChatMessageRole.ASSISTANT,
|
||||||
|
content=_EMPTY_SHELL_CONTENT,
|
||||||
|
author_id=None,
|
||||||
|
turn_id=turn_id,
|
||||||
|
)
|
||||||
|
.on_conflict_do_nothing(
|
||||||
|
index_elements=["thread_id", "turn_id", "role"],
|
||||||
|
index_where=sa_text("turn_id IS NOT NULL"),
|
||||||
|
)
|
||||||
|
.returning(NewChatMessage.id)
|
||||||
|
)
|
||||||
|
inserted_id = (await ws.execute(insert_stmt)).scalar()
|
||||||
|
|
||||||
|
if inserted_id is None:
|
||||||
|
# Conflict — another writer (legacy FE appendMessage,
|
||||||
|
# resume retry, redelivered request) wrote the
|
||||||
|
# (thread_id, turn_id, ASSISTANT) row first. Look it up
|
||||||
|
# so the streaming layer can UPDATE the same row at
|
||||||
|
# finalize time.
|
||||||
|
lookup = await ws.execute(
|
||||||
|
select(NewChatMessage.id).where(
|
||||||
|
NewChatMessage.thread_id == chat_id,
|
||||||
|
NewChatMessage.turn_id == turn_id,
|
||||||
|
NewChatMessage.role == NewChatMessageRole.ASSISTANT,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
existing_id = lookup.scalars().first()
|
||||||
|
if existing_id is None:
|
||||||
|
logger.warning(
|
||||||
|
"persist_assistant_shell: conflict but no matching "
|
||||||
|
"(thread_id, turn_id, role) row found "
|
||||||
|
"(chat_id=%s, turn_id=%s)",
|
||||||
|
chat_id,
|
||||||
|
turn_id,
|
||||||
|
)
|
||||||
|
outcome = "integrity_no_match"
|
||||||
|
return None
|
||||||
|
resolved_id = int(existing_id)
|
||||||
|
outcome = "race_recovered"
|
||||||
|
else:
|
||||||
|
resolved_id = int(inserted_id)
|
||||||
|
outcome = "inserted"
|
||||||
|
|
||||||
|
await ws.commit()
|
||||||
|
return resolved_id
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"persist_assistant_shell failed (chat_id=%s, turn_id=%s)",
|
||||||
|
chat_id,
|
||||||
|
turn_id,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
finally:
|
||||||
|
_perf_log.info(
|
||||||
|
"[persist_assistant_shell] outcome=%s chat_id=%s turn_id=%s "
|
||||||
|
"message_id=%s in %.3fs",
|
||||||
|
outcome,
|
||||||
|
chat_id,
|
||||||
|
turn_id,
|
||||||
|
resolved_id,
|
||||||
|
time.perf_counter() - t0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def finalize_assistant_turn(
|
||||||
|
*,
|
||||||
|
message_id: int,
|
||||||
|
chat_id: int,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str | None,
|
||||||
|
turn_id: str,
|
||||||
|
content: list[dict[str, Any]],
|
||||||
|
accumulator: TurnTokenAccumulator | None,
|
||||||
|
) -> None:
|
||||||
|
"""Finalize the assistant row and write its token_usage.
|
||||||
|
|
||||||
|
Two writes in a single shielded session:
|
||||||
|
|
||||||
|
1. ``UPDATE new_chat_messages SET content = :c, updated_at = now()
|
||||||
|
WHERE id = :id`` — overwrites the placeholder ``persist_assistant_shell``
|
||||||
|
wrote with the full ``ContentPart[]`` snapshot produced server-side.
|
||||||
|
2. ``INSERT INTO token_usage (...) VALUES (...) ON CONFLICT (message_id)
|
||||||
|
WHERE message_id IS NOT NULL DO NOTHING`` — uses the partial unique
|
||||||
|
index ``uq_token_usage_message_id`` from migration 142 to make the
|
||||||
|
insert idempotent against ``append_message``'s recovery branch
|
||||||
|
(which uses the same ON CONFLICT clause).
|
||||||
|
|
||||||
|
Substitutes the status-marker payload when ``content`` is empty
|
||||||
|
(pure tool-call turn that aborted before any output, or interrupt
|
||||||
|
before any event arrived). The status marker is preferable to a
|
||||||
|
blank text bubble because token accounting still runs and an ops
|
||||||
|
dashboard can flag the row.
|
||||||
|
|
||||||
|
Best-effort — never raises. The streaming ``finally`` calls this
|
||||||
|
from within ``anyio.CancelScope(shield=True)``; any raised exception
|
||||||
|
here would mask the real error that triggered the cleanup.
|
||||||
|
"""
|
||||||
|
if not turn_id:
|
||||||
|
logger.error(
|
||||||
|
"finalize_assistant_turn called without turn_id "
|
||||||
|
"(chat_id=%s, message_id=%s); skipping",
|
||||||
|
chat_id,
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if not message_id:
|
||||||
|
logger.error(
|
||||||
|
"finalize_assistant_turn called without message_id "
|
||||||
|
"(chat_id=%s, turn_id=%s); skipping",
|
||||||
|
chat_id,
|
||||||
|
turn_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
payload: list[dict[str, Any]]
|
||||||
|
is_status_marker = False
|
||||||
|
if content:
|
||||||
|
payload = content
|
||||||
|
else:
|
||||||
|
payload = _STATUS_NO_RESPONSE
|
||||||
|
is_status_marker = True
|
||||||
|
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
outcome = "failed"
|
||||||
|
token_usage_attempted = bool(
|
||||||
|
accumulator is not None and accumulator.calls and user_id
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
async with shielded_async_session() as ws:
|
||||||
|
assistant_row = await ws.get(NewChatMessage, message_id)
|
||||||
|
if assistant_row is None:
|
||||||
|
logger.warning(
|
||||||
|
"finalize_assistant_turn: row not found "
|
||||||
|
"(chat_id=%s, message_id=%s, turn_id=%s); skipping",
|
||||||
|
chat_id,
|
||||||
|
message_id,
|
||||||
|
turn_id,
|
||||||
|
)
|
||||||
|
outcome = "row_missing"
|
||||||
|
return
|
||||||
|
|
||||||
|
assistant_row.content = payload
|
||||||
|
assistant_row.updated_at = datetime.now(UTC)
|
||||||
|
|
||||||
|
# Token usage. ``record_token_usage`` (used elsewhere) does
|
||||||
|
# SELECT-then-INSERT in two statements which races with
|
||||||
|
# ``append_message``. Switch to a single INSERT ... ON
|
||||||
|
# CONFLICT DO NOTHING keyed on the migration-142 partial
|
||||||
|
# unique index so the loser silently drops its write at
|
||||||
|
# the DB level — exactly one row per ``message_id``,
|
||||||
|
# regardless of which session committed first.
|
||||||
|
if accumulator is not None and accumulator.calls and user_id:
|
||||||
|
try:
|
||||||
|
user_uuid = UUID(user_id)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
logger.warning(
|
||||||
|
"finalize_assistant_turn: invalid user_id=%r, "
|
||||||
|
"skipping token_usage row",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
insert_stmt = (
|
||||||
|
pg_insert(TokenUsage)
|
||||||
|
.values(
|
||||||
|
usage_type="chat",
|
||||||
|
prompt_tokens=accumulator.total_prompt_tokens,
|
||||||
|
completion_tokens=accumulator.total_completion_tokens,
|
||||||
|
total_tokens=accumulator.grand_total,
|
||||||
|
cost_micros=accumulator.total_cost_micros,
|
||||||
|
model_breakdown=accumulator.per_message_summary(),
|
||||||
|
call_details={"calls": accumulator.serialized_calls()},
|
||||||
|
thread_id=chat_id,
|
||||||
|
message_id=message_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_uuid,
|
||||||
|
)
|
||||||
|
.on_conflict_do_nothing(
|
||||||
|
index_elements=["message_id"],
|
||||||
|
index_where=sa_text("message_id IS NOT NULL"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await ws.execute(insert_stmt)
|
||||||
|
|
||||||
|
await ws.commit()
|
||||||
|
outcome = "ok"
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"finalize_assistant_turn failed (chat_id=%s, message_id=%s, turn_id=%s)",
|
||||||
|
chat_id,
|
||||||
|
message_id,
|
||||||
|
turn_id,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
_perf_log.info(
|
||||||
|
"[finalize_assistant_turn] outcome=%s chat_id=%s message_id=%s "
|
||||||
|
"turn_id=%s parts=%d status_marker=%s "
|
||||||
|
"token_usage_attempted=%s in %.3fs",
|
||||||
|
outcome,
|
||||||
|
chat_id,
|
||||||
|
message_id,
|
||||||
|
turn_id,
|
||||||
|
len(payload),
|
||||||
|
is_status_marker,
|
||||||
|
token_usage_attempted,
|
||||||
|
time.perf_counter() - t0,
|
||||||
|
)
|
||||||
|
|
@ -25,7 +25,6 @@ from uuid import UUID
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
from sqlalchemy import func
|
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
|
@ -314,6 +313,19 @@ class StreamResult:
|
||||||
verification_succeeded: bool = False
|
verification_succeeded: bool = False
|
||||||
commit_gate_passed: bool = True
|
commit_gate_passed: bool = True
|
||||||
commit_gate_reason: str = ""
|
commit_gate_reason: str = ""
|
||||||
|
# Pre-allocated assistant ``new_chat_messages.id`` for this turn,
|
||||||
|
# captured by ``persist_assistant_shell`` right after the user row is
|
||||||
|
# persisted. ``None`` for the legacy / anonymous code paths that don't
|
||||||
|
# opt in to server-side ``ContentPart[]`` projection.
|
||||||
|
assistant_message_id: int | None = None
|
||||||
|
# In-memory mirror of the FE's assistant-ui ``ContentPartsState``,
|
||||||
|
# populated by the lifecycle methods called from ``_stream_agent_events``
|
||||||
|
# at each ``streaming_service.format_*`` yield site. Snapshot in the
|
||||||
|
# streaming ``finally`` to produce the rich JSONB persisted by
|
||||||
|
# ``finalize_assistant_turn``. ``repr=False`` keeps the
|
||||||
|
# log-on-error path (``StreamResult`` is logged in some error
|
||||||
|
# branches) from dumping a potentially-large parts list.
|
||||||
|
content_builder: Any | None = field(default=None, repr=False)
|
||||||
|
|
||||||
|
|
||||||
def _safe_float(value: Any, default: float = 0.0) -> float:
|
def _safe_float(value: Any, default: float = 0.0) -> float:
|
||||||
|
|
@ -721,6 +733,7 @@ async def _stream_agent_events(
|
||||||
fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
||||||
fallback_commit_thread_id: int | None = None,
|
fallback_commit_thread_id: int | None = None,
|
||||||
runtime_context: Any = None,
|
runtime_context: Any = None,
|
||||||
|
content_builder: Any | None = None,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""Shared async generator that streams and formats astream_events from the agent.
|
"""Shared async generator that streams and formats astream_events from the agent.
|
||||||
|
|
||||||
|
|
@ -737,6 +750,15 @@ async def _stream_agent_events(
|
||||||
initial_step_id: If set, the helper inherits an already-active thinking step.
|
initial_step_id: If set, the helper inherits an already-active thinking step.
|
||||||
initial_step_title: Title of the inherited thinking step.
|
initial_step_title: Title of the inherited thinking step.
|
||||||
initial_step_items: Items of the inherited thinking step.
|
initial_step_items: Items of the inherited thinking step.
|
||||||
|
content_builder: Optional ``AssistantContentBuilder``. When set, every
|
||||||
|
``streaming_service.format_*`` yield site also drives the matching
|
||||||
|
builder lifecycle method (``on_text_*``, ``on_reasoning_*``,
|
||||||
|
``on_tool_*``, ``on_thinking_step``, ``on_step_separator``) so the
|
||||||
|
in-memory ``ContentPart[]`` projection stays in lockstep with what
|
||||||
|
the FE renders live. Pure in-memory accumulation — no DB I/O —
|
||||||
|
consumed by the streaming ``finally`` to produce the rich JSONB
|
||||||
|
persisted via ``finalize_assistant_turn``. ``None`` (the default)
|
||||||
|
is used by the anonymous / legacy code paths and is a no-op.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
SSE-formatted strings for each event.
|
SSE-formatted strings for each event.
|
||||||
|
|
@ -801,12 +823,46 @@ async def _stream_agent_events(
|
||||||
current_lc_tool_call_id: dict[str, str | None] = {"value": None}
|
current_lc_tool_call_id: dict[str, str | None] = {"value": None}
|
||||||
|
|
||||||
def _emit_tool_output(call_id: str, output: Any) -> str:
|
def _emit_tool_output(call_id: str, output: Any) -> str:
|
||||||
|
# Drive the builder before formatting the SSE so the in-memory
|
||||||
|
# ContentPart[] mirror sees the result attached to the same
|
||||||
|
# card the FE will render. Builder method is a no-op when
|
||||||
|
# ``content_builder`` is None (anonymous / legacy paths).
|
||||||
|
if content_builder is not None:
|
||||||
|
content_builder.on_tool_output_available(
|
||||||
|
call_id, output, current_lc_tool_call_id["value"]
|
||||||
|
)
|
||||||
return streaming_service.format_tool_output_available(
|
return streaming_service.format_tool_output_available(
|
||||||
call_id,
|
call_id,
|
||||||
output,
|
output,
|
||||||
langchain_tool_call_id=current_lc_tool_call_id["value"],
|
langchain_tool_call_id=current_lc_tool_call_id["value"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _emit_thinking_step(
|
||||||
|
*,
|
||||||
|
step_id: str,
|
||||||
|
title: str,
|
||||||
|
status: str = "in_progress",
|
||||||
|
items: list[str] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Format a thinking-step SSE event and notify the builder.
|
||||||
|
|
||||||
|
Single helper used at every ``format_thinking_step`` yield site
|
||||||
|
in this generator. Drives ``AssistantContentBuilder.on_thinking_step``
|
||||||
|
first so the FE-mirror state lands the update before the SSE
|
||||||
|
carrying the same data leaves the wire — order matches the FE
|
||||||
|
pipeline (``processSharedStreamEvent`` updates state, then
|
||||||
|
flushes). Builder call is a no-op when ``content_builder`` is
|
||||||
|
None (anonymous / legacy paths).
|
||||||
|
"""
|
||||||
|
if content_builder is not None:
|
||||||
|
content_builder.on_thinking_step(step_id, title, status, items)
|
||||||
|
return streaming_service.format_thinking_step(
|
||||||
|
step_id=step_id,
|
||||||
|
title=title,
|
||||||
|
status=status,
|
||||||
|
items=items,
|
||||||
|
)
|
||||||
|
|
||||||
def next_thinking_step_id() -> str:
|
def next_thinking_step_id() -> str:
|
||||||
nonlocal thinking_step_counter
|
nonlocal thinking_step_counter
|
||||||
thinking_step_counter += 1
|
thinking_step_counter += 1
|
||||||
|
|
@ -816,7 +872,7 @@ async def _stream_agent_events(
|
||||||
nonlocal last_active_step_id
|
nonlocal last_active_step_id
|
||||||
if last_active_step_id and last_active_step_id not in completed_step_ids:
|
if last_active_step_id and last_active_step_id not in completed_step_ids:
|
||||||
completed_step_ids.add(last_active_step_id)
|
completed_step_ids.add(last_active_step_id)
|
||||||
event = streaming_service.format_thinking_step(
|
event = _emit_thinking_step(
|
||||||
step_id=last_active_step_id,
|
step_id=last_active_step_id,
|
||||||
title=last_active_step_title,
|
title=last_active_step_title,
|
||||||
status="completed",
|
status="completed",
|
||||||
|
|
@ -861,6 +917,8 @@ async def _stream_agent_events(
|
||||||
if parity_v2 and reasoning_delta:
|
if parity_v2 and reasoning_delta:
|
||||||
if current_text_id is not None:
|
if current_text_id is not None:
|
||||||
yield streaming_service.format_text_end(current_text_id)
|
yield streaming_service.format_text_end(current_text_id)
|
||||||
|
if content_builder is not None:
|
||||||
|
content_builder.on_text_end(current_text_id)
|
||||||
current_text_id = None
|
current_text_id = None
|
||||||
if current_reasoning_id is None:
|
if current_reasoning_id is None:
|
||||||
completion_event = complete_current_step()
|
completion_event = complete_current_step()
|
||||||
|
|
@ -873,13 +931,21 @@ async def _stream_agent_events(
|
||||||
just_finished_tool = False
|
just_finished_tool = False
|
||||||
current_reasoning_id = streaming_service.generate_reasoning_id()
|
current_reasoning_id = streaming_service.generate_reasoning_id()
|
||||||
yield streaming_service.format_reasoning_start(current_reasoning_id)
|
yield streaming_service.format_reasoning_start(current_reasoning_id)
|
||||||
|
if content_builder is not None:
|
||||||
|
content_builder.on_reasoning_start(current_reasoning_id)
|
||||||
yield streaming_service.format_reasoning_delta(
|
yield streaming_service.format_reasoning_delta(
|
||||||
current_reasoning_id, reasoning_delta
|
current_reasoning_id, reasoning_delta
|
||||||
)
|
)
|
||||||
|
if content_builder is not None:
|
||||||
|
content_builder.on_reasoning_delta(
|
||||||
|
current_reasoning_id, reasoning_delta
|
||||||
|
)
|
||||||
|
|
||||||
if text_delta:
|
if text_delta:
|
||||||
if current_reasoning_id is not None:
|
if current_reasoning_id is not None:
|
||||||
yield streaming_service.format_reasoning_end(current_reasoning_id)
|
yield streaming_service.format_reasoning_end(current_reasoning_id)
|
||||||
|
if content_builder is not None:
|
||||||
|
content_builder.on_reasoning_end(current_reasoning_id)
|
||||||
current_reasoning_id = None
|
current_reasoning_id = None
|
||||||
if current_text_id is None:
|
if current_text_id is None:
|
||||||
completion_event = complete_current_step()
|
completion_event = complete_current_step()
|
||||||
|
|
@ -892,8 +958,12 @@ async def _stream_agent_events(
|
||||||
just_finished_tool = False
|
just_finished_tool = False
|
||||||
current_text_id = streaming_service.generate_text_id()
|
current_text_id = streaming_service.generate_text_id()
|
||||||
yield streaming_service.format_text_start(current_text_id)
|
yield streaming_service.format_text_start(current_text_id)
|
||||||
|
if content_builder is not None:
|
||||||
|
content_builder.on_text_start(current_text_id)
|
||||||
yield streaming_service.format_text_delta(current_text_id, text_delta)
|
yield streaming_service.format_text_delta(current_text_id, text_delta)
|
||||||
accumulated_text += text_delta
|
accumulated_text += text_delta
|
||||||
|
if content_builder is not None:
|
||||||
|
content_builder.on_text_delta(current_text_id, text_delta)
|
||||||
|
|
||||||
# Live tool-call argument streaming. Runs AFTER text/reasoning
|
# Live tool-call argument streaming. Runs AFTER text/reasoning
|
||||||
# processing so chunks containing both stay in their natural
|
# processing so chunks containing both stay in their natural
|
||||||
|
|
@ -925,11 +995,17 @@ async def _stream_agent_events(
|
||||||
# within the same stream window.
|
# within the same stream window.
|
||||||
if current_text_id is not None:
|
if current_text_id is not None:
|
||||||
yield streaming_service.format_text_end(current_text_id)
|
yield streaming_service.format_text_end(current_text_id)
|
||||||
|
if content_builder is not None:
|
||||||
|
content_builder.on_text_end(current_text_id)
|
||||||
current_text_id = None
|
current_text_id = None
|
||||||
if current_reasoning_id is not None:
|
if current_reasoning_id is not None:
|
||||||
yield streaming_service.format_reasoning_end(
|
yield streaming_service.format_reasoning_end(
|
||||||
current_reasoning_id
|
current_reasoning_id
|
||||||
)
|
)
|
||||||
|
if content_builder is not None:
|
||||||
|
content_builder.on_reasoning_end(
|
||||||
|
current_reasoning_id
|
||||||
|
)
|
||||||
current_reasoning_id = None
|
current_reasoning_id = None
|
||||||
|
|
||||||
index_to_meta[idx] = {
|
index_to_meta[idx] = {
|
||||||
|
|
@ -942,6 +1018,8 @@ async def _stream_agent_events(
|
||||||
name,
|
name,
|
||||||
langchain_tool_call_id=lc_id,
|
langchain_tool_call_id=lc_id,
|
||||||
)
|
)
|
||||||
|
if content_builder is not None:
|
||||||
|
content_builder.on_tool_input_start(ui_id, name, lc_id)
|
||||||
|
|
||||||
# Emit args delta for any chunk at a registered
|
# Emit args delta for any chunk at a registered
|
||||||
# index (including idless continuations). Once an
|
# index (including idless continuations). Once an
|
||||||
|
|
@ -957,6 +1035,10 @@ async def _stream_agent_events(
|
||||||
yield streaming_service.format_tool_input_delta(
|
yield streaming_service.format_tool_input_delta(
|
||||||
meta["ui_id"], args_chunk
|
meta["ui_id"], args_chunk
|
||||||
)
|
)
|
||||||
|
if content_builder is not None:
|
||||||
|
content_builder.on_tool_input_delta(
|
||||||
|
meta["ui_id"], args_chunk
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
pending_tool_call_chunks.append(tcc)
|
pending_tool_call_chunks.append(tcc)
|
||||||
|
|
||||||
|
|
@ -974,6 +1056,8 @@ async def _stream_agent_events(
|
||||||
|
|
||||||
if current_text_id is not None:
|
if current_text_id is not None:
|
||||||
yield streaming_service.format_text_end(current_text_id)
|
yield streaming_service.format_text_end(current_text_id)
|
||||||
|
if content_builder is not None:
|
||||||
|
content_builder.on_text_end(current_text_id)
|
||||||
current_text_id = None
|
current_text_id = None
|
||||||
|
|
||||||
if last_active_step_title != "Synthesizing response":
|
if last_active_step_title != "Synthesizing response":
|
||||||
|
|
@ -994,7 +1078,7 @@ async def _stream_agent_events(
|
||||||
)
|
)
|
||||||
last_active_step_title = "Listing files"
|
last_active_step_title = "Listing files"
|
||||||
last_active_step_items = [ls_path]
|
last_active_step_items = [ls_path]
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=tool_step_id,
|
step_id=tool_step_id,
|
||||||
title="Listing files",
|
title="Listing files",
|
||||||
status="in_progress",
|
status="in_progress",
|
||||||
|
|
@ -1009,7 +1093,7 @@ async def _stream_agent_events(
|
||||||
display_fp = fp if len(fp) <= 80 else "…" + fp[-77:]
|
display_fp = fp if len(fp) <= 80 else "…" + fp[-77:]
|
||||||
last_active_step_title = "Reading file"
|
last_active_step_title = "Reading file"
|
||||||
last_active_step_items = [display_fp]
|
last_active_step_items = [display_fp]
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=tool_step_id,
|
step_id=tool_step_id,
|
||||||
title="Reading file",
|
title="Reading file",
|
||||||
status="in_progress",
|
status="in_progress",
|
||||||
|
|
@ -1024,7 +1108,7 @@ async def _stream_agent_events(
|
||||||
display_fp = fp if len(fp) <= 80 else "…" + fp[-77:]
|
display_fp = fp if len(fp) <= 80 else "…" + fp[-77:]
|
||||||
last_active_step_title = "Writing file"
|
last_active_step_title = "Writing file"
|
||||||
last_active_step_items = [display_fp]
|
last_active_step_items = [display_fp]
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=tool_step_id,
|
step_id=tool_step_id,
|
||||||
title="Writing file",
|
title="Writing file",
|
||||||
status="in_progress",
|
status="in_progress",
|
||||||
|
|
@ -1039,7 +1123,7 @@ async def _stream_agent_events(
|
||||||
display_fp = fp if len(fp) <= 80 else "…" + fp[-77:]
|
display_fp = fp if len(fp) <= 80 else "…" + fp[-77:]
|
||||||
last_active_step_title = "Editing file"
|
last_active_step_title = "Editing file"
|
||||||
last_active_step_items = [display_fp]
|
last_active_step_items = [display_fp]
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=tool_step_id,
|
step_id=tool_step_id,
|
||||||
title="Editing file",
|
title="Editing file",
|
||||||
status="in_progress",
|
status="in_progress",
|
||||||
|
|
@ -1056,7 +1140,7 @@ async def _stream_agent_events(
|
||||||
)
|
)
|
||||||
last_active_step_title = "Searching files"
|
last_active_step_title = "Searching files"
|
||||||
last_active_step_items = [f"{pat} in {base_path}"]
|
last_active_step_items = [f"{pat} in {base_path}"]
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=tool_step_id,
|
step_id=tool_step_id,
|
||||||
title="Searching files",
|
title="Searching files",
|
||||||
status="in_progress",
|
status="in_progress",
|
||||||
|
|
@ -1076,7 +1160,7 @@ async def _stream_agent_events(
|
||||||
last_active_step_items = [
|
last_active_step_items = [
|
||||||
f'"{display_pat}"' + (f" in {grep_path}" if grep_path else "")
|
f'"{display_pat}"' + (f" in {grep_path}" if grep_path else "")
|
||||||
]
|
]
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=tool_step_id,
|
step_id=tool_step_id,
|
||||||
title="Searching content",
|
title="Searching content",
|
||||||
status="in_progress",
|
status="in_progress",
|
||||||
|
|
@ -1091,7 +1175,7 @@ async def _stream_agent_events(
|
||||||
display_path = rm_path if len(rm_path) <= 80 else "…" + rm_path[-77:]
|
display_path = rm_path if len(rm_path) <= 80 else "…" + rm_path[-77:]
|
||||||
last_active_step_title = "Deleting file"
|
last_active_step_title = "Deleting file"
|
||||||
last_active_step_items = [display_path] if display_path else []
|
last_active_step_items = [display_path] if display_path else []
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=tool_step_id,
|
step_id=tool_step_id,
|
||||||
title="Deleting file",
|
title="Deleting file",
|
||||||
status="in_progress",
|
status="in_progress",
|
||||||
|
|
@ -1108,7 +1192,7 @@ async def _stream_agent_events(
|
||||||
)
|
)
|
||||||
last_active_step_title = "Deleting folder"
|
last_active_step_title = "Deleting folder"
|
||||||
last_active_step_items = [display_path] if display_path else []
|
last_active_step_items = [display_path] if display_path else []
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=tool_step_id,
|
step_id=tool_step_id,
|
||||||
title="Deleting folder",
|
title="Deleting folder",
|
||||||
status="in_progress",
|
status="in_progress",
|
||||||
|
|
@ -1125,7 +1209,7 @@ async def _stream_agent_events(
|
||||||
)
|
)
|
||||||
last_active_step_title = "Creating folder"
|
last_active_step_title = "Creating folder"
|
||||||
last_active_step_items = [display_path] if display_path else []
|
last_active_step_items = [display_path] if display_path else []
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=tool_step_id,
|
step_id=tool_step_id,
|
||||||
title="Creating folder",
|
title="Creating folder",
|
||||||
status="in_progress",
|
status="in_progress",
|
||||||
|
|
@ -1148,7 +1232,7 @@ async def _stream_agent_events(
|
||||||
last_active_step_items = (
|
last_active_step_items = (
|
||||||
[f"{display_src} → {display_dst}"] if src or dst else []
|
[f"{display_src} → {display_dst}"] if src or dst else []
|
||||||
)
|
)
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=tool_step_id,
|
step_id=tool_step_id,
|
||||||
title="Moving file",
|
title="Moving file",
|
||||||
status="in_progress",
|
status="in_progress",
|
||||||
|
|
@ -1165,7 +1249,7 @@ async def _stream_agent_events(
|
||||||
if todo_count
|
if todo_count
|
||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=tool_step_id,
|
step_id=tool_step_id,
|
||||||
title="Planning tasks",
|
title="Planning tasks",
|
||||||
status="in_progress",
|
status="in_progress",
|
||||||
|
|
@ -1180,7 +1264,7 @@ async def _stream_agent_events(
|
||||||
display_title = doc_title[:60] + ("…" if len(doc_title) > 60 else "")
|
display_title = doc_title[:60] + ("…" if len(doc_title) > 60 else "")
|
||||||
last_active_step_title = "Saving document"
|
last_active_step_title = "Saving document"
|
||||||
last_active_step_items = [display_title]
|
last_active_step_items = [display_title]
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=tool_step_id,
|
step_id=tool_step_id,
|
||||||
title="Saving document",
|
title="Saving document",
|
||||||
status="in_progress",
|
status="in_progress",
|
||||||
|
|
@ -1196,7 +1280,7 @@ async def _stream_agent_events(
|
||||||
last_active_step_items = [
|
last_active_step_items = [
|
||||||
f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}"
|
f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}"
|
||||||
]
|
]
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=tool_step_id,
|
step_id=tool_step_id,
|
||||||
title="Generating image",
|
title="Generating image",
|
||||||
status="in_progress",
|
status="in_progress",
|
||||||
|
|
@ -1212,7 +1296,7 @@ async def _stream_agent_events(
|
||||||
last_active_step_items = [
|
last_active_step_items = [
|
||||||
f"URL: {url[:80]}{'...' if len(url) > 80 else ''}"
|
f"URL: {url[:80]}{'...' if len(url) > 80 else ''}"
|
||||||
]
|
]
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=tool_step_id,
|
step_id=tool_step_id,
|
||||||
title="Scraping webpage",
|
title="Scraping webpage",
|
||||||
status="in_progress",
|
status="in_progress",
|
||||||
|
|
@ -1235,7 +1319,7 @@ async def _stream_agent_events(
|
||||||
f"Content: {content_len:,} characters",
|
f"Content: {content_len:,} characters",
|
||||||
"Preparing audio generation...",
|
"Preparing audio generation...",
|
||||||
]
|
]
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=tool_step_id,
|
step_id=tool_step_id,
|
||||||
title="Generating podcast",
|
title="Generating podcast",
|
||||||
status="in_progress",
|
status="in_progress",
|
||||||
|
|
@ -1256,7 +1340,7 @@ async def _stream_agent_events(
|
||||||
f"Topic: {report_topic}",
|
f"Topic: {report_topic}",
|
||||||
"Analyzing source content...",
|
"Analyzing source content...",
|
||||||
]
|
]
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=tool_step_id,
|
step_id=tool_step_id,
|
||||||
title=step_title,
|
title=step_title,
|
||||||
status="in_progress",
|
status="in_progress",
|
||||||
|
|
@ -1271,7 +1355,7 @@ async def _stream_agent_events(
|
||||||
display_cmd = cmd[:80] + ("…" if len(cmd) > 80 else "")
|
display_cmd = cmd[:80] + ("…" if len(cmd) > 80 else "")
|
||||||
last_active_step_title = "Running command"
|
last_active_step_title = "Running command"
|
||||||
last_active_step_items = [f"$ {display_cmd}"]
|
last_active_step_items = [f"$ {display_cmd}"]
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=tool_step_id,
|
step_id=tool_step_id,
|
||||||
title="Running command",
|
title="Running command",
|
||||||
status="in_progress",
|
status="in_progress",
|
||||||
|
|
@ -1288,7 +1372,7 @@ async def _stream_agent_events(
|
||||||
tool_name.replace("_", " ").strip().capitalize() or tool_name
|
tool_name.replace("_", " ").strip().capitalize() or tool_name
|
||||||
)
|
)
|
||||||
last_active_step_items = []
|
last_active_step_items = []
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=tool_step_id,
|
step_id=tool_step_id,
|
||||||
title=last_active_step_title,
|
title=last_active_step_title,
|
||||||
status="in_progress",
|
status="in_progress",
|
||||||
|
|
@ -1349,6 +1433,10 @@ async def _stream_agent_events(
|
||||||
tool_name,
|
tool_name,
|
||||||
langchain_tool_call_id=langchain_tool_call_id,
|
langchain_tool_call_id=langchain_tool_call_id,
|
||||||
)
|
)
|
||||||
|
if content_builder is not None:
|
||||||
|
content_builder.on_tool_input_start(
|
||||||
|
tool_call_id, tool_name, langchain_tool_call_id
|
||||||
|
)
|
||||||
|
|
||||||
if run_id:
|
if run_id:
|
||||||
ui_tool_call_id_by_run[run_id] = tool_call_id
|
ui_tool_call_id_by_run[run_id] = tool_call_id
|
||||||
|
|
@ -1371,6 +1459,13 @@ async def _stream_agent_events(
|
||||||
_safe_input,
|
_safe_input,
|
||||||
langchain_tool_call_id=langchain_tool_call_id,
|
langchain_tool_call_id=langchain_tool_call_id,
|
||||||
)
|
)
|
||||||
|
if content_builder is not None:
|
||||||
|
content_builder.on_tool_input_available(
|
||||||
|
tool_call_id,
|
||||||
|
tool_name,
|
||||||
|
_safe_input,
|
||||||
|
langchain_tool_call_id,
|
||||||
|
)
|
||||||
|
|
||||||
elif event_type == "on_tool_end":
|
elif event_type == "on_tool_end":
|
||||||
active_tool_depth = max(0, active_tool_depth - 1)
|
active_tool_depth = max(0, active_tool_depth - 1)
|
||||||
|
|
@ -1443,70 +1538,70 @@ async def _stream_agent_events(
|
||||||
current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id]
|
current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id]
|
||||||
|
|
||||||
if tool_name == "read_file":
|
if tool_name == "read_file":
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=original_step_id,
|
step_id=original_step_id,
|
||||||
title="Reading file",
|
title="Reading file",
|
||||||
status="completed",
|
status="completed",
|
||||||
items=last_active_step_items,
|
items=last_active_step_items,
|
||||||
)
|
)
|
||||||
elif tool_name == "write_file":
|
elif tool_name == "write_file":
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=original_step_id,
|
step_id=original_step_id,
|
||||||
title="Writing file",
|
title="Writing file",
|
||||||
status="completed",
|
status="completed",
|
||||||
items=last_active_step_items,
|
items=last_active_step_items,
|
||||||
)
|
)
|
||||||
elif tool_name == "edit_file":
|
elif tool_name == "edit_file":
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=original_step_id,
|
step_id=original_step_id,
|
||||||
title="Editing file",
|
title="Editing file",
|
||||||
status="completed",
|
status="completed",
|
||||||
items=last_active_step_items,
|
items=last_active_step_items,
|
||||||
)
|
)
|
||||||
elif tool_name == "glob":
|
elif tool_name == "glob":
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=original_step_id,
|
step_id=original_step_id,
|
||||||
title="Searching files",
|
title="Searching files",
|
||||||
status="completed",
|
status="completed",
|
||||||
items=last_active_step_items,
|
items=last_active_step_items,
|
||||||
)
|
)
|
||||||
elif tool_name == "grep":
|
elif tool_name == "grep":
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=original_step_id,
|
step_id=original_step_id,
|
||||||
title="Searching content",
|
title="Searching content",
|
||||||
status="completed",
|
status="completed",
|
||||||
items=last_active_step_items,
|
items=last_active_step_items,
|
||||||
)
|
)
|
||||||
elif tool_name == "rm":
|
elif tool_name == "rm":
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=original_step_id,
|
step_id=original_step_id,
|
||||||
title="Deleting file",
|
title="Deleting file",
|
||||||
status="completed",
|
status="completed",
|
||||||
items=last_active_step_items,
|
items=last_active_step_items,
|
||||||
)
|
)
|
||||||
elif tool_name == "rmdir":
|
elif tool_name == "rmdir":
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=original_step_id,
|
step_id=original_step_id,
|
||||||
title="Deleting folder",
|
title="Deleting folder",
|
||||||
status="completed",
|
status="completed",
|
||||||
items=last_active_step_items,
|
items=last_active_step_items,
|
||||||
)
|
)
|
||||||
elif tool_name == "mkdir":
|
elif tool_name == "mkdir":
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=original_step_id,
|
step_id=original_step_id,
|
||||||
title="Creating folder",
|
title="Creating folder",
|
||||||
status="completed",
|
status="completed",
|
||||||
items=last_active_step_items,
|
items=last_active_step_items,
|
||||||
)
|
)
|
||||||
elif tool_name == "move_file":
|
elif tool_name == "move_file":
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=original_step_id,
|
step_id=original_step_id,
|
||||||
title="Moving file",
|
title="Moving file",
|
||||||
status="completed",
|
status="completed",
|
||||||
items=last_active_step_items,
|
items=last_active_step_items,
|
||||||
)
|
)
|
||||||
elif tool_name == "write_todos":
|
elif tool_name == "write_todos":
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=original_step_id,
|
step_id=original_step_id,
|
||||||
title="Planning tasks",
|
title="Planning tasks",
|
||||||
status="completed",
|
status="completed",
|
||||||
|
|
@ -1523,7 +1618,7 @@ async def _stream_agent_events(
|
||||||
*last_active_step_items,
|
*last_active_step_items,
|
||||||
result_str[:80] if is_error else "Saved to knowledge base",
|
result_str[:80] if is_error else "Saved to knowledge base",
|
||||||
]
|
]
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=original_step_id,
|
step_id=original_step_id,
|
||||||
title="Saving document",
|
title="Saving document",
|
||||||
status="completed",
|
status="completed",
|
||||||
|
|
@ -1542,7 +1637,7 @@ async def _stream_agent_events(
|
||||||
else "Generation failed"
|
else "Generation failed"
|
||||||
)
|
)
|
||||||
completed_items = [*last_active_step_items, f"Error: {error_msg}"]
|
completed_items = [*last_active_step_items, f"Error: {error_msg}"]
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=original_step_id,
|
step_id=original_step_id,
|
||||||
title="Generating image",
|
title="Generating image",
|
||||||
status="completed",
|
status="completed",
|
||||||
|
|
@ -1566,7 +1661,7 @@ async def _stream_agent_events(
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
completed_items = [*last_active_step_items, "Content extracted"]
|
completed_items = [*last_active_step_items, "Content extracted"]
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=original_step_id,
|
step_id=original_step_id,
|
||||||
title="Scraping webpage",
|
title="Scraping webpage",
|
||||||
status="completed",
|
status="completed",
|
||||||
|
|
@ -1612,7 +1707,7 @@ async def _stream_agent_events(
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
completed_items = last_active_step_items
|
completed_items = last_active_step_items
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=original_step_id,
|
step_id=original_step_id,
|
||||||
title="Generating podcast",
|
title="Generating podcast",
|
||||||
status="completed",
|
status="completed",
|
||||||
|
|
@ -1647,7 +1742,7 @@ async def _stream_agent_events(
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
completed_items = last_active_step_items
|
completed_items = last_active_step_items
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=original_step_id,
|
step_id=original_step_id,
|
||||||
title="Generating video presentation",
|
title="Generating video presentation",
|
||||||
status="completed",
|
status="completed",
|
||||||
|
|
@ -1695,7 +1790,7 @@ async def _stream_agent_events(
|
||||||
else:
|
else:
|
||||||
completed_items = last_active_step_items
|
completed_items = last_active_step_items
|
||||||
|
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=original_step_id,
|
step_id=original_step_id,
|
||||||
title=step_title,
|
title=step_title,
|
||||||
status="completed",
|
status="completed",
|
||||||
|
|
@ -1721,7 +1816,7 @@ async def _stream_agent_events(
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
completed_items = [*last_active_step_items, "Finished"]
|
completed_items = [*last_active_step_items, "Finished"]
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=original_step_id,
|
step_id=original_step_id,
|
||||||
title="Running command",
|
title="Running command",
|
||||||
status="completed",
|
status="completed",
|
||||||
|
|
@ -1761,7 +1856,7 @@ async def _stream_agent_events(
|
||||||
completed_items.append(f"(+{len(file_names) - 4} more)")
|
completed_items.append(f"(+{len(file_names) - 4} more)")
|
||||||
else:
|
else:
|
||||||
completed_items = ["No files found"]
|
completed_items = ["No files found"]
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=original_step_id,
|
step_id=original_step_id,
|
||||||
title="Listing files",
|
title="Listing files",
|
||||||
status="completed",
|
status="completed",
|
||||||
|
|
@ -1773,7 +1868,7 @@ async def _stream_agent_events(
|
||||||
fallback_title = (
|
fallback_title = (
|
||||||
tool_name.replace("_", " ").strip().capitalize() or tool_name
|
tool_name.replace("_", " ").strip().capitalize() or tool_name
|
||||||
)
|
)
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=original_step_id,
|
step_id=original_step_id,
|
||||||
title=fallback_title,
|
title=fallback_title,
|
||||||
status="completed",
|
status="completed",
|
||||||
|
|
@ -2113,7 +2208,7 @@ async def _stream_agent_events(
|
||||||
# Phase transitions: replace everything after topic
|
# Phase transitions: replace everything after topic
|
||||||
last_active_step_items = [*topic_items, message]
|
last_active_step_items = [*topic_items, message]
|
||||||
|
|
||||||
yield streaming_service.format_thinking_step(
|
yield _emit_thinking_step(
|
||||||
step_id=last_active_step_id,
|
step_id=last_active_step_id,
|
||||||
title=last_active_step_title,
|
title=last_active_step_title,
|
||||||
status="in_progress",
|
status="in_progress",
|
||||||
|
|
@ -2155,10 +2250,14 @@ async def _stream_agent_events(
|
||||||
elif event_type in ("on_chain_end", "on_agent_end"):
|
elif event_type in ("on_chain_end", "on_agent_end"):
|
||||||
if current_text_id is not None:
|
if current_text_id is not None:
|
||||||
yield streaming_service.format_text_end(current_text_id)
|
yield streaming_service.format_text_end(current_text_id)
|
||||||
|
if content_builder is not None:
|
||||||
|
content_builder.on_text_end(current_text_id)
|
||||||
current_text_id = None
|
current_text_id = None
|
||||||
|
|
||||||
if current_text_id is not None:
|
if current_text_id is not None:
|
||||||
yield streaming_service.format_text_end(current_text_id)
|
yield streaming_service.format_text_end(current_text_id)
|
||||||
|
if content_builder is not None:
|
||||||
|
content_builder.on_text_end(current_text_id)
|
||||||
|
|
||||||
completion_event = complete_current_step()
|
completion_event = complete_current_step()
|
||||||
if completion_event:
|
if completion_event:
|
||||||
|
|
@ -2243,8 +2342,14 @@ async def _stream_agent_events(
|
||||||
)
|
)
|
||||||
gate_text_id = streaming_service.generate_text_id()
|
gate_text_id = streaming_service.generate_text_id()
|
||||||
yield streaming_service.format_text_start(gate_text_id)
|
yield streaming_service.format_text_start(gate_text_id)
|
||||||
|
if content_builder is not None:
|
||||||
|
content_builder.on_text_start(gate_text_id)
|
||||||
yield streaming_service.format_text_delta(gate_text_id, gate_notice)
|
yield streaming_service.format_text_delta(gate_text_id, gate_notice)
|
||||||
|
if content_builder is not None:
|
||||||
|
content_builder.on_text_delta(gate_text_id, gate_notice)
|
||||||
yield streaming_service.format_text_end(gate_text_id)
|
yield streaming_service.format_text_end(gate_text_id)
|
||||||
|
if content_builder is not None:
|
||||||
|
content_builder.on_text_end(gate_text_id)
|
||||||
yield streaming_service.format_terminal_info(gate_notice, "error")
|
yield streaming_service.format_terminal_info(gate_notice, "error")
|
||||||
accumulated_text = gate_notice
|
accumulated_text = gate_notice
|
||||||
else:
|
else:
|
||||||
|
|
@ -2270,6 +2375,7 @@ async def stream_new_chat(
|
||||||
llm_config_id: int = -1,
|
llm_config_id: int = -1,
|
||||||
mentioned_document_ids: list[int] | None = None,
|
mentioned_document_ids: list[int] | None = None,
|
||||||
mentioned_surfsense_doc_ids: list[int] | None = None,
|
mentioned_surfsense_doc_ids: list[int] | None = None,
|
||||||
|
mentioned_documents: list[dict[str, Any]] | None = None,
|
||||||
checkpoint_id: str | None = None,
|
checkpoint_id: str | None = None,
|
||||||
needs_history_bootstrap: bool = False,
|
needs_history_bootstrap: bool = False,
|
||||||
thread_visibility: ChatVisibility | None = None,
|
thread_visibility: ChatVisibility | None = None,
|
||||||
|
|
@ -2949,6 +3055,96 @@ async def stream_new_chat(
|
||||||
)
|
)
|
||||||
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||||
|
|
||||||
|
# Persist the user-side row for this turn before any expensive
|
||||||
|
# work runs. Closes the "ghost-thread" abuse vector
|
||||||
|
# (authenticated client hits POST /new_chat then never calls
|
||||||
|
# /messages — empty new_chat_messages, free LLM completion).
|
||||||
|
# Idempotent against the unique index in migration 141 so the
|
||||||
|
# legacy frontend appendMessage call is a no-op on the second
|
||||||
|
# writer. Hard failure aborts the turn so we never produce a
|
||||||
|
# title or assistant row that isn't anchored to a persisted
|
||||||
|
# user message.
|
||||||
|
from app.tasks.chat.content_builder import AssistantContentBuilder
|
||||||
|
from app.tasks.chat.persistence import (
|
||||||
|
persist_assistant_shell,
|
||||||
|
persist_user_turn,
|
||||||
|
)
|
||||||
|
|
||||||
|
user_message_id = await persist_user_turn(
|
||||||
|
chat_id=chat_id,
|
||||||
|
user_id=user_id,
|
||||||
|
turn_id=stream_result.turn_id,
|
||||||
|
user_query=user_query,
|
||||||
|
user_image_data_urls=user_image_data_urls,
|
||||||
|
mentioned_documents=mentioned_documents,
|
||||||
|
)
|
||||||
|
if user_message_id is None:
|
||||||
|
yield _emit_stream_error(
|
||||||
|
message=(
|
||||||
|
"We couldn't save your message. Please try again in a moment."
|
||||||
|
),
|
||||||
|
error_kind="server_error",
|
||||||
|
error_code="MESSAGE_PERSIST_FAILED",
|
||||||
|
)
|
||||||
|
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||||
|
yield streaming_service.format_finish_step()
|
||||||
|
yield streaming_service.format_finish()
|
||||||
|
yield streaming_service.format_done()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Emit canonical user message id BEFORE any LLM streaming so the
|
||||||
|
# FE can rename its optimistic ``msg-user-XXX`` placeholder to
|
||||||
|
# ``msg-{user_message_id}`` and unlock features gated on a real
|
||||||
|
# DB id (comments, edit-from-this-message). See B4 in
|
||||||
|
# ``sse-based_message_id_handshake`` plan.
|
||||||
|
yield streaming_service.format_data(
|
||||||
|
"user-message-id",
|
||||||
|
{"message_id": user_message_id, "turn_id": stream_result.turn_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pre-write the assistant row for this turn so we have a stable
|
||||||
|
# ``message_id`` to anchor mid-stream metadata (token_usage,
|
||||||
|
# future agent_action_log.message_id correlation) and a
|
||||||
|
# write-once UPDATE target at finalize time. Idempotent against
|
||||||
|
# the (thread_id, turn_id, ASSISTANT) partial unique index from
|
||||||
|
# migration 141 — if the legacy frontend appendMessage races
|
||||||
|
# this, we recover the existing row's id.
|
||||||
|
assistant_message_id = await persist_assistant_shell(
|
||||||
|
chat_id=chat_id,
|
||||||
|
user_id=user_id,
|
||||||
|
turn_id=stream_result.turn_id,
|
||||||
|
)
|
||||||
|
if assistant_message_id is None:
|
||||||
|
# Genuine DB failure — abort the turn rather than stream
|
||||||
|
# into a void. The user row is already persisted so the
|
||||||
|
# legacy "ghost-thread" gate isn't reopened.
|
||||||
|
yield _emit_stream_error(
|
||||||
|
message=(
|
||||||
|
"We couldn't initialize the assistant message. Please try again."
|
||||||
|
),
|
||||||
|
error_kind="server_error",
|
||||||
|
error_code="MESSAGE_PERSIST_FAILED",
|
||||||
|
)
|
||||||
|
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||||
|
yield streaming_service.format_finish_step()
|
||||||
|
yield streaming_service.format_finish()
|
||||||
|
yield streaming_service.format_done()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Emit canonical assistant message id BEFORE any LLM streaming
|
||||||
|
# so the FE can rename its optimistic ``msg-assistant-XXX``
|
||||||
|
# placeholder to ``msg-{assistant_message_id}`` and bind
|
||||||
|
# ``tokenUsageStore`` / ``pendingInterrupt`` to the real id
|
||||||
|
# immediately. See B4 in ``sse-based_message_id_handshake``
|
||||||
|
# plan.
|
||||||
|
yield streaming_service.format_data(
|
||||||
|
"assistant-message-id",
|
||||||
|
{"message_id": assistant_message_id, "turn_id": stream_result.turn_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
stream_result.assistant_message_id = assistant_message_id
|
||||||
|
stream_result.content_builder = AssistantContentBuilder()
|
||||||
|
|
||||||
# Initial thinking step - analyzing the request
|
# Initial thinking step - analyzing the request
|
||||||
if mentioned_surfsense_docs:
|
if mentioned_surfsense_docs:
|
||||||
initial_title = "Analyzing referenced content"
|
initial_title = "Analyzing referenced content"
|
||||||
|
|
@ -2981,6 +3177,15 @@ async def stream_new_chat(
|
||||||
initial_items = [f"{action_verb}: {' '.join(processing_parts)}"]
|
initial_items = [f"{action_verb}: {' '.join(processing_parts)}"]
|
||||||
initial_step_id = "thinking-1"
|
initial_step_id = "thinking-1"
|
||||||
|
|
||||||
|
# Drive the builder for this initial thinking step too — the
|
||||||
|
# ``_emit_thinking_step`` helper lives inside ``_stream_agent_events``
|
||||||
|
# so it isn't in scope here, but the FE folds this step into
|
||||||
|
# the same singleton ``data-thinking-steps`` part as everything
|
||||||
|
# the agent stream emits later. Mirror that fold server-side.
|
||||||
|
if stream_result.content_builder is not None:
|
||||||
|
stream_result.content_builder.on_thinking_step(
|
||||||
|
initial_step_id, initial_title, "in_progress", initial_items
|
||||||
|
)
|
||||||
yield streaming_service.format_thinking_step(
|
yield streaming_service.format_thinking_step(
|
||||||
step_id=initial_step_id,
|
step_id=initial_step_id,
|
||||||
title=initial_title,
|
title=initial_title,
|
||||||
|
|
@ -2997,16 +3202,34 @@ async def stream_new_chat(
|
||||||
# Check if this is the first assistant response so we can generate
|
# Check if this is the first assistant response so we can generate
|
||||||
# a title in parallel with the agent stream (better UX than waiting
|
# a title in parallel with the agent stream (better UX than waiting
|
||||||
# until after the full response).
|
# until after the full response).
|
||||||
assistant_count_result = await session.execute(
|
# Use a LIMIT 1 EXISTS-style probe rather than COUNT(*) because
|
||||||
select(func.count(NewChatMessage.id)).filter(
|
# this is now a hot path executed on every turn, and COUNT scales
|
||||||
|
# with thread length (server-side persistence can grow rows
|
||||||
|
# quickly under power users).
|
||||||
|
#
|
||||||
|
# IMPORTANT: ``persist_assistant_shell`` above (line ~3112) already
|
||||||
|
# inserted THIS turn's assistant row. We must therefore exclude
|
||||||
|
# it from the probe — otherwise the gate fires on every turn
|
||||||
|
# except the very first, and title generation never runs for new
|
||||||
|
# threads. Excluding by primary key (``id != assistant_message_id``)
|
||||||
|
# is bulletproof regardless of ``turn_id`` shape (legacy NULLs,
|
||||||
|
# resume turns, etc.).
|
||||||
|
first_assistant_probe = await session.execute(
|
||||||
|
select(NewChatMessage.id)
|
||||||
|
.filter(
|
||||||
NewChatMessage.thread_id == chat_id,
|
NewChatMessage.thread_id == chat_id,
|
||||||
NewChatMessage.role == "assistant",
|
NewChatMessage.role == "assistant",
|
||||||
|
NewChatMessage.id != assistant_message_id,
|
||||||
)
|
)
|
||||||
|
.limit(1)
|
||||||
)
|
)
|
||||||
is_first_response = (assistant_count_result.scalar() or 0) == 0
|
is_first_response = first_assistant_probe.scalars().first() is None
|
||||||
|
|
||||||
title_task: asyncio.Task[tuple[str | None, dict | None]] | None = None
|
title_task: asyncio.Task[tuple[str | None, dict | None]] | None = None
|
||||||
if is_first_response:
|
# Gate title generation on a persisted user message so a stream
|
||||||
|
# that fails before persistence (we abort above) can never leave
|
||||||
|
# behind a thread with a generated title and no anchoring rows.
|
||||||
|
if is_first_response and user_message_id is not None:
|
||||||
|
|
||||||
async def _generate_title() -> tuple[str | None, dict | None]:
|
async def _generate_title() -> tuple[str | None, dict | None]:
|
||||||
"""Generate a short title via litellm.acompletion.
|
"""Generate a short title via litellm.acompletion.
|
||||||
|
|
@ -3138,6 +3361,7 @@ async def stream_new_chat(
|
||||||
),
|
),
|
||||||
fallback_commit_thread_id=chat_id,
|
fallback_commit_thread_id=chat_id,
|
||||||
runtime_context=runtime_context,
|
runtime_context=runtime_context,
|
||||||
|
content_builder=stream_result.content_builder,
|
||||||
):
|
):
|
||||||
if not _first_event_logged:
|
if not _first_event_logged:
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
|
|
@ -3493,6 +3717,81 @@ async def stream_new_chat(
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
await session.close()
|
await session.close()
|
||||||
|
|
||||||
|
# Server-side assistant-message + token_usage finalization.
|
||||||
|
# Runs after the main session has been closed (uses its own
|
||||||
|
# shielded session) so we don't fight the same DB connection.
|
||||||
|
# Idempotent against the legacy frontend appendMessage:
|
||||||
|
# * the assistant row was already INSERTed by
|
||||||
|
# ``persist_assistant_shell`` above, so this just UPDATEs
|
||||||
|
# it with the rich ContentPart[] from the builder.
|
||||||
|
# * token_usage uses INSERT ... ON CONFLICT DO NOTHING
|
||||||
|
# against migration 142's partial unique index, so a
|
||||||
|
# racing append_message recovery branch can never
|
||||||
|
# double-write.
|
||||||
|
# ``mark_interrupted`` closes any open text/reasoning blocks
|
||||||
|
# and flips running tool-calls (no result) to state=aborted
|
||||||
|
# so the persisted JSONB reflects a coherent end-state even
|
||||||
|
# on client disconnect.
|
||||||
|
# Never raises (best-effort, logs only).
|
||||||
|
if (
|
||||||
|
stream_result
|
||||||
|
and stream_result.turn_id
|
||||||
|
and stream_result.assistant_message_id
|
||||||
|
):
|
||||||
|
from app.tasks.chat.persistence import finalize_assistant_turn
|
||||||
|
|
||||||
|
builder_stats: dict[str, int] | None = None
|
||||||
|
if stream_result.content_builder is not None:
|
||||||
|
stream_result.content_builder.mark_interrupted()
|
||||||
|
# Snapshot stats BEFORE deepcopy in ``snapshot()`` so
|
||||||
|
# the perf log records the actual finalised payload
|
||||||
|
# (post-mark_interrupted), not the live-mutating
|
||||||
|
# builder state.
|
||||||
|
builder_stats = stream_result.content_builder.stats()
|
||||||
|
content_payload = stream_result.content_builder.snapshot()
|
||||||
|
else:
|
||||||
|
# Defensive fallback — we always set the builder
|
||||||
|
# alongside ``assistant_message_id`` above, so this
|
||||||
|
# branch only fires if a future refactor ever
|
||||||
|
# decouples them. Persist whatever accumulated
|
||||||
|
# text we captured so the row at least renders.
|
||||||
|
content_payload = [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": stream_result.accumulated_text or "",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
if builder_stats is not None:
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_new_chat] finalize_payload chat_id=%s "
|
||||||
|
"message_id=%s parts=%d bytes=%d text=%d "
|
||||||
|
"reasoning=%d tool_calls=%d "
|
||||||
|
"tool_calls_completed=%d tool_calls_aborted=%d "
|
||||||
|
"thinking_step_parts=%d step_separators=%d",
|
||||||
|
chat_id,
|
||||||
|
stream_result.assistant_message_id,
|
||||||
|
builder_stats["parts"],
|
||||||
|
builder_stats["bytes"],
|
||||||
|
builder_stats["text"],
|
||||||
|
builder_stats["reasoning"],
|
||||||
|
builder_stats["tool_calls"],
|
||||||
|
builder_stats["tool_calls_completed"],
|
||||||
|
builder_stats["tool_calls_aborted"],
|
||||||
|
builder_stats["thinking_step_parts"],
|
||||||
|
builder_stats["step_separators"],
|
||||||
|
)
|
||||||
|
|
||||||
|
await finalize_assistant_turn(
|
||||||
|
message_id=stream_result.assistant_message_id,
|
||||||
|
chat_id=chat_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
turn_id=stream_result.turn_id,
|
||||||
|
content=content_payload,
|
||||||
|
accumulator=accumulator,
|
||||||
|
)
|
||||||
|
|
||||||
# Persist any sandbox-produced files to local storage so they
|
# Persist any sandbox-produced files to local storage so they
|
||||||
# remain downloadable after the Daytona sandbox auto-deletes.
|
# remain downloadable after the Daytona sandbox auto-deletes.
|
||||||
if stream_result and stream_result.sandbox_files:
|
if stream_result and stream_result.sandbox_files:
|
||||||
|
|
@ -3937,6 +4236,50 @@ async def stream_resume_chat(
|
||||||
)
|
)
|
||||||
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||||
|
|
||||||
|
# Pre-write a fresh assistant row for this resume turn. The
|
||||||
|
# original (interrupted) ``stream_new_chat`` invocation already
|
||||||
|
# persisted its own assistant row anchored to a different
|
||||||
|
# ``turn_id``; resume allocates a new ``turn_id`` (above) so we
|
||||||
|
# need a separate row keyed on the same ``(thread_id, turn_id,
|
||||||
|
# ASSISTANT)`` invariant. Idempotent against migration 141's
|
||||||
|
# partial unique index — recovers existing id on retry.
|
||||||
|
from app.tasks.chat.content_builder import AssistantContentBuilder
|
||||||
|
from app.tasks.chat.persistence import persist_assistant_shell
|
||||||
|
|
||||||
|
assistant_message_id = await persist_assistant_shell(
|
||||||
|
chat_id=chat_id,
|
||||||
|
user_id=user_id,
|
||||||
|
turn_id=stream_result.turn_id,
|
||||||
|
)
|
||||||
|
if assistant_message_id is None:
|
||||||
|
yield _emit_stream_error(
|
||||||
|
message=(
|
||||||
|
"We couldn't initialize the assistant message. Please try again."
|
||||||
|
),
|
||||||
|
error_kind="server_error",
|
||||||
|
error_code="MESSAGE_PERSIST_FAILED",
|
||||||
|
)
|
||||||
|
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||||
|
yield streaming_service.format_finish_step()
|
||||||
|
yield streaming_service.format_finish()
|
||||||
|
yield streaming_service.format_done()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Emit canonical assistant message id BEFORE any LLM streaming
|
||||||
|
# so the FE can rename ``pendingInterrupt.assistantMsgId`` to
|
||||||
|
# ``msg-{assistant_message_id}`` immediately. Resume does NOT
|
||||||
|
# emit ``data-user-message-id`` because the user row is from
|
||||||
|
# the original interrupted turn (different ``turn_id``) and is
|
||||||
|
# never re-persisted here. See B5 in the
|
||||||
|
# ``sse-based_message_id_handshake`` plan.
|
||||||
|
yield streaming_service.format_data(
|
||||||
|
"assistant-message-id",
|
||||||
|
{"message_id": assistant_message_id, "turn_id": stream_result.turn_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
stream_result.assistant_message_id = assistant_message_id
|
||||||
|
stream_result.content_builder = AssistantContentBuilder()
|
||||||
|
|
||||||
# Resume path doesn't carry new ``mentioned_document_ids`` —
|
# Resume path doesn't carry new ``mentioned_document_ids`` —
|
||||||
# those are seeded in the original turn. We still pass a
|
# those are seeded in the original turn. We still pass a
|
||||||
# context so future middleware extensions (Phase 2) can rely on
|
# context so future middleware extensions (Phase 2) can rely on
|
||||||
|
|
@ -3968,6 +4311,7 @@ async def stream_resume_chat(
|
||||||
),
|
),
|
||||||
fallback_commit_thread_id=chat_id,
|
fallback_commit_thread_id=chat_id,
|
||||||
runtime_context=runtime_context,
|
runtime_context=runtime_context,
|
||||||
|
content_builder=stream_result.content_builder,
|
||||||
):
|
):
|
||||||
if not _first_event_logged:
|
if not _first_event_logged:
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
|
|
@ -4219,6 +4563,64 @@ async def stream_resume_chat(
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
await session.close()
|
await session.close()
|
||||||
|
|
||||||
|
# Server-side assistant-message + token_usage finalization for
|
||||||
|
# the resume flow. The original user message was persisted by
|
||||||
|
# the original (interrupted) ``stream_new_chat`` invocation;
|
||||||
|
# the resume's own ``persist_assistant_shell`` write lives at
|
||||||
|
# the new ``turn_id`` above. This finalize updates that row
|
||||||
|
# with the rich ContentPart[] from the builder and writes
|
||||||
|
# token_usage idempotently via migration 142's partial
|
||||||
|
# unique index. Best-effort, never raises.
|
||||||
|
if (
|
||||||
|
stream_result
|
||||||
|
and stream_result.turn_id
|
||||||
|
and stream_result.assistant_message_id
|
||||||
|
):
|
||||||
|
from app.tasks.chat.persistence import finalize_assistant_turn
|
||||||
|
|
||||||
|
builder_stats: dict[str, int] | None = None
|
||||||
|
if stream_result.content_builder is not None:
|
||||||
|
stream_result.content_builder.mark_interrupted()
|
||||||
|
builder_stats = stream_result.content_builder.stats()
|
||||||
|
content_payload = stream_result.content_builder.snapshot()
|
||||||
|
else:
|
||||||
|
content_payload = [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": stream_result.accumulated_text or "",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
if builder_stats is not None:
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_resume] finalize_payload chat_id=%s "
|
||||||
|
"message_id=%s parts=%d bytes=%d text=%d "
|
||||||
|
"reasoning=%d tool_calls=%d "
|
||||||
|
"tool_calls_completed=%d tool_calls_aborted=%d "
|
||||||
|
"thinking_step_parts=%d step_separators=%d",
|
||||||
|
chat_id,
|
||||||
|
stream_result.assistant_message_id,
|
||||||
|
builder_stats["parts"],
|
||||||
|
builder_stats["bytes"],
|
||||||
|
builder_stats["text"],
|
||||||
|
builder_stats["reasoning"],
|
||||||
|
builder_stats["tool_calls"],
|
||||||
|
builder_stats["tool_calls_completed"],
|
||||||
|
builder_stats["tool_calls_aborted"],
|
||||||
|
builder_stats["thinking_step_parts"],
|
||||||
|
builder_stats["step_separators"],
|
||||||
|
)
|
||||||
|
|
||||||
|
await finalize_assistant_turn(
|
||||||
|
message_id=stream_result.assistant_message_id,
|
||||||
|
chat_id=chat_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
turn_id=stream_result.turn_id,
|
||||||
|
content=content_payload,
|
||||||
|
accumulator=accumulator,
|
||||||
|
)
|
||||||
|
|
||||||
agent = llm = connector_service = None
|
agent = llm = connector_service = None
|
||||||
stream_result = None
|
stream_result = None
|
||||||
session = None
|
session = None
|
||||||
|
|
|
||||||
0
surfsense_backend/tests/integration/chat/__init__.py
Normal file
0
surfsense_backend/tests/integration/chat/__init__.py
Normal file
|
|
@ -0,0 +1,573 @@
|
||||||
|
"""Integration tests for the cross-writer integration between the
|
||||||
|
streaming chat task and the legacy ``POST /threads/{id}/messages``
|
||||||
|
(``append_message``) round-trip.
|
||||||
|
|
||||||
|
Two scenarios anchor the contract introduced by the server-side
|
||||||
|
persistence rework:
|
||||||
|
|
||||||
|
(a) **Tool-heavy turn streamed to completion.**
|
||||||
|
|
||||||
|
Drives :class:`AssistantContentBuilder` with synthetic SSE events
|
||||||
|
that mirror what ``_stream_agent_events`` emits for a turn that
|
||||||
|
interleaves text, reasoning, a tool call (start/delta/available/
|
||||||
|
output), and a final text block. Then runs
|
||||||
|
:func:`finalize_assistant_turn` and asserts:
|
||||||
|
|
||||||
|
* ``new_chat_messages.content`` JSONB matches the
|
||||||
|
``ContentPart[]`` shape the FE history loader expects, with full
|
||||||
|
``args``/``argsText``/``result``/``langchainToolCallId`` for the
|
||||||
|
tool call.
|
||||||
|
* Exactly one ``token_usage`` row exists keyed on the assistant
|
||||||
|
``message_id``.
|
||||||
|
|
||||||
|
(b) **Stale FE ``appendMessage`` after server finalize.**
|
||||||
|
|
||||||
|
Verifies the recovery branch of the ``append_message`` route now
|
||||||
|
returns the SERVER's authoritative ``ContentPart[]`` (not the FE's
|
||||||
|
stale payload) when the partial unique index from migration 141
|
||||||
|
blocks the FE's INSERT, and that the ``ON CONFLICT DO NOTHING``
|
||||||
|
clause from migration 142 stops the route from writing a duplicate
|
||||||
|
``token_usage`` row.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import (
|
||||||
|
ChatVisibility,
|
||||||
|
NewChatMessage,
|
||||||
|
NewChatMessageRole,
|
||||||
|
NewChatThread,
|
||||||
|
SearchSpace,
|
||||||
|
TokenUsage,
|
||||||
|
User,
|
||||||
|
)
|
||||||
|
from app.routes import new_chat_routes
|
||||||
|
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||||
|
from app.tasks.chat import persistence as persistence_module
|
||||||
|
from app.tasks.chat.content_builder import AssistantContentBuilder
|
||||||
|
from app.tasks.chat.persistence import (
|
||||||
|
finalize_assistant_turn,
|
||||||
|
persist_assistant_shell,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def db_thread(
|
||||||
|
db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
|
||||||
|
) -> NewChatThread:
|
||||||
|
thread = NewChatThread(
|
||||||
|
title="Test Chat",
|
||||||
|
search_space_id=db_search_space.id,
|
||||||
|
created_by_id=db_user.id,
|
||||||
|
visibility=ChatVisibility.PRIVATE,
|
||||||
|
)
|
||||||
|
db_session.add(thread)
|
||||||
|
await db_session.flush()
|
||||||
|
return thread
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def patched_shielded_session(monkeypatch, db_session: AsyncSession):
|
||||||
|
"""Route persistence helpers to the test's savepoint-bound session.
|
||||||
|
|
||||||
|
Mirrors the helper from ``test_persistence.py`` so the helpers'
|
||||||
|
internal ``ws.commit()`` / ``ws.rollback()`` resolve to SAVEPOINT
|
||||||
|
operations on the test transaction instead of touching real
|
||||||
|
autocommit boundaries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _fake_shielded_session():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
persistence_module,
|
||||||
|
"shielded_async_session",
|
||||||
|
_fake_shielded_session,
|
||||||
|
)
|
||||||
|
return db_session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def bypass_permission_checks(monkeypatch):
|
||||||
|
"""Replace RBAC + thread access checks with no-ops.
|
||||||
|
|
||||||
|
The append_message route under test calls ``check_permission`` and
|
||||||
|
``check_thread_access``; those rely on a SearchSpaceMembership row
|
||||||
|
that the existing integration fixtures don't create. The contract
|
||||||
|
we want to verify here is the ``IntegrityError`` -> recovery branch,
|
||||||
|
not the RBAC plumbing — so stub them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def _allow(*_args, **_kwargs):
|
||||||
|
return True
|
||||||
|
|
||||||
|
monkeypatch.setattr(new_chat_routes, "check_permission", _allow)
|
||||||
|
monkeypatch.setattr(new_chat_routes, "check_thread_access", _allow)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeRequest:
|
||||||
|
"""Minimal Request stand-in used by ``append_message``.
|
||||||
|
|
||||||
|
The route only calls ``await request.json()`` — keep the surface
|
||||||
|
area tight so this doesn't accidentally hide future signature
|
||||||
|
changes that we *would* want to break the test.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, body: dict):
|
||||||
|
self._body = body
|
||||||
|
|
||||||
|
async def json(self) -> dict:
|
||||||
|
return self._body
|
||||||
|
|
||||||
|
|
||||||
|
def _build_tool_heavy_content() -> list[dict]:
|
||||||
|
"""Drive ``AssistantContentBuilder`` through a tool-heavy turn.
|
||||||
|
|
||||||
|
Produces the same ``ContentPart[]`` shape the streaming layer would
|
||||||
|
persist if ``_stream_agent_events`` ran a turn with: opening
|
||||||
|
reasoning -> text -> tool call (input start/delta/available/output)
|
||||||
|
-> closing text. Centralised here so the (a) and (b) scenarios use
|
||||||
|
the same authoritative payload.
|
||||||
|
"""
|
||||||
|
builder = AssistantContentBuilder()
|
||||||
|
|
||||||
|
builder.on_reasoning_start("r1")
|
||||||
|
builder.on_reasoning_delta("r1", "Let me look up ")
|
||||||
|
builder.on_reasoning_delta("r1", "the file listing.")
|
||||||
|
builder.on_reasoning_end("r1")
|
||||||
|
|
||||||
|
builder.on_text_start("t1")
|
||||||
|
builder.on_text_delta("t1", "Sure, listing files in ")
|
||||||
|
builder.on_text_delta("t1", "/.")
|
||||||
|
builder.on_text_end("t1")
|
||||||
|
|
||||||
|
builder.on_tool_input_start(
|
||||||
|
"tool_call_ui_1",
|
||||||
|
tool_name="ls",
|
||||||
|
langchain_tool_call_id="lc_call_xyz",
|
||||||
|
)
|
||||||
|
builder.on_tool_input_delta("tool_call_ui_1", '{"path"')
|
||||||
|
builder.on_tool_input_delta("tool_call_ui_1", ': "/"}')
|
||||||
|
builder.on_tool_input_available(
|
||||||
|
"tool_call_ui_1",
|
||||||
|
tool_name="ls",
|
||||||
|
args={"path": "/"},
|
||||||
|
langchain_tool_call_id="lc_call_xyz",
|
||||||
|
)
|
||||||
|
builder.on_tool_output_available(
|
||||||
|
"tool_call_ui_1",
|
||||||
|
output={"files": ["a.txt", "b.txt"]},
|
||||||
|
langchain_tool_call_id="lc_call_xyz",
|
||||||
|
)
|
||||||
|
|
||||||
|
builder.on_text_start("t2")
|
||||||
|
builder.on_text_delta("t2", "Found 2 files: a.txt and b.txt.")
|
||||||
|
builder.on_text_end("t2")
|
||||||
|
|
||||||
|
return builder.snapshot()
|
||||||
|
|
||||||
|
|
||||||
|
def _accumulator_with_one_call() -> TurnTokenAccumulator:
|
||||||
|
acc = TurnTokenAccumulator()
|
||||||
|
acc.add(
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
prompt_tokens=200,
|
||||||
|
completion_tokens=80,
|
||||||
|
total_tokens=280,
|
||||||
|
cost_micros=22222,
|
||||||
|
)
|
||||||
|
return acc
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# (a) Tool-heavy stream finalize
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolHeavyTurnFinalize:
|
||||||
|
async def test_full_tool_call_persisted_and_one_token_usage_row(
|
||||||
|
self,
|
||||||
|
db_session,
|
||||||
|
db_user,
|
||||||
|
db_thread,
|
||||||
|
db_search_space,
|
||||||
|
patched_shielded_session,
|
||||||
|
):
|
||||||
|
"""End-to-end seam: builder snapshot -> finalize -> DB row.
|
||||||
|
|
||||||
|
Matches the production flow's *content* invariant: whatever
|
||||||
|
``AssistantContentBuilder.snapshot()`` produces is what the
|
||||||
|
streaming layer hands to ``finalize_assistant_turn``, so this
|
||||||
|
test catches any drift between the JSONB shape the builder
|
||||||
|
emits and the one the FE history loader expects.
|
||||||
|
"""
|
||||||
|
thread_id = db_thread.id
|
||||||
|
user_id_str = str(db_user.id)
|
||||||
|
search_space_id = db_search_space.id
|
||||||
|
turn_id = f"{thread_id}:tool_heavy"
|
||||||
|
|
||||||
|
msg_id = await persist_assistant_shell(
|
||||||
|
chat_id=thread_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
)
|
||||||
|
assert msg_id is not None
|
||||||
|
|
||||||
|
snapshot = _build_tool_heavy_content()
|
||||||
|
# Sanity-check the snapshot before we hand it to the DB so a
|
||||||
|
# builder regression surfaces here, not deep inside an opaque
|
||||||
|
# JSONB diff.
|
||||||
|
assert any(p.get("type") == "reasoning" for p in snapshot)
|
||||||
|
text_parts = [p for p in snapshot if p.get("type") == "text"]
|
||||||
|
assert len(text_parts) == 2
|
||||||
|
tool_parts = [p for p in snapshot if p.get("type") == "tool-call"]
|
||||||
|
assert len(tool_parts) == 1
|
||||||
|
tool_part = tool_parts[0]
|
||||||
|
assert tool_part["toolCallId"] == "tool_call_ui_1"
|
||||||
|
assert tool_part["toolName"] == "ls"
|
||||||
|
assert tool_part["args"] == {"path": "/"}
|
||||||
|
# ``argsText`` ends up as the pretty-printed final args (the
|
||||||
|
# ``tool-input-available`` event replaces the streamed deltas
|
||||||
|
# with ``json.dumps(args, indent=2)`` to match the FE's
|
||||||
|
# post-stream rendering).
|
||||||
|
assert tool_part["argsText"] == '{\n "path": "/"\n}'
|
||||||
|
assert tool_part["result"] == {"files": ["a.txt", "b.txt"]}
|
||||||
|
# ``langchainToolCallId`` is the agent-side correlation id used
|
||||||
|
# by the regenerate path; a missing one breaks
|
||||||
|
# edit-from-tool-call later.
|
||||||
|
assert tool_part["langchainToolCallId"] == "lc_call_xyz"
|
||||||
|
|
||||||
|
await finalize_assistant_turn(
|
||||||
|
message_id=msg_id,
|
||||||
|
chat_id=thread_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
content=snapshot,
|
||||||
|
accumulator=_accumulator_with_one_call(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ``content`` must round-trip byte-for-byte through the JSONB
|
||||||
|
# column. SQLAlchemy doesn't auto-refresh the row that survived
|
||||||
|
# the savepoint commit, so refresh explicitly.
|
||||||
|
row = await db_session.get(NewChatMessage, msg_id)
|
||||||
|
await db_session.refresh(row)
|
||||||
|
|
||||||
|
# The history loader reads ``content`` straight into the FE's
|
||||||
|
# parts array, so a strict equality comparison is the right
|
||||||
|
# invariant here.
|
||||||
|
assert row.content == snapshot
|
||||||
|
# Tool-call parts must JSON-serialise cleanly — nothing in
|
||||||
|
# ``args`` / ``argsText`` / ``result`` should accidentally be a
|
||||||
|
# non-JSON-safe value (datetime, set, custom class).
|
||||||
|
assert json.dumps(row.content)
|
||||||
|
|
||||||
|
usage_count = (
|
||||||
|
await db_session.execute(
|
||||||
|
select(func.count())
|
||||||
|
.select_from(TokenUsage)
|
||||||
|
.where(TokenUsage.message_id == msg_id)
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
assert usage_count == 1
|
||||||
|
|
||||||
|
usage = (
|
||||||
|
await db_session.execute(
|
||||||
|
select(TokenUsage).where(TokenUsage.message_id == msg_id)
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
assert usage.usage_type == "chat"
|
||||||
|
assert usage.prompt_tokens == 200
|
||||||
|
assert usage.completion_tokens == 80
|
||||||
|
assert usage.total_tokens == 280
|
||||||
|
assert usage.cost_micros == 22222
|
||||||
|
assert usage.thread_id == thread_id
|
||||||
|
assert usage.search_space_id == search_space_id
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# (b) FE appendMessage after server finalize
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAppendMessageRecoveryAfterFinalize:
|
||||||
|
async def test_returns_server_content_and_does_not_duplicate_token_usage(
|
||||||
|
self,
|
||||||
|
db_session,
|
||||||
|
db_user,
|
||||||
|
db_thread,
|
||||||
|
db_search_space,
|
||||||
|
patched_shielded_session,
|
||||||
|
bypass_permission_checks,
|
||||||
|
):
|
||||||
|
"""FE's stale ``appendMessage`` after server finalize.
|
||||||
|
|
||||||
|
The frontend used to be the authoritative writer for assistant
|
||||||
|
``content``. Now the server is. When the legacy FE round-trip
|
||||||
|
fires *after* the server has already finalized:
|
||||||
|
|
||||||
|
* the route's INSERT trips the (thread_id, turn_id, role)
|
||||||
|
partial unique index from migration 141,
|
||||||
|
* the recovery branch fetches the existing row and returns
|
||||||
|
*its* ``content`` — discarding the FE payload — so the
|
||||||
|
history loader reads the rich server payload (full tool
|
||||||
|
args, argsText, langchainToolCallId, etc.) on next page
|
||||||
|
reload,
|
||||||
|
* the route's optional ``token_usage`` insert is keyed on the
|
||||||
|
partial unique index from migration 142 so it silently
|
||||||
|
no-ops if the server already wrote one.
|
||||||
|
"""
|
||||||
|
thread_id = db_thread.id
|
||||||
|
user_id_str = str(db_user.id)
|
||||||
|
search_space_id = db_search_space.id
|
||||||
|
turn_id = f"{thread_id}:fe_late_append"
|
||||||
|
|
||||||
|
# Step 1: server stream completes. Server-built rich content is
|
||||||
|
# finalized.
|
||||||
|
msg_id = await persist_assistant_shell(
|
||||||
|
chat_id=thread_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
)
|
||||||
|
assert msg_id is not None
|
||||||
|
|
||||||
|
server_content = _build_tool_heavy_content()
|
||||||
|
await finalize_assistant_turn(
|
||||||
|
message_id=msg_id,
|
||||||
|
chat_id=thread_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
content=server_content,
|
||||||
|
accumulator=_accumulator_with_one_call(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 2: simulate the legacy FE ``appendMessage`` round-trip
|
||||||
|
# arriving with stale, lossy content (missing tool args, etc.)
|
||||||
|
# plus a ``token_usage`` body.
|
||||||
|
fe_stale_content = [
|
||||||
|
{"type": "text", "text": "Found 2 files: a.txt and b.txt."},
|
||||||
|
]
|
||||||
|
fe_request_body = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": fe_stale_content,
|
||||||
|
"turn_id": turn_id,
|
||||||
|
"token_usage": {
|
||||||
|
"prompt_tokens": 999,
|
||||||
|
"completion_tokens": 999,
|
||||||
|
"total_tokens": 1998,
|
||||||
|
"cost_micros": 88888,
|
||||||
|
"usage": {"any": "thing"},
|
||||||
|
"call_details": {"calls": []},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
request = _FakeRequest(fe_request_body)
|
||||||
|
|
||||||
|
# ``db_user`` is bound to ``db_session``. The route's
|
||||||
|
# IntegrityError branch calls ``session.rollback()``, which
|
||||||
|
# expires every ORM row attached to the session including this
|
||||||
|
# user — historically causing ``user.id`` to lazy-load
|
||||||
|
# out-of-greenlet and crash the request with ``MissingGreenlet``
|
||||||
|
# (observed in production logs at /api/v1/threads/531/messages
|
||||||
|
# 2026-05-04). The route now captures ``user.id`` to a primitive
|
||||||
|
# UUID at the top of the handler, so the rollback can't reach
|
||||||
|
# it. Pass the *attached* user here on purpose — that's the
|
||||||
|
# production scenario, and this test is the regression guard
|
||||||
|
# against that bug returning.
|
||||||
|
response = await new_chat_routes.append_message(
|
||||||
|
thread_id=thread_id,
|
||||||
|
request=request,
|
||||||
|
session=db_session,
|
||||||
|
user=db_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Response must echo the SERVER's rich payload, not the FE's
|
||||||
|
# stale snapshot. This is the user-visible part of the
|
||||||
|
# contract: history reload + ThreadHistoryAdapter.append both
|
||||||
|
# read from the same authoritative source.
|
||||||
|
assert response.id == msg_id
|
||||||
|
assert response.role == NewChatMessageRole.ASSISTANT
|
||||||
|
assert response.turn_id == turn_id
|
||||||
|
assert response.content == server_content
|
||||||
|
assert response.content != fe_stale_content
|
||||||
|
|
||||||
|
# The on-disk row must agree with the response.
|
||||||
|
row = await db_session.get(NewChatMessage, msg_id)
|
||||||
|
await db_session.refresh(row)
|
||||||
|
assert row.content == server_content
|
||||||
|
|
||||||
|
# ``token_usage``: exactly one row, with the *server's* values
|
||||||
|
# (the FE's much larger token counts must not have overwritten
|
||||||
|
# them).
|
||||||
|
usage_count = (
|
||||||
|
await db_session.execute(
|
||||||
|
select(func.count())
|
||||||
|
.select_from(TokenUsage)
|
||||||
|
.where(TokenUsage.message_id == msg_id)
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
assert usage_count == 1
|
||||||
|
|
||||||
|
usage = (
|
||||||
|
await db_session.execute(
|
||||||
|
select(TokenUsage).where(TokenUsage.message_id == msg_id)
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
assert usage.cost_micros == 22222 # Server's value, not 88888
|
||||||
|
assert usage.total_tokens == 280 # Server's value, not 1998
|
||||||
|
|
||||||
|
async def test_legacy_fe_first_appendmessage_then_server_no_dupe(
|
||||||
|
self,
|
||||||
|
db_session,
|
||||||
|
db_user,
|
||||||
|
db_thread,
|
||||||
|
db_search_space,
|
||||||
|
patched_shielded_session,
|
||||||
|
bypass_permission_checks,
|
||||||
|
):
|
||||||
|
"""Inverse race: legacy FE writes first, server finalize second.
|
||||||
|
|
||||||
|
Some clients still post ``appendMessage`` before the streaming
|
||||||
|
``finally`` runs. The contract is symmetric: whichever writer
|
||||||
|
loses the (thread_id, turn_id, role) race silently lets the
|
||||||
|
winner keep its content. In particular the *server's*
|
||||||
|
finalize must NOT raise — it must look up the existing row and
|
||||||
|
UPDATE its content with the server-built payload (which is
|
||||||
|
always richer/more authoritative than whatever the FE
|
||||||
|
snapshot held).
|
||||||
|
"""
|
||||||
|
thread_id = db_thread.id
|
||||||
|
user_id_str = str(db_user.id)
|
||||||
|
search_space_id = db_search_space.id
|
||||||
|
turn_id = f"{thread_id}:fe_first"
|
||||||
|
|
||||||
|
# Step 1: legacy FE appendMessage lands first. No prior shell
|
||||||
|
# row exists; the route does the INSERT itself.
|
||||||
|
fe_request_body = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "text", "text": "early FE write"}],
|
||||||
|
"turn_id": turn_id,
|
||||||
|
}
|
||||||
|
fe_response = await new_chat_routes.append_message(
|
||||||
|
thread_id=thread_id,
|
||||||
|
request=_FakeRequest(fe_request_body),
|
||||||
|
session=db_session,
|
||||||
|
user=db_user,
|
||||||
|
)
|
||||||
|
assert fe_response.role == NewChatMessageRole.ASSISTANT
|
||||||
|
|
||||||
|
# Step 2: server stream's persist_assistant_shell now races
|
||||||
|
# behind. It must adopt the existing row id, not raise.
|
||||||
|
adopted_id = await persist_assistant_shell(
|
||||||
|
chat_id=thread_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
)
|
||||||
|
assert adopted_id == fe_response.id
|
||||||
|
|
||||||
|
# Step 3: server finalize then overwrites the FE's stub with
|
||||||
|
# the rich content (which is the correct, more authoritative
|
||||||
|
# payload).
|
||||||
|
server_content = _build_tool_heavy_content()
|
||||||
|
await finalize_assistant_turn(
|
||||||
|
message_id=adopted_id,
|
||||||
|
chat_id=thread_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
content=server_content,
|
||||||
|
accumulator=_accumulator_with_one_call(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Final state: one row, server content, one token_usage row.
|
||||||
|
msg_count = (
|
||||||
|
await db_session.execute(
|
||||||
|
select(func.count())
|
||||||
|
.select_from(NewChatMessage)
|
||||||
|
.where(
|
||||||
|
NewChatMessage.thread_id == thread_id,
|
||||||
|
NewChatMessage.turn_id == turn_id,
|
||||||
|
NewChatMessage.role == NewChatMessageRole.ASSISTANT,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
assert msg_count == 1
|
||||||
|
|
||||||
|
row = await db_session.get(NewChatMessage, adopted_id)
|
||||||
|
await db_session.refresh(row)
|
||||||
|
assert row.content == server_content
|
||||||
|
|
||||||
|
usage_count = (
|
||||||
|
await db_session.execute(
|
||||||
|
select(func.count())
|
||||||
|
.select_from(TokenUsage)
|
||||||
|
.where(TokenUsage.message_id == adopted_id)
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
assert usage_count == 1
|
||||||
|
|
||||||
|
async def test_appendmessage_without_turn_id_legacy_400(
|
||||||
|
self,
|
||||||
|
db_session,
|
||||||
|
db_user,
|
||||||
|
db_thread,
|
||||||
|
patched_shielded_session,
|
||||||
|
bypass_permission_checks,
|
||||||
|
):
|
||||||
|
"""Defensive: a bare appendMessage with no turn_id and no
|
||||||
|
existing row is just a normal INSERT — must succeed. But if a
|
||||||
|
row with the same role already exists in this thread *without*
|
||||||
|
turn_id collisions, the route should fall through to the
|
||||||
|
legacy 400 path on a foreign-key / unrelated IntegrityError
|
||||||
|
(we don't ship that bug today, but pin the behaviour so a
|
||||||
|
future schema change can't silently regress it).
|
||||||
|
"""
|
||||||
|
thread_id = db_thread.id
|
||||||
|
|
||||||
|
# Bare appendMessage with no turn_id — should just succeed
|
||||||
|
# without invoking the recovery branch.
|
||||||
|
ok_response = await new_chat_routes.append_message(
|
||||||
|
thread_id=thread_id,
|
||||||
|
request=_FakeRequest(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "hi"}],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
session=db_session,
|
||||||
|
user=db_user,
|
||||||
|
)
|
||||||
|
assert ok_response.role == NewChatMessageRole.USER
|
||||||
|
assert ok_response.turn_id is None
|
||||||
|
|
||||||
|
# Sanity: the route did NOT silently swallow the missing
|
||||||
|
# turn_id by routing through the unique-index recovery branch
|
||||||
|
# — it took the happy path.
|
||||||
|
msg_count = (
|
||||||
|
await db_session.execute(
|
||||||
|
select(func.count())
|
||||||
|
.select_from(NewChatMessage)
|
||||||
|
.where(
|
||||||
|
NewChatMessage.thread_id == thread_id,
|
||||||
|
NewChatMessage.role == NewChatMessageRole.USER,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
assert msg_count == 1
|
||||||
332
surfsense_backend/tests/integration/chat/test_message_id_sse.py
Normal file
332
surfsense_backend/tests/integration/chat/test_message_id_sse.py
Normal file
|
|
@ -0,0 +1,332 @@
|
||||||
|
"""Integration tests for the SSE-based message ID handshake.
|
||||||
|
|
||||||
|
The streaming generators (``stream_new_chat`` / ``stream_resume_chat``)
|
||||||
|
emit two new events after their respective persistence helpers resolve
|
||||||
|
the canonical ``new_chat_messages.id``:
|
||||||
|
|
||||||
|
* ``data-user-message-id`` — emitted only by ``stream_new_chat``,
|
||||||
|
AFTER ``persist_user_turn`` and BEFORE any LLM streaming.
|
||||||
|
* ``data-assistant-message-id`` — emitted by both
|
||||||
|
``stream_new_chat`` and ``stream_resume_chat``, AFTER
|
||||||
|
``persist_assistant_shell`` and BEFORE any LLM streaming.
|
||||||
|
|
||||||
|
The frontend renames its optimistic ``msg-user-XXX`` /
|
||||||
|
``msg-assistant-XXX`` placeholder ids to ``msg-{db_id}`` upon receiving
|
||||||
|
these events. This test suite anchors three contracts:
|
||||||
|
|
||||||
|
1. ``format_data`` produces SSE bytes in the precise shape
|
||||||
|
``data: {"type":"data-<suffix>","data":{...}}\\n\\n`` that the FE's
|
||||||
|
``readSSEStream`` consumer parses (matches ``surfsense_web/lib/chat/streaming-state.ts``).
|
||||||
|
2. The ``message_id`` carried in the SSE payload exactly equals the
|
||||||
|
primary key the persistence helper inserted into
|
||||||
|
``new_chat_messages`` — so the FE rename produces ``msg-{real_pk}``,
|
||||||
|
which in turn unlocks DB-id-gated UI (comments, edit-from-message).
|
||||||
|
3. The same ``message_id`` is used for the ``token_usage.message_id``
|
||||||
|
foreign key, so ``finalize_assistant_turn``'s row binds correctly.
|
||||||
|
|
||||||
|
Direct end-to-end testing of ``stream_new_chat`` requires a fully
|
||||||
|
mocked agent + LLM stack (out-of-scope here); those flows are covered
|
||||||
|
by the harness-driven integration tests under
|
||||||
|
``tests/integration/agents/new_chat/`` plus the assertion in
|
||||||
|
``test_persistence.py`` that the helpers themselves return ``int``
|
||||||
|
ids. The contracts above close the remaining gap between the persist
|
||||||
|
helpers and the bytes that ship to the FE.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import (
|
||||||
|
ChatVisibility,
|
||||||
|
NewChatMessage,
|
||||||
|
NewChatMessageRole,
|
||||||
|
NewChatThread,
|
||||||
|
SearchSpace,
|
||||||
|
User,
|
||||||
|
)
|
||||||
|
from app.services.new_streaming_service import VercelStreamingService
|
||||||
|
from app.tasks.chat import persistence as persistence_module
|
||||||
|
from app.tasks.chat.persistence import (
|
||||||
|
persist_assistant_shell,
|
||||||
|
persist_user_turn,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures (mirror test_persistence.py)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def db_thread(
|
||||||
|
db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
|
||||||
|
) -> NewChatThread:
|
||||||
|
thread = NewChatThread(
|
||||||
|
title="Test Chat",
|
||||||
|
search_space_id=db_search_space.id,
|
||||||
|
created_by_id=db_user.id,
|
||||||
|
visibility=ChatVisibility.PRIVATE,
|
||||||
|
)
|
||||||
|
db_session.add(thread)
|
||||||
|
await db_session.flush()
|
||||||
|
return thread
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def patched_shielded_session(monkeypatch, db_session: AsyncSession):
|
||||||
|
"""Route persistence helpers to the test's savepoint-bound session."""
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _fake_shielded_session():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
persistence_module,
|
||||||
|
"shielded_async_session",
|
||||||
|
_fake_shielded_session,
|
||||||
|
)
|
||||||
|
return db_session
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# (1) SSE byte-shape contract
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_sse_data_line(blob: str) -> dict:
|
||||||
|
"""Unwrap a single ``data: <json>\\n\\n`` SSE frame.
|
||||||
|
|
||||||
|
Raises if there's more than one frame or the prefix is wrong — keeps
|
||||||
|
the parser strict so a regression in ``format_data`` produces a
|
||||||
|
test failure here, not in a downstream consumer.
|
||||||
|
"""
|
||||||
|
assert blob.endswith("\n\n"), f"missing terminator: {blob!r}"
|
||||||
|
line = blob.removesuffix("\n\n")
|
||||||
|
assert line.startswith("data: "), f"missing data prefix: {line!r}"
|
||||||
|
return json.loads(line.removeprefix("data: "))
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSEByteShape:
|
||||||
|
def test_data_user_message_id_byte_shape(self):
|
||||||
|
"""``format_data("user-message-id", {...})`` must produce the
|
||||||
|
exact wire format the FE's
|
||||||
|
``readSSEStream`` -> ``data-user-message-id`` case parses.
|
||||||
|
"""
|
||||||
|
svc = VercelStreamingService()
|
||||||
|
blob = svc.format_data(
|
||||||
|
"user-message-id",
|
||||||
|
{"message_id": 1843, "turn_id": "533:1762900000000"},
|
||||||
|
)
|
||||||
|
envelope = _parse_sse_data_line(blob)
|
||||||
|
assert envelope == {
|
||||||
|
"type": "data-user-message-id",
|
||||||
|
"data": {"message_id": 1843, "turn_id": "533:1762900000000"},
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_data_assistant_message_id_byte_shape(self):
|
||||||
|
svc = VercelStreamingService()
|
||||||
|
blob = svc.format_data(
|
||||||
|
"assistant-message-id",
|
||||||
|
{"message_id": 1844, "turn_id": "533:1762900000000"},
|
||||||
|
)
|
||||||
|
envelope = _parse_sse_data_line(blob)
|
||||||
|
assert envelope == {
|
||||||
|
"type": "data-assistant-message-id",
|
||||||
|
"data": {"message_id": 1844, "turn_id": "533:1762900000000"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# (2) Helper-id <-> DB-pk coherence
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestHandshakeIdMatchesDB:
|
||||||
|
"""The SSE handshake's correctness hinges on the integer in
|
||||||
|
``data-{user,assistant}-message-id`` being the EXACT primary key
|
||||||
|
the persistence helper inserted. If they ever diverge, the FE
|
||||||
|
rename produces ``msg-{wrong_id}``, comments break (regex match
|
||||||
|
fails), and downstream features (edit, regenerate) target the
|
||||||
|
wrong row. Anchor it here.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def test_user_message_id_matches_new_chat_messages_pk(
|
||||||
|
self,
|
||||||
|
db_session,
|
||||||
|
db_user,
|
||||||
|
db_thread,
|
||||||
|
patched_shielded_session,
|
||||||
|
):
|
||||||
|
thread_id = db_thread.id
|
||||||
|
user_id_str = str(db_user.id)
|
||||||
|
turn_id = f"{thread_id}:9000"
|
||||||
|
|
||||||
|
# The streaming generator passes this same value into
|
||||||
|
# ``streaming_service.format_data("user-message-id", {...})``.
|
||||||
|
msg_id_from_helper = await persist_user_turn(
|
||||||
|
chat_id=thread_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
user_query="hello",
|
||||||
|
)
|
||||||
|
assert isinstance(msg_id_from_helper, int)
|
||||||
|
|
||||||
|
# Look up the row the helper inserted via
|
||||||
|
# ``(thread_id, turn_id, role)`` — the same composite the FE
|
||||||
|
# uses to identify a turn — and confirm the PK matches.
|
||||||
|
row = (
|
||||||
|
await db_session.execute(
|
||||||
|
select(NewChatMessage).where(
|
||||||
|
NewChatMessage.thread_id == thread_id,
|
||||||
|
NewChatMessage.turn_id == turn_id,
|
||||||
|
NewChatMessage.role == NewChatMessageRole.USER,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
assert row.id == msg_id_from_helper
|
||||||
|
|
||||||
|
# The byte-stream the FE actually receives — confirms the
|
||||||
|
# round-trip from the helper return value to the SSE payload.
|
||||||
|
svc = VercelStreamingService()
|
||||||
|
envelope = _parse_sse_data_line(
|
||||||
|
svc.format_data(
|
||||||
|
"user-message-id",
|
||||||
|
{"message_id": msg_id_from_helper, "turn_id": turn_id},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert envelope["data"]["message_id"] == row.id
|
||||||
|
|
||||||
|
async def test_assistant_message_id_matches_new_chat_messages_pk(
|
||||||
|
self,
|
||||||
|
db_session,
|
||||||
|
db_user,
|
||||||
|
db_thread,
|
||||||
|
patched_shielded_session,
|
||||||
|
):
|
||||||
|
thread_id = db_thread.id
|
||||||
|
user_id_str = str(db_user.id)
|
||||||
|
turn_id = f"{thread_id}:9100"
|
||||||
|
|
||||||
|
msg_id_from_helper = await persist_assistant_shell(
|
||||||
|
chat_id=thread_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
)
|
||||||
|
assert isinstance(msg_id_from_helper, int)
|
||||||
|
|
||||||
|
row = (
|
||||||
|
await db_session.execute(
|
||||||
|
select(NewChatMessage).where(
|
||||||
|
NewChatMessage.thread_id == thread_id,
|
||||||
|
NewChatMessage.turn_id == turn_id,
|
||||||
|
NewChatMessage.role == NewChatMessageRole.ASSISTANT,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
assert row.id == msg_id_from_helper
|
||||||
|
|
||||||
|
svc = VercelStreamingService()
|
||||||
|
envelope = _parse_sse_data_line(
|
||||||
|
svc.format_data(
|
||||||
|
"assistant-message-id",
|
||||||
|
{"message_id": msg_id_from_helper, "turn_id": turn_id},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert envelope["data"]["message_id"] == row.id
|
||||||
|
|
||||||
|
async def test_handshake_ids_for_full_turn_are_distinct_and_paired(
|
||||||
|
self,
|
||||||
|
db_session,
|
||||||
|
db_user,
|
||||||
|
db_thread,
|
||||||
|
patched_shielded_session,
|
||||||
|
):
|
||||||
|
"""Sanity: a full new-chat turn's two SSE events carry two
|
||||||
|
DIFFERENT ids (user row PK ≠ assistant row PK), both anchored
|
||||||
|
to the SAME ``turn_id`` in the DB. This pairing is what
|
||||||
|
``finalize_assistant_turn`` and the regenerate / edit flows
|
||||||
|
rely on.
|
||||||
|
"""
|
||||||
|
thread_id = db_thread.id
|
||||||
|
user_id_str = str(db_user.id)
|
||||||
|
turn_id = f"{thread_id}:9200"
|
||||||
|
|
||||||
|
user_msg_id = await persist_user_turn(
|
||||||
|
chat_id=thread_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
user_query="hi",
|
||||||
|
)
|
||||||
|
assistant_msg_id = await persist_assistant_shell(
|
||||||
|
chat_id=thread_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
)
|
||||||
|
assert user_msg_id is not None and assistant_msg_id is not None
|
||||||
|
assert user_msg_id != assistant_msg_id
|
||||||
|
|
||||||
|
rows = (
|
||||||
|
(
|
||||||
|
await db_session.execute(
|
||||||
|
select(NewChatMessage)
|
||||||
|
.where(
|
||||||
|
NewChatMessage.thread_id == thread_id,
|
||||||
|
NewChatMessage.turn_id == turn_id,
|
||||||
|
)
|
||||||
|
.order_by(NewChatMessage.id)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
assert len(rows) == 2
|
||||||
|
ids_by_role = {r.role: r.id for r in rows}
|
||||||
|
assert ids_by_role[NewChatMessageRole.USER] == user_msg_id
|
||||||
|
assert ids_by_role[NewChatMessageRole.ASSISTANT] == assistant_msg_id
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# (3) Parse helpers used by the FE — sanity-check our payload shape
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPayloadShapeMatchesFEReader:
|
||||||
|
"""The FE's ``readStreamedMessageId`` (in
|
||||||
|
``surfsense_web/lib/chat/stream-side-effects.ts``) requires:
|
||||||
|
|
||||||
|
* ``message_id`` is a ``number`` (rejects null / string / NaN).
|
||||||
|
* ``turn_id`` is an optional non-empty string (else it's coerced
|
||||||
|
to ``null``).
|
||||||
|
|
||||||
|
These tests exercise the BE side of that contract by inspecting
|
||||||
|
``format_data`` output shapes that the FE consumes verbatim.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_message_id_is_serialised_as_a_json_number(self):
|
||||||
|
svc = VercelStreamingService()
|
||||||
|
envelope = _parse_sse_data_line(
|
||||||
|
svc.format_data("user-message-id", {"message_id": 42, "turn_id": "t"})
|
||||||
|
)
|
||||||
|
assert isinstance(envelope["data"]["message_id"], int)
|
||||||
|
assert envelope["data"]["message_id"] == 42
|
||||||
|
|
||||||
|
def test_turn_id_round_trips_as_string(self):
|
||||||
|
svc = VercelStreamingService()
|
||||||
|
# The actual format used in production: f"{chat_id}:{int(time.time()*1000)}"
|
||||||
|
production_turn_id = "533:1762900000000"
|
||||||
|
envelope = _parse_sse_data_line(
|
||||||
|
svc.format_data(
|
||||||
|
"assistant-message-id",
|
||||||
|
{"message_id": 1, "turn_id": production_turn_id},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert envelope["data"]["turn_id"] == production_turn_id
|
||||||
747
surfsense_backend/tests/integration/chat/test_persistence.py
Normal file
747
surfsense_backend/tests/integration/chat/test_persistence.py
Normal file
|
|
@ -0,0 +1,747 @@
|
||||||
|
"""Integration tests for ``app.tasks.chat.persistence``.
|
||||||
|
|
||||||
|
Verifies the DB-side guarantees the streaming chat task relies on:
|
||||||
|
|
||||||
|
* ``persist_assistant_shell`` is idempotent against the
|
||||||
|
``(thread_id, turn_id, ASSISTANT)`` partial unique index from
|
||||||
|
migration 141. Two calls with the same ``turn_id`` return the SAME
|
||||||
|
``message_id`` and never create a duplicate ``new_chat_messages`` row.
|
||||||
|
* ``finalize_assistant_turn`` writes a status-marker payload when given
|
||||||
|
empty content, never raises, and is safe to call twice on the same
|
||||||
|
``message_id`` — the partial unique index from migration 142
|
||||||
|
(``uq_token_usage_message_id``) prevents the second insert from
|
||||||
|
producing a duplicate ``token_usage`` row.
|
||||||
|
* The same ``ON CONFLICT DO NOTHING`` invariant covers the cross-writer
|
||||||
|
race where ``finalize_assistant_turn`` and the ``append_message``
|
||||||
|
recovery branch both target the same ``message_id``.
|
||||||
|
|
||||||
|
All tests run inside the conftest's outer-transaction-with-savepoint
|
||||||
|
fixture so commits inside the helpers (which open their own
|
||||||
|
``shielded_async_session``) are released as savepoints and rolled back
|
||||||
|
at test end. We monkey-patch ``shielded_async_session`` to yield the
|
||||||
|
same pooled test session so the integration transaction stays
|
||||||
|
in-scope.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import (
|
||||||
|
ChatVisibility,
|
||||||
|
NewChatMessage,
|
||||||
|
NewChatMessageRole,
|
||||||
|
NewChatThread,
|
||||||
|
SearchSpace,
|
||||||
|
TokenUsage,
|
||||||
|
User,
|
||||||
|
)
|
||||||
|
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||||
|
from app.tasks.chat import persistence as persistence_module
|
||||||
|
from app.tasks.chat.persistence import (
|
||||||
|
finalize_assistant_turn,
|
||||||
|
persist_assistant_shell,
|
||||||
|
persist_user_turn,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def db_thread(
|
||||||
|
db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
|
||||||
|
) -> NewChatThread:
|
||||||
|
thread = NewChatThread(
|
||||||
|
title="Test Chat",
|
||||||
|
search_space_id=db_search_space.id,
|
||||||
|
created_by_id=db_user.id,
|
||||||
|
visibility=ChatVisibility.PRIVATE,
|
||||||
|
)
|
||||||
|
db_session.add(thread)
|
||||||
|
await db_session.flush()
|
||||||
|
return thread
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def patched_shielded_session(monkeypatch, db_session: AsyncSession):
|
||||||
|
"""Route persistence helpers to the test's savepoint-bound session.
|
||||||
|
|
||||||
|
The persistence helpers use ``async with shielded_async_session() as
|
||||||
|
ws`` and call ``ws.commit()`` internally. Inside the conftest's
|
||||||
|
``join_transaction_mode="create_savepoint"`` setup, those commits
|
||||||
|
release a SAVEPOINT instead of committing the outer transaction —
|
||||||
|
so the test session can see helper-staged rows immediately and the
|
||||||
|
outer rollback at end of test wipes them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _fake_shielded_session():
|
||||||
|
yield db_session
|
||||||
|
# Do NOT close — the outer fixture owns the session lifecycle.
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
persistence_module,
|
||||||
|
"shielded_async_session",
|
||||||
|
_fake_shielded_session,
|
||||||
|
)
|
||||||
|
return db_session
|
||||||
|
|
||||||
|
|
||||||
|
def _accumulator_with_one_call() -> TurnTokenAccumulator:
|
||||||
|
acc = TurnTokenAccumulator()
|
||||||
|
acc.add(
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
prompt_tokens=100,
|
||||||
|
completion_tokens=50,
|
||||||
|
total_tokens=150,
|
||||||
|
cost_micros=12345,
|
||||||
|
)
|
||||||
|
return acc
|
||||||
|
|
||||||
|
|
||||||
|
async def _count_assistant_rows(
|
||||||
|
session: AsyncSession, thread_id: int, turn_id: str
|
||||||
|
) -> int:
|
||||||
|
result = await session.execute(
|
||||||
|
select(func.count())
|
||||||
|
.select_from(NewChatMessage)
|
||||||
|
.where(
|
||||||
|
NewChatMessage.thread_id == thread_id,
|
||||||
|
NewChatMessage.turn_id == turn_id,
|
||||||
|
NewChatMessage.role == NewChatMessageRole.ASSISTANT,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return int(result.scalar_one())
|
||||||
|
|
||||||
|
|
||||||
|
async def _count_token_usage_rows(session: AsyncSession, message_id: int) -> int:
|
||||||
|
result = await session.execute(
|
||||||
|
select(func.count())
|
||||||
|
.select_from(TokenUsage)
|
||||||
|
.where(TokenUsage.message_id == message_id)
|
||||||
|
)
|
||||||
|
return int(result.scalar_one())
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# persist_assistant_shell
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPersistAssistantShell:
|
||||||
|
async def test_first_call_inserts_empty_shell_and_returns_id(
|
||||||
|
self,
|
||||||
|
db_session,
|
||||||
|
db_user,
|
||||||
|
db_thread,
|
||||||
|
patched_shielded_session,
|
||||||
|
):
|
||||||
|
# Capture primitive ids before any persistence helper runs:
|
||||||
|
# the helpers commit/rollback the shared test session, which
|
||||||
|
# can detach ORM rows mid-test.
|
||||||
|
thread_id = db_thread.id
|
||||||
|
user_id_str = str(db_user.id)
|
||||||
|
turn_id = f"{thread_id}:1000"
|
||||||
|
|
||||||
|
msg_id = await persist_assistant_shell(
|
||||||
|
chat_id=thread_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
)
|
||||||
|
assert msg_id is not None and isinstance(msg_id, int)
|
||||||
|
|
||||||
|
row = await db_session.get(NewChatMessage, msg_id)
|
||||||
|
assert row is not None
|
||||||
|
assert row.thread_id == thread_id
|
||||||
|
assert row.role == NewChatMessageRole.ASSISTANT
|
||||||
|
assert row.turn_id == turn_id
|
||||||
|
# Empty shell payload — finalize_assistant_turn overwrites later.
|
||||||
|
assert row.content == [{"type": "text", "text": ""}]
|
||||||
|
|
||||||
|
async def test_second_call_with_same_turn_id_returns_same_id(
|
||||||
|
self,
|
||||||
|
db_session,
|
||||||
|
db_user,
|
||||||
|
db_thread,
|
||||||
|
patched_shielded_session,
|
||||||
|
):
|
||||||
|
# Capture primitive ids before any persistence helper runs:
|
||||||
|
# the helpers commit/rollback the shared test session, which
|
||||||
|
# can detach ORM rows mid-test.
|
||||||
|
thread_id = db_thread.id
|
||||||
|
user_id_str = str(db_user.id)
|
||||||
|
turn_id = f"{thread_id}:2000"
|
||||||
|
|
||||||
|
first_id = await persist_assistant_shell(
|
||||||
|
chat_id=thread_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
)
|
||||||
|
second_id = await persist_assistant_shell(
|
||||||
|
chat_id=thread_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert first_id is not None
|
||||||
|
assert first_id == second_id
|
||||||
|
# Exactly one row in the DB for this turn.
|
||||||
|
assert await _count_assistant_rows(db_session, thread_id, turn_id) == 1
|
||||||
|
|
||||||
|
async def test_missing_turn_id_returns_none(
|
||||||
|
self,
|
||||||
|
db_user,
|
||||||
|
db_thread,
|
||||||
|
patched_shielded_session,
|
||||||
|
):
|
||||||
|
thread_id = db_thread.id
|
||||||
|
user_id_str = str(db_user.id)
|
||||||
|
|
||||||
|
msg_id = await persist_assistant_shell(
|
||||||
|
chat_id=thread_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id="",
|
||||||
|
)
|
||||||
|
assert msg_id is None
|
||||||
|
|
||||||
|
async def test_after_persist_user_turn_resolves_assistant_id(
|
||||||
|
self,
|
||||||
|
db_session,
|
||||||
|
db_user,
|
||||||
|
db_thread,
|
||||||
|
patched_shielded_session,
|
||||||
|
):
|
||||||
|
thread_id = db_thread.id
|
||||||
|
user_id_str = str(db_user.id)
|
||||||
|
turn_id = f"{thread_id}:3000"
|
||||||
|
|
||||||
|
# The streaming layer always calls persist_user_turn first, so
|
||||||
|
# smoke-test the canonical sequence.
|
||||||
|
user_msg_id = await persist_user_turn(
|
||||||
|
chat_id=thread_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
user_query="hello",
|
||||||
|
)
|
||||||
|
assert isinstance(user_msg_id, int)
|
||||||
|
|
||||||
|
msg_id = await persist_assistant_shell(
|
||||||
|
chat_id=thread_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
)
|
||||||
|
assert msg_id is not None
|
||||||
|
# User row + assistant shell row = 2 rows for this turn.
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(func.count())
|
||||||
|
.select_from(NewChatMessage)
|
||||||
|
.where(
|
||||||
|
NewChatMessage.thread_id == thread_id,
|
||||||
|
NewChatMessage.turn_id == turn_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert result.scalar_one() == 2
|
||||||
|
|
||||||
|
async def test_double_call_with_same_turn_id_uses_on_conflict(
|
||||||
|
self,
|
||||||
|
db_session,
|
||||||
|
db_user,
|
||||||
|
db_thread,
|
||||||
|
patched_shielded_session,
|
||||||
|
):
|
||||||
|
"""Verifies the ON CONFLICT DO NOTHING path on the assistant
|
||||||
|
shell does not raise ``IntegrityError`` even when the second
|
||||||
|
writer races the first within a tight loop. ``test_second_call_with_same_turn_id_returns_same_id``
|
||||||
|
already covers the same-id semantics; this test additionally
|
||||||
|
asserts neither call raises so the debugger's
|
||||||
|
``raise-on-IntegrityError`` setting won't pause the streaming
|
||||||
|
path under contention.
|
||||||
|
"""
|
||||||
|
thread_id = db_thread.id
|
||||||
|
user_id_str = str(db_user.id)
|
||||||
|
turn_id = f"{thread_id}:3500"
|
||||||
|
|
||||||
|
# Both calls go through ``pg_insert(...).on_conflict_do_nothing``;
|
||||||
|
# the second one returns RETURNING=∅ and falls into the SELECT
|
||||||
|
# branch. Neither path raises.
|
||||||
|
first_id = await persist_assistant_shell(
|
||||||
|
chat_id=thread_id, user_id=user_id_str, turn_id=turn_id
|
||||||
|
)
|
||||||
|
second_id = await persist_assistant_shell(
|
||||||
|
chat_id=thread_id, user_id=user_id_str, turn_id=turn_id
|
||||||
|
)
|
||||||
|
assert first_id is not None
|
||||||
|
assert first_id == second_id
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# persist_user_turn
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPersistUserTurn:
|
||||||
|
async def test_returns_message_id_on_first_insert(
|
||||||
|
self,
|
||||||
|
db_session,
|
||||||
|
db_user,
|
||||||
|
db_thread,
|
||||||
|
patched_shielded_session,
|
||||||
|
):
|
||||||
|
thread_id = db_thread.id
|
||||||
|
user_id_str = str(db_user.id)
|
||||||
|
turn_id = f"{thread_id}:8000"
|
||||||
|
|
||||||
|
msg_id = await persist_user_turn(
|
||||||
|
chat_id=thread_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
user_query="hello",
|
||||||
|
)
|
||||||
|
assert isinstance(msg_id, int) and msg_id > 0
|
||||||
|
|
||||||
|
row = await db_session.get(NewChatMessage, msg_id)
|
||||||
|
assert row is not None
|
||||||
|
assert row.thread_id == thread_id
|
||||||
|
assert row.role == NewChatMessageRole.USER
|
||||||
|
assert row.turn_id == turn_id
|
||||||
|
assert row.content == [{"type": "text", "text": "hello"}]
|
||||||
|
|
||||||
|
async def test_returns_existing_id_on_conflict(
|
||||||
|
self,
|
||||||
|
db_session,
|
||||||
|
db_user,
|
||||||
|
db_thread,
|
||||||
|
patched_shielded_session,
|
||||||
|
):
|
||||||
|
thread_id = db_thread.id
|
||||||
|
user_id_str = str(db_user.id)
|
||||||
|
turn_id = f"{thread_id}:8100"
|
||||||
|
|
||||||
|
first_id = await persist_user_turn(
|
||||||
|
chat_id=thread_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
user_query="hello",
|
||||||
|
)
|
||||||
|
# Second call simulates a legacy FE ``appendMessage`` racing the
|
||||||
|
# SSE stream: ON CONFLICT DO NOTHING short-circuits at the DB
|
||||||
|
# level, the helper recovers the existing id via SELECT, and
|
||||||
|
# crucially does NOT raise ``IntegrityError`` (the debugger
|
||||||
|
# would otherwise pause on it).
|
||||||
|
second_id = await persist_user_turn(
|
||||||
|
chat_id=thread_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
user_query="ignored on conflict",
|
||||||
|
)
|
||||||
|
assert first_id is not None
|
||||||
|
assert first_id == second_id
|
||||||
|
|
||||||
|
# Exactly one user row for this turn.
|
||||||
|
count = await db_session.execute(
|
||||||
|
select(func.count())
|
||||||
|
.select_from(NewChatMessage)
|
||||||
|
.where(
|
||||||
|
NewChatMessage.thread_id == thread_id,
|
||||||
|
NewChatMessage.turn_id == turn_id,
|
||||||
|
NewChatMessage.role == NewChatMessageRole.USER,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert count.scalar_one() == 1
|
||||||
|
|
||||||
|
async def test_embeds_mentioned_documents_part(
|
||||||
|
self,
|
||||||
|
db_session,
|
||||||
|
db_user,
|
||||||
|
db_thread,
|
||||||
|
patched_shielded_session,
|
||||||
|
):
|
||||||
|
"""The full ``{id, title, document_type}`` triple forwarded by
|
||||||
|
the FE must round-trip into a single ``mentioned-documents``
|
||||||
|
ContentPart on the persisted user message — the history loader
|
||||||
|
renders the chips on reload from this part directly.
|
||||||
|
"""
|
||||||
|
thread_id = db_thread.id
|
||||||
|
user_id_str = str(db_user.id)
|
||||||
|
turn_id = f"{thread_id}:8200"
|
||||||
|
|
||||||
|
mentioned = [
|
||||||
|
{"id": 11, "title": "Alpha", "document_type": "GENERAL"},
|
||||||
|
{"id": 22, "title": "Beta", "document_type": "GENERAL"},
|
||||||
|
]
|
||||||
|
msg_id = await persist_user_turn(
|
||||||
|
chat_id=thread_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
user_query="hello",
|
||||||
|
mentioned_documents=mentioned,
|
||||||
|
)
|
||||||
|
assert isinstance(msg_id, int)
|
||||||
|
|
||||||
|
row = await db_session.get(NewChatMessage, msg_id)
|
||||||
|
assert row is not None
|
||||||
|
# Content is a 2-part list: text + mentioned-documents.
|
||||||
|
assert isinstance(row.content, list)
|
||||||
|
assert row.content[0] == {"type": "text", "text": "hello"}
|
||||||
|
assert row.content[1] == {
|
||||||
|
"type": "mentioned-documents",
|
||||||
|
"documents": [
|
||||||
|
{"id": 11, "title": "Alpha", "document_type": "GENERAL"},
|
||||||
|
{"id": 22, "title": "Beta", "document_type": "GENERAL"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def test_skips_mentioned_documents_when_empty_or_invalid(
|
||||||
|
self,
|
||||||
|
db_session,
|
||||||
|
db_user,
|
||||||
|
db_thread,
|
||||||
|
patched_shielded_session,
|
||||||
|
):
|
||||||
|
"""Empty list and entries missing required fields are dropped;
|
||||||
|
a ``mentioned-documents`` part is only emitted when at least
|
||||||
|
one normalised entry survived.
|
||||||
|
"""
|
||||||
|
thread_id = db_thread.id
|
||||||
|
user_id_str = str(db_user.id)
|
||||||
|
turn_id_empty = f"{thread_id}:8300"
|
||||||
|
turn_id_invalid = f"{thread_id}:8301"
|
||||||
|
|
||||||
|
msg_id_empty = await persist_user_turn(
|
||||||
|
chat_id=thread_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id_empty,
|
||||||
|
user_query="hi",
|
||||||
|
mentioned_documents=[],
|
||||||
|
)
|
||||||
|
assert isinstance(msg_id_empty, int)
|
||||||
|
row_empty = await db_session.get(NewChatMessage, msg_id_empty)
|
||||||
|
assert row_empty is not None
|
||||||
|
assert row_empty.content == [{"type": "text", "text": "hi"}]
|
||||||
|
|
||||||
|
# Each entry missing one required field — all skipped.
|
||||||
|
msg_id_invalid = await persist_user_turn(
|
||||||
|
chat_id=thread_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id_invalid,
|
||||||
|
user_query="hi",
|
||||||
|
mentioned_documents=[
|
||||||
|
{"title": "no id", "document_type": "GENERAL"}, # missing id
|
||||||
|
{"id": 99, "document_type": "GENERAL"}, # missing title
|
||||||
|
{"id": 100, "title": "no type"}, # missing document_type
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert isinstance(msg_id_invalid, int)
|
||||||
|
row_invalid = await db_session.get(NewChatMessage, msg_id_invalid)
|
||||||
|
assert row_invalid is not None
|
||||||
|
assert row_invalid.content == [{"type": "text", "text": "hi"}]
|
||||||
|
|
||||||
|
async def test_missing_turn_id_returns_none(
|
||||||
|
self,
|
||||||
|
db_user,
|
||||||
|
db_thread,
|
||||||
|
patched_shielded_session,
|
||||||
|
):
|
||||||
|
thread_id = db_thread.id
|
||||||
|
user_id_str = str(db_user.id)
|
||||||
|
|
||||||
|
msg_id = await persist_user_turn(
|
||||||
|
chat_id=thread_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id="",
|
||||||
|
user_query="hello",
|
||||||
|
)
|
||||||
|
assert msg_id is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# finalize_assistant_turn
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestFinalizeAssistantTurn:
|
||||||
|
async def test_writes_content_and_token_usage(
|
||||||
|
self,
|
||||||
|
db_session,
|
||||||
|
db_user,
|
||||||
|
db_thread,
|
||||||
|
db_search_space,
|
||||||
|
patched_shielded_session,
|
||||||
|
):
|
||||||
|
thread_id = db_thread.id
|
||||||
|
user_id_uuid = db_user.id
|
||||||
|
user_id_str = str(user_id_uuid)
|
||||||
|
search_space_id = db_search_space.id
|
||||||
|
turn_id = f"{thread_id}:4000"
|
||||||
|
|
||||||
|
msg_id = await persist_assistant_shell(
|
||||||
|
chat_id=thread_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
)
|
||||||
|
assert msg_id is not None
|
||||||
|
|
||||||
|
rich_content = [
|
||||||
|
{"type": "text", "text": "Hello world"},
|
||||||
|
{
|
||||||
|
"type": "tool-call",
|
||||||
|
"toolCallId": "call_x",
|
||||||
|
"toolName": "ls",
|
||||||
|
"args": {"path": "/"},
|
||||||
|
"argsText": '{\n "path": "/"\n}',
|
||||||
|
"result": {"files": []},
|
||||||
|
"langchainToolCallId": "lc_x",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
await finalize_assistant_turn(
|
||||||
|
message_id=msg_id,
|
||||||
|
chat_id=thread_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
content=rich_content,
|
||||||
|
accumulator=_accumulator_with_one_call(),
|
||||||
|
)
|
||||||
|
|
||||||
|
row = await db_session.get(NewChatMessage, msg_id)
|
||||||
|
await db_session.refresh(row)
|
||||||
|
assert row.content == rich_content
|
||||||
|
|
||||||
|
# Exactly one token_usage row keyed on this message_id.
|
||||||
|
usage_rows = (
|
||||||
|
(
|
||||||
|
await db_session.execute(
|
||||||
|
select(TokenUsage).where(TokenUsage.message_id == msg_id)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
assert len(usage_rows) == 1
|
||||||
|
usage = usage_rows[0]
|
||||||
|
assert usage.usage_type == "chat"
|
||||||
|
assert usage.prompt_tokens == 100
|
||||||
|
assert usage.completion_tokens == 50
|
||||||
|
assert usage.total_tokens == 150
|
||||||
|
assert usage.cost_micros == 12345
|
||||||
|
assert usage.thread_id == thread_id
|
||||||
|
assert usage.search_space_id == search_space_id
|
||||||
|
|
||||||
|
async def test_empty_content_writes_status_marker(
|
||||||
|
self,
|
||||||
|
db_session,
|
||||||
|
db_user,
|
||||||
|
db_thread,
|
||||||
|
db_search_space,
|
||||||
|
patched_shielded_session,
|
||||||
|
):
|
||||||
|
thread_id = db_thread.id
|
||||||
|
user_id_str = str(db_user.id)
|
||||||
|
search_space_id = db_search_space.id
|
||||||
|
turn_id = f"{thread_id}:5000"
|
||||||
|
|
||||||
|
msg_id = await persist_assistant_shell(
|
||||||
|
chat_id=thread_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
)
|
||||||
|
assert msg_id is not None
|
||||||
|
|
||||||
|
# Pure tool-call turn that aborted before any output, or
|
||||||
|
# interrupt before any event arrived — empty list.
|
||||||
|
await finalize_assistant_turn(
|
||||||
|
message_id=msg_id,
|
||||||
|
chat_id=thread_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
content=[],
|
||||||
|
accumulator=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
row = await db_session.get(NewChatMessage, msg_id)
|
||||||
|
await db_session.refresh(row)
|
||||||
|
assert row.content == [{"type": "status", "text": "(no text response)"}]
|
||||||
|
|
||||||
|
async def test_double_call_safe_via_on_conflict(
|
||||||
|
self,
|
||||||
|
db_session,
|
||||||
|
db_user,
|
||||||
|
db_thread,
|
||||||
|
db_search_space,
|
||||||
|
patched_shielded_session,
|
||||||
|
):
|
||||||
|
thread_id = db_thread.id
|
||||||
|
user_id_str = str(db_user.id)
|
||||||
|
search_space_id = db_search_space.id
|
||||||
|
turn_id = f"{thread_id}:6000"
|
||||||
|
|
||||||
|
msg_id = await persist_assistant_shell(
|
||||||
|
chat_id=thread_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
)
|
||||||
|
assert msg_id is not None
|
||||||
|
|
||||||
|
first_acc = _accumulator_with_one_call()
|
||||||
|
await finalize_assistant_turn(
|
||||||
|
message_id=msg_id,
|
||||||
|
chat_id=thread_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
content=[{"type": "text", "text": "first finalize"}],
|
||||||
|
accumulator=first_acc,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Simulate a follow-up finalize (e.g., resume retry within the
|
||||||
|
# shielded finally block firing twice). Different content, but
|
||||||
|
# ON CONFLICT DO NOTHING on token_usage means the cost from the
|
||||||
|
# first finalize stays authoritative.
|
||||||
|
second_acc = TurnTokenAccumulator()
|
||||||
|
second_acc.add(
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
prompt_tokens=999,
|
||||||
|
completion_tokens=999,
|
||||||
|
total_tokens=1998,
|
||||||
|
cost_micros=99999,
|
||||||
|
)
|
||||||
|
await finalize_assistant_turn(
|
||||||
|
message_id=msg_id,
|
||||||
|
chat_id=thread_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
content=[{"type": "text", "text": "second finalize"}],
|
||||||
|
accumulator=second_acc,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Content was overwritten by the second UPDATE.
|
||||||
|
row = await db_session.get(NewChatMessage, msg_id)
|
||||||
|
await db_session.refresh(row)
|
||||||
|
assert row.content == [{"type": "text", "text": "second finalize"}]
|
||||||
|
|
||||||
|
# But token_usage stayed at exactly one row, preserving the
|
||||||
|
# first finalize's authoritative cost.
|
||||||
|
assert await _count_token_usage_rows(db_session, msg_id) == 1
|
||||||
|
usage = (
|
||||||
|
await db_session.execute(
|
||||||
|
select(TokenUsage).where(TokenUsage.message_id == msg_id)
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
assert usage.cost_micros == 12345 # First finalize's value
|
||||||
|
|
||||||
|
async def test_append_message_style_insert_after_finalize_no_dupe(
|
||||||
|
self,
|
||||||
|
db_session,
|
||||||
|
db_user,
|
||||||
|
db_thread,
|
||||||
|
db_search_space,
|
||||||
|
patched_shielded_session,
|
||||||
|
):
|
||||||
|
"""Cross-writer race: ``append_message`` arrives after ``finalize_assistant_turn``.
|
||||||
|
|
||||||
|
Both target the same ``message_id``; the partial unique index
|
||||||
|
``uq_token_usage_message_id`` (migration 142) makes the second
|
||||||
|
insert a no-op via ``ON CONFLICT DO NOTHING``.
|
||||||
|
"""
|
||||||
|
from sqlalchemy import text as sa_text
|
||||||
|
|
||||||
|
thread_id = db_thread.id
|
||||||
|
user_uuid = db_user.id
|
||||||
|
user_id_str = str(user_uuid)
|
||||||
|
search_space_id = db_search_space.id
|
||||||
|
turn_id = f"{thread_id}:7000"
|
||||||
|
|
||||||
|
msg_id = await persist_assistant_shell(
|
||||||
|
chat_id=thread_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
)
|
||||||
|
assert msg_id is not None
|
||||||
|
|
||||||
|
await finalize_assistant_turn(
|
||||||
|
message_id=msg_id,
|
||||||
|
chat_id=thread_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id=turn_id,
|
||||||
|
content=[{"type": "text", "text": "from server"}],
|
||||||
|
accumulator=_accumulator_with_one_call(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now simulate the FE's append_message branch firing AFTER —
|
||||||
|
# the same INSERT ... ON CONFLICT DO NOTHING shape used by the
|
||||||
|
# route handler, keyed on the migration-142 partial unique
|
||||||
|
# index.
|
||||||
|
late_insert = (
|
||||||
|
pg_insert(TokenUsage)
|
||||||
|
.values(
|
||||||
|
usage_type="chat",
|
||||||
|
prompt_tokens=42,
|
||||||
|
completion_tokens=42,
|
||||||
|
total_tokens=84,
|
||||||
|
cost_micros=1,
|
||||||
|
model_breakdown=None,
|
||||||
|
call_details=None,
|
||||||
|
thread_id=thread_id,
|
||||||
|
message_id=msg_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_uuid,
|
||||||
|
)
|
||||||
|
.on_conflict_do_nothing(
|
||||||
|
index_elements=["message_id"],
|
||||||
|
index_where=sa_text("message_id IS NOT NULL"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await db_session.execute(late_insert)
|
||||||
|
await db_session.flush()
|
||||||
|
|
||||||
|
# Still exactly one row, with the original (server) cost value.
|
||||||
|
assert await _count_token_usage_rows(db_session, msg_id) == 1
|
||||||
|
usage = (
|
||||||
|
await db_session.execute(
|
||||||
|
select(TokenUsage).where(TokenUsage.message_id == msg_id)
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
assert usage.cost_micros == 12345
|
||||||
|
|
||||||
|
async def test_helper_never_raises_on_missing_message_id(
|
||||||
|
self,
|
||||||
|
db_session,
|
||||||
|
db_user,
|
||||||
|
db_thread,
|
||||||
|
db_search_space,
|
||||||
|
patched_shielded_session,
|
||||||
|
):
|
||||||
|
thread_id = db_thread.id
|
||||||
|
user_id_str = str(db_user.id)
|
||||||
|
search_space_id = db_search_space.id
|
||||||
|
|
||||||
|
# message_id that doesn't exist — finalize must log+return,
|
||||||
|
# never raise (called from shielded finally).
|
||||||
|
await finalize_assistant_turn(
|
||||||
|
message_id=999_999_999,
|
||||||
|
chat_id=thread_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id_str,
|
||||||
|
turn_id="anything",
|
||||||
|
content=[{"type": "text", "text": "x"}],
|
||||||
|
accumulator=_accumulator_with_one_call(),
|
||||||
|
)
|
||||||
|
# If we got here without an exception, the test passes.
|
||||||
|
# Sanity: no token_usage row created (FK to message would have
|
||||||
|
# been rejected anyway, but ON CONFLICT path may swallow
|
||||||
|
# FK errors as well; check directly).
|
||||||
|
assert await _count_token_usage_rows(db_session, 999_999_999) == 0
|
||||||
526
surfsense_backend/tests/unit/tasks/chat/test_content_builder.py
Normal file
526
surfsense_backend/tests/unit/tasks/chat/test_content_builder.py
Normal file
|
|
@ -0,0 +1,526 @@
|
||||||
|
"""Unit tests for ``AssistantContentBuilder``.
|
||||||
|
|
||||||
|
Pins the in-memory ``ContentPart[]`` projection so the JSONB the server
|
||||||
|
persists matches what the frontend renders live (see
|
||||||
|
``surfsense_web/lib/chat/streaming-state.ts``). Every test asserts both
|
||||||
|
the structural shape of ``snapshot()`` and that the snapshot is
|
||||||
|
``json.dumps``-safe (the streaming finally block writes it directly to
|
||||||
|
``new_chat_messages.content`` without an explicit serialization round
|
||||||
|
trip).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.tasks.chat.content_builder import AssistantContentBuilder
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_jsonb_safe(parts: list[dict]) -> None:
|
||||||
|
"""Sanity check: any snapshot must round-trip through ``json.dumps``."""
|
||||||
|
serialized = json.dumps(parts)
|
||||||
|
assert json.loads(serialized) == parts
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Text turns
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestTextOnly:
|
||||||
|
def test_single_text_block_collapses_consecutive_deltas(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_text_start("text-1")
|
||||||
|
b.on_text_delta("text-1", "Hello")
|
||||||
|
b.on_text_delta("text-1", " ")
|
||||||
|
b.on_text_delta("text-1", "world")
|
||||||
|
b.on_text_end("text-1")
|
||||||
|
|
||||||
|
snap = b.snapshot()
|
||||||
|
assert snap == [{"type": "text", "text": "Hello world"}]
|
||||||
|
assert not b.is_empty()
|
||||||
|
_assert_jsonb_safe(snap)
|
||||||
|
|
||||||
|
def test_empty_text_start_end_pair_leaves_no_part(self):
|
||||||
|
# Mirrors the FE: a text-start without any deltas should
|
||||||
|
# not materialise an empty ``{"type":"text","text":""}`` part.
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_text_start("text-1")
|
||||||
|
b.on_text_end("text-1")
|
||||||
|
|
||||||
|
assert b.snapshot() == []
|
||||||
|
assert b.is_empty()
|
||||||
|
|
||||||
|
def test_text_after_text_end_starts_fresh_part(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_text_start("text-1")
|
||||||
|
b.on_text_delta("text-1", "first")
|
||||||
|
b.on_text_end("text-1")
|
||||||
|
|
||||||
|
b.on_text_start("text-2")
|
||||||
|
b.on_text_delta("text-2", "second")
|
||||||
|
b.on_text_end("text-2")
|
||||||
|
|
||||||
|
snap = b.snapshot()
|
||||||
|
assert snap == [
|
||||||
|
{"type": "text", "text": "first"},
|
||||||
|
{"type": "text", "text": "second"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestReasoningThenText:
|
||||||
|
def test_reasoning_followed_by_text_yields_two_parts_in_order(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_reasoning_start("r-1")
|
||||||
|
b.on_reasoning_delta("r-1", "Considering options...")
|
||||||
|
b.on_reasoning_end("r-1")
|
||||||
|
|
||||||
|
b.on_text_start("text-1")
|
||||||
|
b.on_text_delta("text-1", "The answer is 42.")
|
||||||
|
b.on_text_end("text-1")
|
||||||
|
|
||||||
|
snap = b.snapshot()
|
||||||
|
assert snap == [
|
||||||
|
{"type": "reasoning", "text": "Considering options..."},
|
||||||
|
{"type": "text", "text": "The answer is 42."},
|
||||||
|
]
|
||||||
|
_assert_jsonb_safe(snap)
|
||||||
|
|
||||||
|
def test_text_delta_after_reasoning_implicitly_closes_reasoning(self):
|
||||||
|
# Mirrors FE ``appendText``: a text delta arriving while a
|
||||||
|
# reasoning part is "active" still produces a fresh text
|
||||||
|
# part, never appends into the reasoning block.
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_reasoning_start("r-1")
|
||||||
|
b.on_reasoning_delta("r-1", "thinking")
|
||||||
|
# No explicit reasoning_end — text delta should close it.
|
||||||
|
b.on_text_delta("text-1", "answer")
|
||||||
|
|
||||||
|
snap = b.snapshot()
|
||||||
|
assert snap == [
|
||||||
|
{"type": "reasoning", "text": "thinking"},
|
||||||
|
{"type": "text", "text": "answer"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tool calls
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolHeavyTurn:
|
||||||
|
def test_full_tool_lifecycle_produces_complete_tool_call_part(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
# Some narration before the tool fires.
|
||||||
|
b.on_text_start("text-1")
|
||||||
|
b.on_text_delta("text-1", "Searching...")
|
||||||
|
b.on_text_end("text-1")
|
||||||
|
|
||||||
|
b.on_tool_input_start(
|
||||||
|
ui_id="call_run123",
|
||||||
|
tool_name="web_search",
|
||||||
|
langchain_tool_call_id="lc_tool_abc",
|
||||||
|
)
|
||||||
|
b.on_tool_input_delta("call_run123", '{"query":')
|
||||||
|
b.on_tool_input_delta("call_run123", '"surfsense"}')
|
||||||
|
b.on_tool_input_available(
|
||||||
|
ui_id="call_run123",
|
||||||
|
tool_name="web_search",
|
||||||
|
args={"query": "surfsense"},
|
||||||
|
langchain_tool_call_id="lc_tool_abc",
|
||||||
|
)
|
||||||
|
b.on_tool_output_available(
|
||||||
|
ui_id="call_run123",
|
||||||
|
output={"status": "completed", "citations": {}},
|
||||||
|
langchain_tool_call_id="lc_tool_abc",
|
||||||
|
)
|
||||||
|
|
||||||
|
snap = b.snapshot()
|
||||||
|
assert snap[0] == {"type": "text", "text": "Searching..."}
|
||||||
|
tool_part = snap[1]
|
||||||
|
assert tool_part["type"] == "tool-call"
|
||||||
|
assert tool_part["toolCallId"] == "call_run123"
|
||||||
|
assert tool_part["toolName"] == "web_search"
|
||||||
|
assert tool_part["args"] == {"query": "surfsense"}
|
||||||
|
# ``argsText`` is the pretty-printed final JSON, not the raw
|
||||||
|
# streaming buffer (FE ``stream-pipeline.ts:128``).
|
||||||
|
assert tool_part["argsText"] == json.dumps(
|
||||||
|
{"query": "surfsense"}, indent=2, ensure_ascii=False
|
||||||
|
)
|
||||||
|
assert tool_part["langchainToolCallId"] == "lc_tool_abc"
|
||||||
|
assert tool_part["result"] == {"status": "completed", "citations": {}}
|
||||||
|
_assert_jsonb_safe(snap)
|
||||||
|
|
||||||
|
def test_tool_input_available_without_prior_start_creates_card(self):
|
||||||
|
# Legacy / parity_v2-OFF path: tool-input-available may be
|
||||||
|
# emitted without a prior tool-input-start (no streamed
|
||||||
|
# tool_call_chunks). The card should still be created.
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_tool_input_available(
|
||||||
|
ui_id="call_run42",
|
||||||
|
tool_name="grep",
|
||||||
|
args={"pattern": "TODO"},
|
||||||
|
langchain_tool_call_id="lc_x",
|
||||||
|
)
|
||||||
|
b.on_tool_output_available(
|
||||||
|
ui_id="call_run42",
|
||||||
|
output={"matches": 3},
|
||||||
|
langchain_tool_call_id="lc_x",
|
||||||
|
)
|
||||||
|
|
||||||
|
snap = b.snapshot()
|
||||||
|
assert len(snap) == 1
|
||||||
|
part = snap[0]
|
||||||
|
assert part["type"] == "tool-call"
|
||||||
|
assert part["toolCallId"] == "call_run42"
|
||||||
|
assert part["args"] == {"pattern": "TODO"}
|
||||||
|
assert part["langchainToolCallId"] == "lc_x"
|
||||||
|
assert part["result"] == {"matches": 3}
|
||||||
|
|
||||||
|
def test_tool_input_start_idempotent_for_same_ui_id(self):
|
||||||
|
# parity_v2: tool-input-start can fire from BOTH the chunk
|
||||||
|
# registration path AND the canonical ``on_tool_start`` path.
|
||||||
|
# The second call must not create a duplicate part.
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_tool_input_start("call_x", "ls", "lc_x")
|
||||||
|
b.on_tool_input_start("call_x", "ls", "lc_x")
|
||||||
|
snap = b.snapshot()
|
||||||
|
assert len(snap) == 1
|
||||||
|
|
||||||
|
def test_tool_input_delta_without_prior_start_is_silently_dropped(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_tool_input_delta("call_unknown", '{"orphan": "delta"}')
|
||||||
|
assert b.snapshot() == []
|
||||||
|
|
||||||
|
def test_langchain_tool_call_id_backfills_only_when_absent(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_tool_input_start("call_x", "ls", "lc_first")
|
||||||
|
# Late event must NOT clobber an already-set lc id.
|
||||||
|
b.on_tool_input_start("call_x", "ls", "lc_late")
|
||||||
|
snap = b.snapshot()
|
||||||
|
assert snap[0]["langchainToolCallId"] == "lc_first"
|
||||||
|
|
||||||
|
def test_args_text_streaming_buffer_reflects_concatenation(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_tool_input_start("call_x", "save_doc", "lc_y")
|
||||||
|
b.on_tool_input_delta("call_x", '{"title":')
|
||||||
|
b.on_tool_input_delta("call_x", '"Hi"}')
|
||||||
|
# Snapshot mid-stream should see the partial buffer (the FE
|
||||||
|
# tolerates invalid JSON and renders it as-is).
|
||||||
|
mid = b.snapshot()
|
||||||
|
assert mid[0]["argsText"] == '{"title":"Hi"}'
|
||||||
|
# Then tool-input-available replaces with pretty-printed.
|
||||||
|
b.on_tool_input_available(
|
||||||
|
"call_x",
|
||||||
|
"save_doc",
|
||||||
|
{"title": "Hi"},
|
||||||
|
"lc_y",
|
||||||
|
)
|
||||||
|
final = b.snapshot()
|
||||||
|
assert final[0]["argsText"] == json.dumps(
|
||||||
|
{"title": "Hi"}, indent=2, ensure_ascii=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Thinking steps & separators
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestThinkingSteps:
|
||||||
|
def test_first_thinking_step_unshifts_singleton_to_index_zero(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_text_start("text-1")
|
||||||
|
b.on_text_delta("text-1", "Hello")
|
||||||
|
b.on_text_end("text-1")
|
||||||
|
|
||||||
|
b.on_thinking_step("step-1", "Analyzing", "in_progress", ["item-a"])
|
||||||
|
|
||||||
|
snap = b.snapshot()
|
||||||
|
# Singleton goes to index 0 (FE ``updateThinkingSteps`` unshift).
|
||||||
|
assert snap[0]["type"] == "data-thinking-steps"
|
||||||
|
assert snap[0]["data"]["steps"] == [
|
||||||
|
{
|
||||||
|
"id": "step-1",
|
||||||
|
"title": "Analyzing",
|
||||||
|
"status": "in_progress",
|
||||||
|
"items": ["item-a"],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
assert snap[1] == {"type": "text", "text": "Hello"}
|
||||||
|
|
||||||
|
def test_subsequent_thinking_steps_mutate_the_singleton_in_place(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_thinking_step("step-1", "Analyzing", "in_progress", [])
|
||||||
|
b.on_thinking_step("step-2", "Searching", "in_progress", ["q"])
|
||||||
|
b.on_thinking_step("step-1", "Analyzing", "completed", ["done"])
|
||||||
|
|
||||||
|
snap = b.snapshot()
|
||||||
|
assert len([p for p in snap if p["type"] == "data-thinking-steps"]) == 1
|
||||||
|
steps = snap[0]["data"]["steps"]
|
||||||
|
assert len(steps) == 2
|
||||||
|
assert steps[0]["id"] == "step-1"
|
||||||
|
assert steps[0]["status"] == "completed"
|
||||||
|
assert steps[0]["items"] == ["done"]
|
||||||
|
assert steps[1]["id"] == "step-2"
|
||||||
|
|
||||||
|
def test_thinking_step_with_text_continues_appending_to_text(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_text_start("text-1")
|
||||||
|
b.on_text_delta("text-1", "first")
|
||||||
|
|
||||||
|
# Thinking step inserts at index 0, bumps text idx from 0 to 1.
|
||||||
|
b.on_thinking_step("step-1", "Working", "in_progress", [])
|
||||||
|
b.on_text_delta("text-1", " second")
|
||||||
|
|
||||||
|
snap = b.snapshot()
|
||||||
|
text_parts = [p for p in snap if p["type"] == "text"]
|
||||||
|
assert text_parts == [{"type": "text", "text": "first second"}]
|
||||||
|
|
||||||
|
def test_thinking_step_without_id_is_dropped(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_thinking_step("", "noop", "in_progress", None)
|
||||||
|
assert b.snapshot() == []
|
||||||
|
assert b.is_empty()
|
||||||
|
|
||||||
|
|
||||||
|
class TestStepSeparators:
|
||||||
|
def test_separator_no_op_before_any_content(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_step_separator()
|
||||||
|
assert b.snapshot() == []
|
||||||
|
|
||||||
|
def test_separator_after_text_appends_with_step_index_zero(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_text_start("text-1")
|
||||||
|
b.on_text_delta("text-1", "first")
|
||||||
|
b.on_text_end("text-1")
|
||||||
|
|
||||||
|
b.on_step_separator()
|
||||||
|
|
||||||
|
snap = b.snapshot()
|
||||||
|
assert snap[-1] == {
|
||||||
|
"type": "data-step-separator",
|
||||||
|
"data": {"stepIndex": 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_consecutive_separators_collapse_to_one(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_text_delta("text-1", "x")
|
||||||
|
b.on_step_separator()
|
||||||
|
b.on_step_separator() # No-op: previous part is already a separator.
|
||||||
|
snap = b.snapshot()
|
||||||
|
assert sum(1 for p in snap if p["type"] == "data-step-separator") == 1
|
||||||
|
|
||||||
|
def test_step_index_increments_across_separators(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_text_delta("text-1", "a")
|
||||||
|
b.on_step_separator()
|
||||||
|
b.on_text_delta("text-2", "b")
|
||||||
|
b.on_step_separator()
|
||||||
|
snap = b.snapshot()
|
||||||
|
seps = [p for p in snap if p["type"] == "data-step-separator"]
|
||||||
|
assert [s["data"]["stepIndex"] for s in seps] == [0, 1]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Interruption handling
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMarkInterrupted:
|
||||||
|
def test_running_tool_calls_get_state_aborted(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_tool_input_start("call_a", "ls", "lc_a")
|
||||||
|
b.on_tool_input_available("call_a", "ls", {"path": "/"}, "lc_a")
|
||||||
|
# No tool-output-available — simulates client disconnect mid-tool.
|
||||||
|
|
||||||
|
b.mark_interrupted()
|
||||||
|
|
||||||
|
snap = b.snapshot()
|
||||||
|
assert snap[0]["state"] == "aborted"
|
||||||
|
assert "result" not in snap[0]
|
||||||
|
|
||||||
|
def test_completed_tool_calls_are_not_marked_aborted(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_tool_input_start("call_a", "ls", "lc_a")
|
||||||
|
b.on_tool_input_available("call_a", "ls", {"path": "/"}, "lc_a")
|
||||||
|
b.on_tool_output_available("call_a", {"files": []}, "lc_a")
|
||||||
|
|
||||||
|
b.mark_interrupted()
|
||||||
|
|
||||||
|
snap = b.snapshot()
|
||||||
|
assert "state" not in snap[0]
|
||||||
|
assert snap[0]["result"] == {"files": []}
|
||||||
|
|
||||||
|
def test_open_text_block_keeps_accumulated_content(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_text_start("text-1")
|
||||||
|
b.on_text_delta("text-1", "partial")
|
||||||
|
# No on_text_end — disconnect mid-stream.
|
||||||
|
|
||||||
|
b.mark_interrupted()
|
||||||
|
|
||||||
|
snap = b.snapshot()
|
||||||
|
assert snap == [{"type": "text", "text": "partial"}]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# is_empty / snapshot semantics
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsEmpty:
|
||||||
|
def test_fresh_builder_is_empty(self):
|
||||||
|
assert AssistantContentBuilder().is_empty()
|
||||||
|
|
||||||
|
def test_text_part_breaks_emptiness(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_text_delta("text-1", "x")
|
||||||
|
assert not b.is_empty()
|
||||||
|
|
||||||
|
def test_tool_call_breaks_emptiness(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_tool_input_start("call_x", "ls", None)
|
||||||
|
assert not b.is_empty()
|
||||||
|
|
||||||
|
def test_thinking_step_alone_does_not_break_emptiness(self):
|
||||||
|
# Mirrors the "status marker fallback" semantic: a turn that
|
||||||
|
# only emitted a thinking step before being interrupted should
|
||||||
|
# still be treated as empty for finalize_assistant_turn's
|
||||||
|
# status-marker substitution.
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_thinking_step("step-1", "Working", "in_progress", [])
|
||||||
|
assert b.is_empty()
|
||||||
|
|
||||||
|
def test_step_separator_alone_does_not_break_emptiness(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
# Force a separator (it would normally no-op without content,
|
||||||
|
# but we simulate the underlying state to verify is_empty is
|
||||||
|
# not fooled by a stray separator).
|
||||||
|
b.parts.append({"type": "data-step-separator", "data": {"stepIndex": 0}})
|
||||||
|
assert b.is_empty()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSnapshotSemantics:
|
||||||
|
def test_snapshot_is_deep_copied_so_mutations_do_not_leak(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_tool_input_start("call_x", "ls", "lc_x")
|
||||||
|
b.on_tool_input_available("call_x", "ls", {"path": "/"}, "lc_x")
|
||||||
|
snap = b.snapshot()
|
||||||
|
# Mutate the returned snapshot — original should be untouched.
|
||||||
|
snap[0]["args"]["mutated"] = True
|
||||||
|
snap[0]["state"] = "tampered"
|
||||||
|
|
||||||
|
again = b.snapshot()
|
||||||
|
assert "mutated" not in again[0]["args"]
|
||||||
|
assert "state" not in again[0]
|
||||||
|
|
||||||
|
def test_snapshot_round_trips_through_json(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_thinking_step("step-1", "Analyzing", "in_progress", ["item"])
|
||||||
|
b.on_text_delta("text-1", "answer")
|
||||||
|
b.on_tool_input_start("call_x", "ls", "lc_x")
|
||||||
|
b.on_tool_input_available("call_x", "ls", {"path": "/"}, "lc_x")
|
||||||
|
b.on_tool_output_available("call_x", {"files": ["a.txt"]}, "lc_x")
|
||||||
|
b.on_step_separator()
|
||||||
|
snap = b.snapshot()
|
||||||
|
|
||||||
|
encoded = json.dumps(snap)
|
||||||
|
assert json.loads(encoded) == snap
|
||||||
|
|
||||||
|
|
||||||
|
class TestStats:
|
||||||
|
"""``stats()`` is the perf-log handle for [PERF] [stream_*]
|
||||||
|
finalize_payload lines. Pin the schema so an ops dashboard can
|
||||||
|
rely on these keys being present and meaningful.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_fresh_builder_reports_all_zeros(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
s = b.stats()
|
||||||
|
assert s == {
|
||||||
|
"parts": 0,
|
||||||
|
"bytes": 2, # ``[]`` is two bytes
|
||||||
|
"text": 0,
|
||||||
|
"reasoning": 0,
|
||||||
|
"tool_calls": 0,
|
||||||
|
"tool_calls_completed": 0,
|
||||||
|
"tool_calls_aborted": 0,
|
||||||
|
"thinking_step_parts": 0,
|
||||||
|
"step_separators": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_counts_each_part_type_independently(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_text_start("t1")
|
||||||
|
b.on_text_delta("t1", "hi")
|
||||||
|
b.on_text_end("t1")
|
||||||
|
b.on_reasoning_start("r1")
|
||||||
|
b.on_reasoning_delta("r1", "thinking")
|
||||||
|
b.on_reasoning_end("r1")
|
||||||
|
b.on_thinking_step("step-1", "Analyzing", "completed", ["item"])
|
||||||
|
b.on_step_separator()
|
||||||
|
b.on_tool_input_start("call_done", "ls", "lc_done")
|
||||||
|
b.on_tool_input_available("call_done", "ls", {}, "lc_done")
|
||||||
|
b.on_tool_output_available("call_done", {"ok": True}, "lc_done")
|
||||||
|
b.on_tool_input_start("call_running", "rm", "lc_running")
|
||||||
|
b.on_tool_input_available("call_running", "rm", {}, "lc_running")
|
||||||
|
|
||||||
|
s = b.stats()
|
||||||
|
assert s["text"] == 1
|
||||||
|
assert s["reasoning"] == 1
|
||||||
|
assert s["tool_calls"] == 2
|
||||||
|
assert s["tool_calls_completed"] == 1
|
||||||
|
assert s["tool_calls_aborted"] == 0
|
||||||
|
assert s["thinking_step_parts"] == 1
|
||||||
|
assert s["step_separators"] == 1
|
||||||
|
assert s["parts"] == sum(
|
||||||
|
[
|
||||||
|
s["text"],
|
||||||
|
s["reasoning"],
|
||||||
|
s["tool_calls"],
|
||||||
|
s["thinking_step_parts"],
|
||||||
|
s["step_separators"],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert s["bytes"] > 0
|
||||||
|
|
||||||
|
def test_mark_interrupted_flips_running_calls_to_aborted_in_stats(self):
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_tool_input_start("call_done", "ls", "lc_done")
|
||||||
|
b.on_tool_input_available("call_done", "ls", {}, "lc_done")
|
||||||
|
b.on_tool_output_available("call_done", {"ok": True}, "lc_done")
|
||||||
|
b.on_tool_input_start("call_running", "rm", "lc_running")
|
||||||
|
b.on_tool_input_available("call_running", "rm", {}, "lc_running")
|
||||||
|
|
||||||
|
# Pre-interrupt: one completed, one still running (no result).
|
||||||
|
pre = b.stats()
|
||||||
|
assert pre["tool_calls_completed"] == 1
|
||||||
|
assert pre["tool_calls_aborted"] == 0
|
||||||
|
|
||||||
|
b.mark_interrupted()
|
||||||
|
post = b.stats()
|
||||||
|
assert post["tool_calls_completed"] == 1
|
||||||
|
assert post["tool_calls_aborted"] == 1
|
||||||
|
assert post["tool_calls"] == 2
|
||||||
|
|
||||||
|
def test_bytes_reflects_jsonb_payload_size(self):
|
||||||
|
# Each text-delta adds bytes monotonically — useful for catching
|
||||||
|
# an unbounded delta buffer regression in the perf signal.
|
||||||
|
b = AssistantContentBuilder()
|
||||||
|
b.on_text_start("t1")
|
||||||
|
b.on_text_delta("t1", "x" * 10)
|
||||||
|
small = b.stats()["bytes"]
|
||||||
|
b.on_text_delta("t1", "x" * 1000)
|
||||||
|
large = b.stats()["bytes"]
|
||||||
|
assert large > small + 900
|
||||||
|
|
@ -457,6 +457,9 @@ def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows()
|
||||||
source = page_path.read_text(encoding="utf-8")
|
source = page_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
# Each flow tracks accepted boundary and passes it into shared terminal handling.
|
# Each flow tracks accepted boundary and passes it into shared terminal handling.
|
||||||
|
# The acceptance boundary is still meaningful post-refactor: it gates
|
||||||
|
# local-state cleanup (onPreAcceptFailure path) and lets the shared
|
||||||
|
# terminal handler distinguish pre-accept aborts from in-stream errors.
|
||||||
assert "let newAccepted = false;" in source
|
assert "let newAccepted = false;" in source
|
||||||
assert "let resumeAccepted = false;" in source
|
assert "let resumeAccepted = false;" in source
|
||||||
assert "let regenerateAccepted = false;" in source
|
assert "let regenerateAccepted = false;" in source
|
||||||
|
|
@ -464,12 +467,23 @@ def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows()
|
||||||
assert "accepted: resumeAccepted," in source
|
assert "accepted: resumeAccepted," in source
|
||||||
assert "accepted: regenerateAccepted," in source
|
assert "accepted: regenerateAccepted," in source
|
||||||
|
|
||||||
# Pre-accept abort in resume/regenerate exits without persistence.
|
# NOTE: The FE-side persistence guards previously asserted here
|
||||||
assert "if (!resumeAccepted) return;" in source
|
# ("if (!resumeAccepted) return;", "if (!regenerateAccepted) return;",
|
||||||
assert "if (!regenerateAccepted) return;" in source
|
# "if (newAccepted && !userPersisted) {") have been intentionally
|
||||||
|
# removed by the SSE-based message-id handshake refactor. Persistence
|
||||||
|
# is now server-authoritative: persist_user_turn / persist_assistant_shell
|
||||||
|
# run inside stream_new_chat / stream_resume_chat unconditionally and
|
||||||
|
# the FE consumes data-user-message-id / data-assistant-message-id
|
||||||
|
# SSE events to learn the canonical primary keys. There is therefore
|
||||||
|
# no FE call-site to guard, and the shared terminal handler relies
|
||||||
|
# purely on the `accepted` field above (forwarded to onAbort /
|
||||||
|
# onAcceptedStreamError) to drive UI cleanup. See
|
||||||
|
# tests/integration/chat/test_message_id_sse.py for the new
|
||||||
|
# cross-tier ID coherence guarantees.
|
||||||
|
|
||||||
# New flow persists only when accepted and not already persisted.
|
# The TURN_CANCELLING / THREAD_BUSY retry plumbing is independent
|
||||||
assert "if (newAccepted && !userPersisted) {" in source
|
# of the persistence refactor and must still exist on every
|
||||||
|
# start-stream fetch.
|
||||||
assert "const fetchWithTurnCancellingRetry = useCallback(" in source
|
assert "const fetchWithTurnCancellingRetry = useCallback(" in source
|
||||||
assert "computeFallbackTurnCancellingRetryDelay" in source
|
assert "computeFallbackTurnCancellingRetryDelay" in source
|
||||||
assert 'withMeta.errorCode === "TURN_CANCELLING"' in source
|
assert 'withMeta.errorCode === "TURN_CANCELLING"' in source
|
||||||
|
|
|
||||||
|
|
@ -82,6 +82,7 @@ import {
|
||||||
mergeChatTurnIdIntoMessage,
|
mergeChatTurnIdIntoMessage,
|
||||||
mergeEditedInterruptAction,
|
mergeEditedInterruptAction,
|
||||||
readStreamedChatTurnId,
|
readStreamedChatTurnId,
|
||||||
|
readStreamedMessageId,
|
||||||
} from "@/lib/chat/stream-side-effects";
|
} from "@/lib/chat/stream-side-effects";
|
||||||
import {
|
import {
|
||||||
buildContentForPersistence,
|
buildContentForPersistence,
|
||||||
|
|
@ -256,110 +257,17 @@ export default function NewChatPage() {
|
||||||
[tokenUsageStore]
|
[tokenUsageStore]
|
||||||
);
|
);
|
||||||
|
|
||||||
const persistUserTurn = useCallback(
|
// NOTE: ``persistUserTurn`` / ``persistAssistantTurn`` callbacks
|
||||||
async ({
|
// were removed in the SSE-based message ID handshake refactor.
|
||||||
threadId,
|
// ``stream_new_chat`` and ``stream_resume_chat`` now persist both
|
||||||
userMsgId,
|
// the user and assistant rows server-side via
|
||||||
content,
|
// ``persist_user_turn`` / ``persist_assistant_shell`` and emit
|
||||||
mentionedDocs,
|
// ``data-user-message-id`` / ``data-assistant-message-id`` SSE
|
||||||
turnId,
|
// events; the consumers below rename the optimistic ids in real
|
||||||
logContext,
|
// time. ``persistAssistantErrorMessage`` (above) is intentionally
|
||||||
}: {
|
// kept — it is the pre-stream-error fallback fired when the
|
||||||
threadId: number | null;
|
// server NEVER accepted the request, and the BE has nothing to
|
||||||
userMsgId: string;
|
// persist in that case.
|
||||||
content: unknown;
|
|
||||||
mentionedDocs?: MentionedDocumentInfo[];
|
|
||||||
turnId?: string | null;
|
|
||||||
logContext: string;
|
|
||||||
}) => {
|
|
||||||
if (!threadId) return null;
|
|
||||||
try {
|
|
||||||
const normalizedContent = Array.isArray(content) ? ([...content] as unknown[]) : [content];
|
|
||||||
const hasMentionedDocumentsPart = normalizedContent.some(
|
|
||||||
(part) => MentionedDocumentsPartSchema.safeParse(part).success
|
|
||||||
);
|
|
||||||
if (mentionedDocs && mentionedDocs.length > 0 && !hasMentionedDocumentsPart) {
|
|
||||||
normalizedContent.push({
|
|
||||||
type: "mentioned-documents",
|
|
||||||
documents: mentionedDocs,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
const savedUserMessage = await appendMessage(threadId, {
|
|
||||||
role: "user",
|
|
||||||
content: normalizedContent as AppendMessage["content"],
|
|
||||||
turn_id: turnId,
|
|
||||||
});
|
|
||||||
const newUserMsgId = `msg-${savedUserMessage.id}`;
|
|
||||||
setMessages((prev) =>
|
|
||||||
prev.map((m) =>
|
|
||||||
m.id === userMsgId
|
|
||||||
? mergeChatTurnIdIntoMessage({ ...m, id: newUserMsgId }, savedUserMessage.turn_id)
|
|
||||||
: m
|
|
||||||
)
|
|
||||||
);
|
|
||||||
if (mentionedDocs && mentionedDocs.length > 0) {
|
|
||||||
setMessageDocumentsMap((prev) => {
|
|
||||||
const { [userMsgId]: _, ...rest } = prev;
|
|
||||||
return {
|
|
||||||
...rest,
|
|
||||||
[newUserMsgId]: mentionedDocs,
|
|
||||||
};
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return newUserMsgId;
|
|
||||||
} catch (err) {
|
|
||||||
console.error(`Failed to persist ${logContext} user message:`, err);
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[setMessageDocumentsMap]
|
|
||||||
);
|
|
||||||
|
|
||||||
const persistAssistantTurn = useCallback(
|
|
||||||
async ({
|
|
||||||
threadId,
|
|
||||||
assistantMsgId,
|
|
||||||
content,
|
|
||||||
tokenUsage,
|
|
||||||
turnId,
|
|
||||||
logContext,
|
|
||||||
onRemapped,
|
|
||||||
}: {
|
|
||||||
threadId: number | null;
|
|
||||||
assistantMsgId: string;
|
|
||||||
content: unknown;
|
|
||||||
tokenUsage?: TokenUsageData;
|
|
||||||
turnId?: string | null;
|
|
||||||
logContext: string;
|
|
||||||
onRemapped?: (newMsgId: string) => void;
|
|
||||||
}) => {
|
|
||||||
if (!threadId) return null;
|
|
||||||
try {
|
|
||||||
const savedMessage = await appendMessage(threadId, {
|
|
||||||
role: "assistant",
|
|
||||||
content: content as AppendMessage["content"],
|
|
||||||
token_usage: tokenUsage,
|
|
||||||
turn_id: turnId,
|
|
||||||
});
|
|
||||||
const newMsgId = `msg-${savedMessage.id}`;
|
|
||||||
tokenUsageStore.rename(assistantMsgId, newMsgId);
|
|
||||||
setMessages((prev) =>
|
|
||||||
prev.map((m) =>
|
|
||||||
m.id === assistantMsgId
|
|
||||||
? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id)
|
|
||||||
: m
|
|
||||||
)
|
|
||||||
);
|
|
||||||
onRemapped?.(newMsgId);
|
|
||||||
return newMsgId;
|
|
||||||
} catch (err) {
|
|
||||||
console.error(`Failed to persist ${logContext} assistant message:`, err);
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[tokenUsageStore]
|
|
||||||
);
|
|
||||||
|
|
||||||
// Get disabled tools from the tool toggle UI
|
// Get disabled tools from the tool toggle UI
|
||||||
const disabledTools = useAtomValue(disabledToolsAtom);
|
const disabledTools = useAtomValue(disabledToolsAtom);
|
||||||
|
|
@ -891,8 +799,13 @@ export default function NewChatPage() {
|
||||||
setPendingUserImageUrls((prev) => prev.filter((u) => !urlsSnapshot.includes(u)));
|
setPendingUserImageUrls((prev) => prev.filter((u) => !urlsSnapshot.includes(u)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add user message to state
|
// Add user message to state. Mutable because the SSE
|
||||||
const userMsgId = `msg-user-${Date.now()}`;
|
// ``data-user-message-id`` handler (below) renames this
|
||||||
|
// optimistic id to the canonical ``msg-{db_id}`` once the
|
||||||
|
// backend's ``persist_user_turn`` resolves the row, and
|
||||||
|
// the in-stream flush / interrupt closures need to see
|
||||||
|
// the post-rename value via this live ``let`` binding.
|
||||||
|
let userMsgId = `msg-user-${Date.now()}`;
|
||||||
|
|
||||||
// Always include author metadata so the UI layer can decide visibility
|
// Always include author metadata so the UI layer can decide visibility
|
||||||
const authorMetadata = currentUser
|
const authorMetadata = currentUser
|
||||||
|
|
@ -958,22 +871,16 @@ export default function NewChatPage() {
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
const persistContent: unknown[] = [...userDisplayContent];
|
|
||||||
|
|
||||||
if (allMentionedDocs.length > 0) {
|
|
||||||
persistContent.push({
|
|
||||||
type: "mentioned-documents",
|
|
||||||
documents: allMentionedDocs,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start streaming response
|
// Start streaming response
|
||||||
setIsRunning(true);
|
setIsRunning(true);
|
||||||
const controller = new AbortController();
|
const controller = new AbortController();
|
||||||
abortControllerRef.current = controller;
|
abortControllerRef.current = controller;
|
||||||
|
|
||||||
// Prepare assistant message
|
// Prepare assistant message. Mutable for the same reason
|
||||||
const assistantMsgId = `msg-assistant-${Date.now()}`;
|
// as ``userMsgId`` above — the ``data-assistant-message-id``
|
||||||
|
// SSE handler reassigns this once
|
||||||
|
// ``persist_assistant_shell`` returns its canonical id.
|
||||||
|
let assistantMsgId = `msg-assistant-${Date.now()}`;
|
||||||
const currentThinkingSteps = new Map<string, ThinkingStepData>();
|
const currentThinkingSteps = new Map<string, ThinkingStepData>();
|
||||||
const contentPartsState: ContentPartsState = {
|
const contentPartsState: ContentPartsState = {
|
||||||
contentParts: [],
|
contentParts: [],
|
||||||
|
|
@ -983,11 +890,7 @@ export default function NewChatPage() {
|
||||||
};
|
};
|
||||||
const { contentParts } = contentPartsState;
|
const { contentParts } = contentPartsState;
|
||||||
let wasInterrupted = false;
|
let wasInterrupted = false;
|
||||||
let tokenUsageData: TokenUsageData | null = null;
|
|
||||||
let newAccepted = false;
|
let newAccepted = false;
|
||||||
let userPersisted = false;
|
|
||||||
// Captured from ``data-turn-info`` at stream start.
|
|
||||||
let streamedChatTurnId: string | null = null;
|
|
||||||
let streamBatcher: FrameBatchedUpdater | null = null;
|
let streamBatcher: FrameBatchedUpdater | null = null;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|
@ -1047,6 +950,18 @@ export default function NewChatPage() {
|
||||||
mentioned_surfsense_doc_ids: hasSurfsenseDocIds
|
mentioned_surfsense_doc_ids: hasSurfsenseDocIds
|
||||||
? mentionedDocumentIds.surfsense_doc_ids
|
? mentionedDocumentIds.surfsense_doc_ids
|
||||||
: undefined,
|
: undefined,
|
||||||
|
// Full mention metadata so the BE can embed a
|
||||||
|
// ``mentioned-documents`` ContentPart on the
|
||||||
|
// persisted user message (replaces the old FE-side
|
||||||
|
// injection in ``persistUserTurn``).
|
||||||
|
mentioned_documents:
|
||||||
|
allMentionedDocs.length > 0
|
||||||
|
? allMentionedDocs.map((d) => ({
|
||||||
|
id: d.id,
|
||||||
|
title: d.title,
|
||||||
|
document_type: d.document_type,
|
||||||
|
}))
|
||||||
|
: undefined,
|
||||||
disabled_tools: disabledTools.length > 0 ? disabledTools : undefined,
|
disabled_tools: disabledTools.length > 0 ? disabledTools : undefined,
|
||||||
...(userImages.length > 0 ? { user_images: userImages } : {}),
|
...(userImages.length > 0 ? { user_images: userImages } : {}),
|
||||||
}),
|
}),
|
||||||
|
|
@ -1089,7 +1004,6 @@ export default function NewChatPage() {
|
||||||
scheduleFlush,
|
scheduleFlush,
|
||||||
forceFlush,
|
forceFlush,
|
||||||
onTokenUsage: (data) => {
|
onTokenUsage: (data) => {
|
||||||
tokenUsageData = data;
|
|
||||||
tokenUsageStore.set(assistantMsgId, data);
|
tokenUsageStore.set(assistantMsgId, data);
|
||||||
},
|
},
|
||||||
onTurnStatus: (data) => {
|
onTurnStatus: (data) => {
|
||||||
|
|
@ -1189,7 +1103,6 @@ export default function NewChatPage() {
|
||||||
|
|
||||||
case "data-turn-info": {
|
case "data-turn-info": {
|
||||||
const turnId = readStreamedChatTurnId(parsed.data);
|
const turnId = readStreamedChatTurnId(parsed.data);
|
||||||
streamedChatTurnId = turnId;
|
|
||||||
if (turnId) {
|
if (turnId) {
|
||||||
setMessages((prev) =>
|
setMessages((prev) =>
|
||||||
applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId)
|
applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId)
|
||||||
|
|
@ -1197,46 +1110,96 @@ export default function NewChatPage() {
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "data-user-message-id": {
|
||||||
|
// Server-authoritative user message id resolved by
|
||||||
|
// ``persist_user_turn`` (or recovered via ON CONFLICT).
|
||||||
|
// Rename the optimistic ``msg-user-XXX`` placeholder to
|
||||||
|
// the canonical ``msg-{db_id}`` so DB-id-gated UI
|
||||||
|
// (comments, edit-from-this-message) unlocks immediately,
|
||||||
|
// migrate the local mentioned-documents map, and reassign
|
||||||
|
// the closure variable so all downstream
|
||||||
|
// ``m.id === userMsgId`` checks see the new value.
|
||||||
|
const parsedMsg = readStreamedMessageId(parsed.data);
|
||||||
|
if (!parsedMsg) break;
|
||||||
|
const newUserMsgId = `msg-${parsedMsg.messageId}`;
|
||||||
|
const oldUserMsgId = userMsgId;
|
||||||
|
setMessages((prev) =>
|
||||||
|
prev.map((m) =>
|
||||||
|
m.id === oldUserMsgId
|
||||||
|
? mergeChatTurnIdIntoMessage(
|
||||||
|
{ ...m, id: newUserMsgId },
|
||||||
|
parsedMsg.turnId
|
||||||
|
)
|
||||||
|
: m
|
||||||
|
)
|
||||||
|
);
|
||||||
|
if (allMentionedDocs.length > 0) {
|
||||||
|
setMessageDocumentsMap((prev) => {
|
||||||
|
if (!(oldUserMsgId in prev)) {
|
||||||
|
return { ...prev, [newUserMsgId]: allMentionedDocs };
|
||||||
|
}
|
||||||
|
const { [oldUserMsgId]: _removed, ...rest } = prev;
|
||||||
|
return { ...rest, [newUserMsgId]: allMentionedDocs };
|
||||||
|
});
|
||||||
|
}
|
||||||
|
userMsgId = newUserMsgId;
|
||||||
|
if (isNewThread) {
|
||||||
|
// First user-side row landed in ``new_chat_messages``;
|
||||||
|
// refresh the sidebar so the freshly-bumped
|
||||||
|
// ``thread.updated_at`` reorders this thread.
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: ["threads", String(searchSpaceId)],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case "data-assistant-message-id": {
|
||||||
|
// Server-authoritative assistant message id resolved
|
||||||
|
// by ``persist_assistant_shell``. Rename the optimistic
|
||||||
|
// id, migrate ``tokenUsageStore`` so any pending
|
||||||
|
// ``data-token-usage`` payload binds to the new id,
|
||||||
|
// remap any in-flight ``pendingInterrupt`` reference,
|
||||||
|
// and reassign the closure variable so the in-stream
|
||||||
|
// flush callback (line ~1074) keeps writing to the
|
||||||
|
// renamed message.
|
||||||
|
const parsedMsg = readStreamedMessageId(parsed.data);
|
||||||
|
if (!parsedMsg) break;
|
||||||
|
const newAssistantMsgId = `msg-${parsedMsg.messageId}`;
|
||||||
|
const oldAssistantMsgId = assistantMsgId;
|
||||||
|
tokenUsageStore.rename(oldAssistantMsgId, newAssistantMsgId);
|
||||||
|
setMessages((prev) =>
|
||||||
|
prev.map((m) =>
|
||||||
|
m.id === oldAssistantMsgId
|
||||||
|
? mergeChatTurnIdIntoMessage(
|
||||||
|
{ ...m, id: newAssistantMsgId },
|
||||||
|
parsedMsg.turnId
|
||||||
|
)
|
||||||
|
: m
|
||||||
|
)
|
||||||
|
);
|
||||||
|
setPendingInterrupt((prev) =>
|
||||||
|
prev && prev.assistantMsgId === oldAssistantMsgId
|
||||||
|
? { ...prev, assistantMsgId: newAssistantMsgId }
|
||||||
|
: prev
|
||||||
|
);
|
||||||
|
assistantMsgId = newAssistantMsgId;
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
batcher.flush();
|
batcher.flush();
|
||||||
|
|
||||||
// Skip persistence for interrupted messages -- handleResume will persist the final version
|
// Server-authoritative persistence: ``stream_new_chat``
|
||||||
const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI);
|
// already wrote the user row in ``persist_user_turn``
|
||||||
|
// (the FE renamed the optimistic id mid-stream via
|
||||||
|
// ``data-user-message-id``) and finalises the assistant
|
||||||
|
// row in ``finalize_assistant_turn`` from a shielded
|
||||||
|
// ``finally`` block. Nothing left for the FE to persist
|
||||||
|
// here — track the response and unblock the UI.
|
||||||
if (contentParts.length > 0 && !wasInterrupted) {
|
if (contentParts.length > 0 && !wasInterrupted) {
|
||||||
if (!userPersisted) {
|
|
||||||
const persistedUserMsgId = await persistUserTurn({
|
|
||||||
threadId: currentThreadId,
|
|
||||||
userMsgId,
|
|
||||||
content: persistContent,
|
|
||||||
mentionedDocs: allMentionedDocs,
|
|
||||||
turnId: streamedChatTurnId,
|
|
||||||
logContext: "new chat",
|
|
||||||
});
|
|
||||||
userPersisted = Boolean(persistedUserMsgId);
|
|
||||||
if (userPersisted && isNewThread) {
|
|
||||||
queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
await persistAssistantTurn({
|
|
||||||
threadId: currentThreadId,
|
|
||||||
assistantMsgId,
|
|
||||||
content: finalContent,
|
|
||||||
tokenUsage: tokenUsageData ?? undefined,
|
|
||||||
turnId: streamedChatTurnId,
|
|
||||||
logContext: "new chat",
|
|
||||||
onRemapped: (newMsgId) => {
|
|
||||||
setPendingInterrupt((prev) =>
|
|
||||||
prev && prev.assistantMsgId === assistantMsgId
|
|
||||||
? { ...prev, assistantMsgId: newMsgId }
|
|
||||||
: prev
|
|
||||||
);
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Track successful response
|
|
||||||
trackChatResponseReceived(searchSpaceId, currentThreadId);
|
trackChatResponseReceived(searchSpaceId, currentThreadId);
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
|
@ -1247,51 +1210,21 @@ export default function NewChatPage() {
|
||||||
threadId: currentThreadId,
|
threadId: currentThreadId,
|
||||||
assistantMsgId,
|
assistantMsgId,
|
||||||
accepted: newAccepted,
|
accepted: newAccepted,
|
||||||
onAbort: async () => {
|
// Server-side ``finalize_assistant_turn`` runs from a
|
||||||
if (newAccepted && !userPersisted) {
|
// shielded ``anyio.CancelScope(shield=True)`` finally
|
||||||
const persistedUserMsgId = await persistUserTurn({
|
// block, so partial content (incl. abort-mid-stream)
|
||||||
threadId: currentThreadId,
|
// is already persisted by the BE for the assistant
|
||||||
userMsgId,
|
// row, and ``persist_user_turn`` ran before any LLM
|
||||||
content: persistContent,
|
// call. The FE's only remaining responsibility on
|
||||||
mentionedDocs: allMentionedDocs,
|
// abort / accepted-stream-error is to surface the
|
||||||
turnId: streamedChatTurnId,
|
// error toast (handled by ``handleStreamTerminalError``
|
||||||
logContext: "new chat (aborted)",
|
// itself).
|
||||||
});
|
|
||||||
userPersisted = Boolean(persistedUserMsgId);
|
|
||||||
if (userPersisted && isNewThread) {
|
|
||||||
queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const hasContent = hasPersistableContent(contentParts, toolsWithUI);
|
|
||||||
if (hasContent && currentThreadId) {
|
|
||||||
const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI);
|
|
||||||
await persistAssistantTurn({
|
|
||||||
threadId: currentThreadId,
|
|
||||||
assistantMsgId,
|
|
||||||
content: partialContent,
|
|
||||||
turnId: streamedChatTurnId,
|
|
||||||
logContext: "partial new chat",
|
|
||||||
});
|
|
||||||
}
|
|
||||||
},
|
|
||||||
onAcceptedStreamError: async () => {
|
|
||||||
if (!userPersisted) {
|
|
||||||
const persistedUserMsgId = await persistUserTurn({
|
|
||||||
threadId: currentThreadId,
|
|
||||||
userMsgId,
|
|
||||||
content: persistContent,
|
|
||||||
mentionedDocs: allMentionedDocs,
|
|
||||||
turnId: streamedChatTurnId,
|
|
||||||
logContext: "new chat (stream error)",
|
|
||||||
});
|
|
||||||
userPersisted = Boolean(persistedUserMsgId);
|
|
||||||
if (userPersisted && isNewThread) {
|
|
||||||
queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
onPreAcceptFailure: async () => {
|
onPreAcceptFailure: async () => {
|
||||||
|
// Pre-accept failure means the BE never accepted the
|
||||||
|
// request — no server-side persistence ran. Roll
|
||||||
|
// back the optimistic UI insertions we made before
|
||||||
|
// the fetch so the user message and any local
|
||||||
|
// mentioned-docs metadata don't linger.
|
||||||
setMessages((prev) => prev.filter((m) => m.id !== userMsgId));
|
setMessages((prev) => prev.filter((m) => m.id !== userMsgId));
|
||||||
setMessageDocumentsMap((prev) => {
|
setMessageDocumentsMap((prev) => {
|
||||||
if (!(userMsgId in prev)) return prev;
|
if (!(userMsgId in prev)) return prev;
|
||||||
|
|
@ -1325,8 +1258,6 @@ export default function NewChatPage() {
|
||||||
fetchWithTurnCancellingRetry,
|
fetchWithTurnCancellingRetry,
|
||||||
handleStreamTerminalError,
|
handleStreamTerminalError,
|
||||||
handleChatFailure,
|
handleChatFailure,
|
||||||
persistAssistantTurn,
|
|
||||||
persistUserTurn,
|
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
@ -1339,7 +1270,12 @@ export default function NewChatPage() {
|
||||||
}>
|
}>
|
||||||
) => {
|
) => {
|
||||||
if (!pendingInterrupt) return;
|
if (!pendingInterrupt) return;
|
||||||
const { threadId: resumeThreadId, assistantMsgId } = pendingInterrupt;
|
const { threadId: resumeThreadId } = pendingInterrupt;
|
||||||
|
// Destructured separately as ``let`` so the SSE
|
||||||
|
// ``data-assistant-message-id`` handler (resume always
|
||||||
|
// allocates a fresh server-side row) can rename it to
|
||||||
|
// the canonical ``msg-{db_id}`` mid-stream.
|
||||||
|
let assistantMsgId = pendingInterrupt.assistantMsgId;
|
||||||
setPendingInterrupt(null);
|
setPendingInterrupt(null);
|
||||||
setIsRunning(true);
|
setIsRunning(true);
|
||||||
|
|
||||||
|
|
@ -1362,10 +1298,7 @@ export default function NewChatPage() {
|
||||||
toolCallIndices: new Map(),
|
toolCallIndices: new Map(),
|
||||||
};
|
};
|
||||||
const { contentParts, toolCallIndices } = contentPartsState;
|
const { contentParts, toolCallIndices } = contentPartsState;
|
||||||
let tokenUsageData: TokenUsageData | null = null;
|
|
||||||
let resumeAccepted = false;
|
let resumeAccepted = false;
|
||||||
// Captured from ``data-turn-info`` at stream start.
|
|
||||||
let streamedChatTurnId: string | null = null;
|
|
||||||
let streamBatcher: FrameBatchedUpdater | null = null;
|
let streamBatcher: FrameBatchedUpdater | null = null;
|
||||||
|
|
||||||
const existingMsg = messages.find((m) => m.id === assistantMsgId);
|
const existingMsg = messages.find((m) => m.id === assistantMsgId);
|
||||||
|
|
@ -1466,7 +1399,6 @@ export default function NewChatPage() {
|
||||||
scheduleFlush,
|
scheduleFlush,
|
||||||
forceFlush,
|
forceFlush,
|
||||||
onTokenUsage: (data) => {
|
onTokenUsage: (data) => {
|
||||||
tokenUsageData = data;
|
|
||||||
tokenUsageStore.set(assistantMsgId, data);
|
tokenUsageStore.set(assistantMsgId, data);
|
||||||
},
|
},
|
||||||
onTurnStatus: (data) => {
|
onTurnStatus: (data) => {
|
||||||
|
|
@ -1514,7 +1446,6 @@ export default function NewChatPage() {
|
||||||
|
|
||||||
case "data-turn-info": {
|
case "data-turn-info": {
|
||||||
const turnId = readStreamedChatTurnId(parsed.data);
|
const turnId = readStreamedChatTurnId(parsed.data);
|
||||||
streamedChatTurnId = turnId;
|
|
||||||
if (turnId) {
|
if (turnId) {
|
||||||
setMessages((prev) =>
|
setMessages((prev) =>
|
||||||
applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId)
|
applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId)
|
||||||
|
|
@ -1522,22 +1453,44 @@ export default function NewChatPage() {
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "data-assistant-message-id": {
|
||||||
|
// Resume always allocates a fresh ``new_chat_messages``
|
||||||
|
// row anchored to a new ``turn_id`` (the original
|
||||||
|
// interrupted turn's row stays as-is), so this is a
|
||||||
|
// real id swap. Rename the optimistic placeholder to
|
||||||
|
// ``msg-{db_id}`` and reassign closure state. Resume
|
||||||
|
// does NOT emit ``data-user-message-id`` — the user
|
||||||
|
// row belongs to the original interrupted turn.
|
||||||
|
const parsedMsg = readStreamedMessageId(parsed.data);
|
||||||
|
if (!parsedMsg) break;
|
||||||
|
const newAssistantMsgId = `msg-${parsedMsg.messageId}`;
|
||||||
|
const oldAssistantMsgId = assistantMsgId;
|
||||||
|
tokenUsageStore.rename(oldAssistantMsgId, newAssistantMsgId);
|
||||||
|
setMessages((prev) =>
|
||||||
|
prev.map((m) =>
|
||||||
|
m.id === oldAssistantMsgId
|
||||||
|
? mergeChatTurnIdIntoMessage(
|
||||||
|
{ ...m, id: newAssistantMsgId },
|
||||||
|
parsedMsg.turnId
|
||||||
|
)
|
||||||
|
: m
|
||||||
|
)
|
||||||
|
);
|
||||||
|
assistantMsgId = newAssistantMsgId;
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
batcher.flush();
|
batcher.flush();
|
||||||
|
|
||||||
const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI);
|
// Server-authoritative persistence: ``stream_resume_chat``
|
||||||
if (contentParts.length > 0) {
|
// finalises the assistant row in
|
||||||
await persistAssistantTurn({
|
// ``finalize_assistant_turn`` from a shielded
|
||||||
threadId: resumeThreadId,
|
// ``finally`` block (covers both happy-path and
|
||||||
assistantMsgId,
|
// abort-mid-stream). FE has no remaining persistence
|
||||||
content: finalContent,
|
// work here.
|
||||||
tokenUsage: tokenUsageData ?? undefined,
|
|
||||||
turnId: streamedChatTurnId,
|
|
||||||
logContext: "resumed chat",
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
streamBatcher?.dispose();
|
streamBatcher?.dispose();
|
||||||
await handleStreamTerminalError({
|
await handleStreamTerminalError({
|
||||||
|
|
@ -1546,19 +1499,6 @@ export default function NewChatPage() {
|
||||||
threadId: resumeThreadId,
|
threadId: resumeThreadId,
|
||||||
assistantMsgId,
|
assistantMsgId,
|
||||||
accepted: resumeAccepted,
|
accepted: resumeAccepted,
|
||||||
onAbort: async () => {
|
|
||||||
if (!resumeAccepted) return;
|
|
||||||
const hasContent = hasPersistableContent(contentParts, toolsWithUI);
|
|
||||||
if (!hasContent) return;
|
|
||||||
const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI);
|
|
||||||
await persistAssistantTurn({
|
|
||||||
threadId: resumeThreadId,
|
|
||||||
assistantMsgId,
|
|
||||||
content: partialContent,
|
|
||||||
turnId: streamedChatTurnId,
|
|
||||||
logContext: "partial resumed chat",
|
|
||||||
});
|
|
||||||
},
|
|
||||||
});
|
});
|
||||||
} finally {
|
} finally {
|
||||||
setIsRunning(false);
|
setIsRunning(false);
|
||||||
|
|
@ -1574,7 +1514,6 @@ export default function NewChatPage() {
|
||||||
tokenUsageStore,
|
tokenUsageStore,
|
||||||
fetchWithTurnCancellingRetry,
|
fetchWithTurnCancellingRetry,
|
||||||
handleStreamTerminalError,
|
handleStreamTerminalError,
|
||||||
persistAssistantTurn,
|
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
@ -1715,9 +1654,12 @@ export default function NewChatPage() {
|
||||||
const controller = new AbortController();
|
const controller = new AbortController();
|
||||||
abortControllerRef.current = controller;
|
abortControllerRef.current = controller;
|
||||||
|
|
||||||
// Add placeholder user message if we have a new query (edit mode)
|
// Add placeholder user message if we have a new query (edit mode).
|
||||||
const userMsgId = `msg-user-${Date.now()}`;
|
// Mutable for the same reason as in ``onNew`` — both ids are
|
||||||
const assistantMsgId = `msg-assistant-${Date.now()}`;
|
// renamed mid-stream by the new ``data-user-message-id`` /
|
||||||
|
// ``data-assistant-message-id`` SSE handlers below.
|
||||||
|
let userMsgId = `msg-user-${Date.now()}`;
|
||||||
|
let assistantMsgId = `msg-assistant-${Date.now()}`;
|
||||||
const currentThinkingSteps = new Map<string, ThinkingStepData>();
|
const currentThinkingSteps = new Map<string, ThinkingStepData>();
|
||||||
|
|
||||||
const contentPartsState: ContentPartsState = {
|
const contentPartsState: ContentPartsState = {
|
||||||
|
|
@ -1727,13 +1669,7 @@ export default function NewChatPage() {
|
||||||
toolCallIndices: new Map(),
|
toolCallIndices: new Map(),
|
||||||
};
|
};
|
||||||
const { contentParts } = contentPartsState;
|
const { contentParts } = contentPartsState;
|
||||||
let tokenUsageData: TokenUsageData | null = null;
|
|
||||||
let regenerateAccepted = false;
|
let regenerateAccepted = false;
|
||||||
let userPersisted = false;
|
|
||||||
// Captured from ``data-turn-info`` at stream start; stamped
|
|
||||||
// onto persisted messages so future edits can locate the
|
|
||||||
// right LangGraph checkpoint.
|
|
||||||
let streamedChatTurnId: string | null = null;
|
|
||||||
let streamBatcher: FrameBatchedUpdater | null = null;
|
let streamBatcher: FrameBatchedUpdater | null = null;
|
||||||
|
|
||||||
// Add placeholder messages to UI
|
// Add placeholder messages to UI
|
||||||
|
|
@ -1747,9 +1683,6 @@ export default function NewChatPage() {
|
||||||
createdAt: new Date(),
|
createdAt: new Date(),
|
||||||
metadata: isEdit ? undefined : originalUserMessageMetadata,
|
metadata: isEdit ? undefined : originalUserMessageMetadata,
|
||||||
};
|
};
|
||||||
const userContentToPersist = isEdit
|
|
||||||
? (editExtras?.userMessageContent ?? [{ type: "text", text: newUserQuery ?? "" }])
|
|
||||||
: originalUserMessageContent || [{ type: "text", text: userQueryToDisplay || "" }];
|
|
||||||
const sourceMentionedDocs =
|
const sourceMentionedDocs =
|
||||||
sourceUserMessageId && messageDocumentsMap[sourceUserMessageId]
|
sourceUserMessageId && messageDocumentsMap[sourceUserMessageId]
|
||||||
? messageDocumentsMap[sourceUserMessageId]
|
? messageDocumentsMap[sourceUserMessageId]
|
||||||
|
|
@ -1765,6 +1698,18 @@ export default function NewChatPage() {
|
||||||
filesystem_mode: selection.filesystem_mode,
|
filesystem_mode: selection.filesystem_mode,
|
||||||
client_platform: selection.client_platform,
|
client_platform: selection.client_platform,
|
||||||
local_filesystem_mounts: selection.local_filesystem_mounts,
|
local_filesystem_mounts: selection.local_filesystem_mounts,
|
||||||
|
// Full mention metadata for the regenerate-specific
|
||||||
|
// source list. Only meaningful for edit (the BE only
|
||||||
|
// re-persists a user row when ``user_query`` is set);
|
||||||
|
// reload reuses the original turn's mentioned_documents.
|
||||||
|
mentioned_documents:
|
||||||
|
sourceMentionedDocs.length > 0
|
||||||
|
? sourceMentionedDocs.map((d) => ({
|
||||||
|
id: d.id,
|
||||||
|
title: d.title,
|
||||||
|
document_type: d.document_type,
|
||||||
|
}))
|
||||||
|
: undefined,
|
||||||
};
|
};
|
||||||
if (isEdit) {
|
if (isEdit) {
|
||||||
requestBody.user_images = editExtras?.userImages ?? [];
|
requestBody.user_images = editExtras?.userImages ?? [];
|
||||||
|
|
@ -1852,7 +1797,6 @@ export default function NewChatPage() {
|
||||||
scheduleFlush,
|
scheduleFlush,
|
||||||
forceFlush,
|
forceFlush,
|
||||||
onTokenUsage: (data) => {
|
onTokenUsage: (data) => {
|
||||||
tokenUsageData = data;
|
|
||||||
tokenUsageStore.set(assistantMsgId, data);
|
tokenUsageStore.set(assistantMsgId, data);
|
||||||
},
|
},
|
||||||
onTurnStatus: (data) => {
|
onTurnStatus: (data) => {
|
||||||
|
|
@ -1897,7 +1841,6 @@ export default function NewChatPage() {
|
||||||
|
|
||||||
case "data-turn-info": {
|
case "data-turn-info": {
|
||||||
const turnId = readStreamedChatTurnId(parsed.data);
|
const turnId = readStreamedChatTurnId(parsed.data);
|
||||||
streamedChatTurnId = turnId;
|
|
||||||
if (turnId) {
|
if (turnId) {
|
||||||
setMessages((prev) =>
|
setMessages((prev) =>
|
||||||
applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId)
|
applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId)
|
||||||
|
|
@ -1906,6 +1849,57 @@ export default function NewChatPage() {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "data-user-message-id": {
|
||||||
|
// Same role as in ``onNew`` but the regenerate-specific
|
||||||
|
// mention metadata (``sourceMentionedDocs``) is the
|
||||||
|
// list to migrate onto the canonical id key.
|
||||||
|
const parsedMsg = readStreamedMessageId(parsed.data);
|
||||||
|
if (!parsedMsg) break;
|
||||||
|
const newUserMsgId = `msg-${parsedMsg.messageId}`;
|
||||||
|
const oldUserMsgId = userMsgId;
|
||||||
|
setMessages((prev) =>
|
||||||
|
prev.map((m) =>
|
||||||
|
m.id === oldUserMsgId
|
||||||
|
? mergeChatTurnIdIntoMessage(
|
||||||
|
{ ...m, id: newUserMsgId },
|
||||||
|
parsedMsg.turnId
|
||||||
|
)
|
||||||
|
: m
|
||||||
|
)
|
||||||
|
);
|
||||||
|
if (sourceMentionedDocs.length > 0) {
|
||||||
|
setMessageDocumentsMap((prev) => {
|
||||||
|
if (!(oldUserMsgId in prev)) {
|
||||||
|
return { ...prev, [newUserMsgId]: sourceMentionedDocs };
|
||||||
|
}
|
||||||
|
const { [oldUserMsgId]: _removed, ...rest } = prev;
|
||||||
|
return { ...rest, [newUserMsgId]: sourceMentionedDocs };
|
||||||
|
});
|
||||||
|
}
|
||||||
|
userMsgId = newUserMsgId;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case "data-assistant-message-id": {
|
||||||
|
const parsedMsg = readStreamedMessageId(parsed.data);
|
||||||
|
if (!parsedMsg) break;
|
||||||
|
const newAssistantMsgId = `msg-${parsedMsg.messageId}`;
|
||||||
|
const oldAssistantMsgId = assistantMsgId;
|
||||||
|
tokenUsageStore.rename(oldAssistantMsgId, newAssistantMsgId);
|
||||||
|
setMessages((prev) =>
|
||||||
|
prev.map((m) =>
|
||||||
|
m.id === oldAssistantMsgId
|
||||||
|
? mergeChatTurnIdIntoMessage(
|
||||||
|
{ ...m, id: newAssistantMsgId },
|
||||||
|
parsedMsg.turnId
|
||||||
|
)
|
||||||
|
: m
|
||||||
|
)
|
||||||
|
);
|
||||||
|
assistantMsgId = newAssistantMsgId;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
case "data-revert-results": {
|
case "data-revert-results": {
|
||||||
const summary = parsed.data;
|
const summary = parsed.data;
|
||||||
// failureCount must include every "not undone" bucket
|
// failureCount must include every "not undone" bucket
|
||||||
|
|
@ -1946,28 +1940,14 @@ export default function NewChatPage() {
|
||||||
|
|
||||||
batcher.flush();
|
batcher.flush();
|
||||||
|
|
||||||
// Persist messages after streaming completes
|
// Server-authoritative persistence: ``stream_new_chat``
|
||||||
const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI);
|
// (regenerate flow) wrote the user row in
|
||||||
|
// ``persist_user_turn`` and finalises the assistant row
|
||||||
|
// in ``finalize_assistant_turn`` from a shielded
|
||||||
|
// ``finally`` block (covers both happy-path and
|
||||||
|
// abort-mid-stream). FE only needs to track the
|
||||||
|
// successful response here.
|
||||||
if (contentParts.length > 0) {
|
if (contentParts.length > 0) {
|
||||||
const persistedUserMsgId = await persistUserTurn({
|
|
||||||
threadId,
|
|
||||||
userMsgId,
|
|
||||||
content: userContentToPersist,
|
|
||||||
mentionedDocs: sourceMentionedDocs,
|
|
||||||
turnId: streamedChatTurnId,
|
|
||||||
logContext: "regenerated",
|
|
||||||
});
|
|
||||||
userPersisted = Boolean(persistedUserMsgId);
|
|
||||||
|
|
||||||
await persistAssistantTurn({
|
|
||||||
threadId,
|
|
||||||
assistantMsgId,
|
|
||||||
content: finalContent,
|
|
||||||
tokenUsage: tokenUsageData ?? undefined,
|
|
||||||
turnId: streamedChatTurnId,
|
|
||||||
logContext: "regenerated",
|
|
||||||
});
|
|
||||||
|
|
||||||
trackChatResponseReceived(searchSpaceId, threadId);
|
trackChatResponseReceived(searchSpaceId, threadId);
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
|
@ -1978,44 +1958,6 @@ export default function NewChatPage() {
|
||||||
threadId,
|
threadId,
|
||||||
assistantMsgId,
|
assistantMsgId,
|
||||||
accepted: regenerateAccepted,
|
accepted: regenerateAccepted,
|
||||||
onAbort: async () => {
|
|
||||||
if (!regenerateAccepted) return;
|
|
||||||
if (!userPersisted) {
|
|
||||||
const persistedUserMsgId = await persistUserTurn({
|
|
||||||
threadId,
|
|
||||||
userMsgId,
|
|
||||||
content: userContentToPersist,
|
|
||||||
mentionedDocs: sourceMentionedDocs,
|
|
||||||
turnId: streamedChatTurnId,
|
|
||||||
logContext: "regenerated (aborted)",
|
|
||||||
});
|
|
||||||
userPersisted = Boolean(persistedUserMsgId);
|
|
||||||
}
|
|
||||||
const hasContent = hasPersistableContent(contentParts, toolsWithUI);
|
|
||||||
if (!hasContent) return;
|
|
||||||
const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI);
|
|
||||||
await persistAssistantTurn({
|
|
||||||
threadId,
|
|
||||||
assistantMsgId,
|
|
||||||
content: partialContent,
|
|
||||||
tokenUsage: tokenUsageData ?? undefined,
|
|
||||||
turnId: streamedChatTurnId,
|
|
||||||
logContext: "partial regenerated chat",
|
|
||||||
});
|
|
||||||
},
|
|
||||||
onAcceptedStreamError: async () => {
|
|
||||||
if (!userPersisted) {
|
|
||||||
const persistedUserMsgId = await persistUserTurn({
|
|
||||||
threadId,
|
|
||||||
userMsgId,
|
|
||||||
content: userContentToPersist,
|
|
||||||
mentionedDocs: sourceMentionedDocs,
|
|
||||||
turnId: streamedChatTurnId,
|
|
||||||
logContext: "regenerated (stream error)",
|
|
||||||
});
|
|
||||||
userPersisted = Boolean(persistedUserMsgId);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
});
|
});
|
||||||
} finally {
|
} finally {
|
||||||
setIsRunning(false);
|
setIsRunning(false);
|
||||||
|
|
@ -2034,8 +1976,6 @@ export default function NewChatPage() {
|
||||||
tokenUsageStore,
|
tokenUsageStore,
|
||||||
fetchWithTurnCancellingRetry,
|
fetchWithTurnCancellingRetry,
|
||||||
handleStreamTerminalError,
|
handleStreamTerminalError,
|
||||||
persistAssistantTurn,
|
|
||||||
persistUserTurn,
|
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -114,6 +114,29 @@ export function readStreamedChatTurnId(data: unknown): string | null {
|
||||||
return typeof value === "string" && value.length > 0 ? value : null;
|
return typeof value === "string" && value.length > 0 ? value : null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse the payload of `data-user-message-id` / `data-assistant-message-id`
|
||||||
|
* SSE events emitted by `stream_new_chat` and `stream_resume_chat` after
|
||||||
|
* `persist_user_turn` / `persist_assistant_shell` resolve a canonical
|
||||||
|
* `new_chat_messages.id`. Mirrors {@link readStreamedChatTurnId}.
|
||||||
|
*
|
||||||
|
* Returns `null` when the payload is malformed (missing or non-numeric
|
||||||
|
* `message_id`); callers should treat this as "ignore the event" so a
|
||||||
|
* malformed BE payload never overwrites the optimistic id with a bogus
|
||||||
|
* value.
|
||||||
|
*/
|
||||||
|
export function readStreamedMessageId(
|
||||||
|
data: unknown
|
||||||
|
): { messageId: number; turnId: string | null } | null {
|
||||||
|
if (typeof data !== "object" || data === null) return null;
|
||||||
|
const obj = data as { message_id?: unknown; turn_id?: unknown };
|
||||||
|
if (typeof obj.message_id !== "number" || !Number.isFinite(obj.message_id)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
const turnId = typeof obj.turn_id === "string" && obj.turn_id.length > 0 ? obj.turn_id : null;
|
||||||
|
return { messageId: obj.message_id, turnId };
|
||||||
|
}
|
||||||
|
|
||||||
export function applyTurnIdToAssistantMessageList(
|
export function applyTurnIdToAssistantMessageList(
|
||||||
messages: ThreadMessageLike[],
|
messages: ThreadMessageLike[],
|
||||||
assistantMsgId: string,
|
assistantMsgId: string,
|
||||||
|
|
|
||||||
|
|
@ -487,6 +487,37 @@ export type SSEEvent =
|
||||||
type: "data-turn-info";
|
type: "data-turn-info";
|
||||||
data: { chat_turn_id: string };
|
data: { chat_turn_id: string };
|
||||||
}
|
}
|
||||||
|
| {
|
||||||
|
/**
|
||||||
|
* Emitted by ``stream_new_chat`` AFTER ``data-turn-info`` /
|
||||||
|
* ``data-turn-status`` and BEFORE any LLM streaming events,
|
||||||
|
* once ``persist_user_turn`` has resolved the canonical
|
||||||
|
* ``new_chat_messages.id`` for the user-side row of the
|
||||||
|
* current turn. The frontend renames its optimistic
|
||||||
|
* ``msg-user-XXX`` placeholder id to ``msg-{message_id}``
|
||||||
|
* so DB-id-gated UI (comments, edit-from-this-message)
|
||||||
|
* unlocks immediately. Not emitted by ``stream_resume_chat``
|
||||||
|
* (resume reuses the original turn's user message).
|
||||||
|
*/
|
||||||
|
type: "data-user-message-id";
|
||||||
|
data: { message_id: number; turn_id: string };
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
/**
|
||||||
|
* Emitted by ``stream_new_chat`` AND ``stream_resume_chat``
|
||||||
|
* AFTER ``data-turn-info`` / ``data-turn-status`` and BEFORE
|
||||||
|
* any LLM streaming events, once ``persist_assistant_shell``
|
||||||
|
* has resolved the canonical ``new_chat_messages.id`` for
|
||||||
|
* the assistant-side row of the current turn. The frontend
|
||||||
|
* renames its optimistic ``msg-assistant-XXX`` placeholder
|
||||||
|
* id, migrates the local ``tokenUsageStore`` and
|
||||||
|
* ``pendingInterrupt`` references, and binds the running
|
||||||
|
* mutable ``assistantMsgId`` closure variable to the
|
||||||
|
* canonical id for the rest of the stream.
|
||||||
|
*/
|
||||||
|
type: "data-assistant-message-id";
|
||||||
|
data: { message_id: number; turn_id: string };
|
||||||
|
}
|
||||||
| {
|
| {
|
||||||
/**
|
/**
|
||||||
* Best-effort revert pass that ran BEFORE this regeneration.
|
* Best-effort revert pass that ran BEFORE this regeneration.
|
||||||
|
|
|
||||||
|
|
@ -144,6 +144,17 @@ export async function getThreadMessages(threadId: number): Promise<ThreadHistory
|
||||||
* via ``data-turn-info``. Persisting it lets later edits locate the
|
* via ``data-turn-info``. Persisting it lets later edits locate the
|
||||||
* matching LangGraph checkpoint without HumanMessage scanning. Older
|
* matching LangGraph checkpoint without HumanMessage scanning. Older
|
||||||
* callers can still omit it for back-compat.
|
* callers can still omit it for back-compat.
|
||||||
|
*
|
||||||
|
* @deprecated Replaced by the SSE-based message ID handshake. The
|
||||||
|
* streaming generator (`stream_new_chat` / `stream_resume_chat`) now
|
||||||
|
* persists both the user and assistant rows server-side via
|
||||||
|
* `persist_user_turn` / `persist_assistant_shell` and emits
|
||||||
|
* `data-user-message-id` / `data-assistant-message-id` SSE events so
|
||||||
|
* the UI renames its optimistic IDs in real time. The only remaining
|
||||||
|
* caller is `persistAssistantErrorMessage` (pre-stream error fallback
|
||||||
|
* for requests the server never accepted — the server has nothing to
|
||||||
|
* persist in that case). After the legacy route is removed in a
|
||||||
|
* follow-up PR this function will be deleted entirely.
|
||||||
*/
|
*/
|
||||||
export async function appendMessage(
|
export async function appendMessage(
|
||||||
threadId: number,
|
threadId: number,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue