Merge pull request #1327 from AnishSarkar22/feat/chat-state-unification

refactor(chat): unify streaming state flow and improve chat viewport + mention UX
This commit is contained in:
Rohan Verma 2026-04-30 16:24:51 -07:00 committed by GitHub
commit d335e96ec2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 1953 additions and 1647 deletions

View file

@ -33,6 +33,7 @@ from __future__ import annotations
import asyncio
import logging
import time
import weakref
from typing import Any
@ -58,6 +59,8 @@ class _ThreadLockManager:
weakref.WeakValueDictionary()
)
self._cancel_events: dict[str, asyncio.Event] = {}
self._cancel_requested_at_ms: dict[str, int] = {}
self._cancel_attempt_count: dict[str, int] = {}
def lock_for(self, thread_id: str) -> asyncio.Lock:
lock = self._locks.get(thread_id)
@ -76,14 +79,45 @@ class _ThreadLockManager:
def request_cancel(self, thread_id: str) -> bool:
event = self._cancel_events.get(thread_id)
if event is None:
return False
event = asyncio.Event()
self._cancel_events[thread_id] = event
event.set()
now_ms = int(time.time() * 1000)
self._cancel_requested_at_ms[thread_id] = now_ms
self._cancel_attempt_count[thread_id] = (
self._cancel_attempt_count.get(thread_id, 0) + 1
)
return True
def is_cancel_requested(self, thread_id: str) -> bool:
event = self._cancel_events.get(thread_id)
return bool(event and event.is_set())
def cancel_state(self, thread_id: str) -> tuple[int, int] | None:
if not self.is_cancel_requested(thread_id):
return None
attempts = self._cancel_attempt_count.get(thread_id, 1)
requested_at_ms = self._cancel_requested_at_ms.get(thread_id, 0)
return attempts, requested_at_ms
def reset(self, thread_id: str) -> None:
event = self._cancel_events.get(thread_id)
if event is not None:
event.clear()
self._cancel_requested_at_ms.pop(thread_id, None)
self._cancel_attempt_count.pop(thread_id, None)
def end_turn(self, thread_id: str) -> None:
"""Best-effort terminal cleanup for a thread turn.
This is intentionally idempotent and safe to call from outer stream
finally-blocks where middleware teardown might be skipped due to abort
or disconnect edge-cases.
"""
lock = self._locks.get(thread_id)
if lock is not None and lock.locked():
lock.release()
self.reset(thread_id)
# Module-level singleton — process-local but reused across all agent
@ -98,15 +132,30 @@ def get_cancel_event(thread_id: str) -> asyncio.Event:
def request_cancel(thread_id: str) -> bool:
"""Trip the cancel event for ``thread_id``. Returns True if found."""
"""Trip the cancel event for ``thread_id``. Always returns True."""
return manager.request_cancel(thread_id)
def is_cancel_requested(thread_id: str) -> bool:
"""Return whether ``thread_id`` currently has a pending cancel signal."""
return manager.is_cancel_requested(thread_id)
def get_cancel_state(thread_id: str) -> tuple[int, int] | None:
"""Return ``(attempt_count, requested_at_ms)`` for pending cancel state."""
return manager.cancel_state(thread_id)
def reset_cancel(thread_id: str) -> None:
"""Reset the cancel event for ``thread_id`` (called between turns)."""
manager.reset(thread_id)
def end_turn(thread_id: str) -> None:
"""Force end-of-turn cleanup for lock + cancel state."""
manager.end_turn(thread_id)
class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
"""Block concurrent prompts on the same thread.
@ -229,7 +278,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
__all__ = [
"BusyMutexMiddleware",
"end_turn",
"get_cancel_event",
"get_cancel_state",
"is_cancel_requested",
"manager",
"request_cancel",
"reset_cancel",

View file

@ -15,7 +15,7 @@ import json
import logging
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from sqlalchemy import func, or_
from sqlalchemy.exc import IntegrityError, OperationalError
@ -29,6 +29,12 @@ from app.agents.new_chat.filesystem_selection import (
FilesystemSelection,
LocalFilesystemMount,
)
from app.agents.new_chat.middleware.busy_mutex import (
get_cancel_state,
is_cancel_requested,
manager,
request_cancel,
)
from app.config import config
from app.db import (
ChatComment,
@ -44,6 +50,7 @@ from app.db import (
)
from app.schemas.new_chat import (
AgentToolInfo,
CancelActiveTurnResponse,
LocalFilesystemMountPayload,
NewChatMessageRead,
NewChatRequest,
@ -60,6 +67,7 @@ from app.schemas.new_chat import (
ThreadListItem,
ThreadListResponse,
TokenUsageSummary,
TurnStatusResponse,
)
from app.services.token_tracking_service import record_token_usage
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
@ -72,6 +80,9 @@ from app.utils.user_message_multimodal import (
_logger = logging.getLogger(__name__)
_background_tasks: set[asyncio.Task] = set()
TURN_CANCELLING_INITIAL_DELAY_MS = 200
TURN_CANCELLING_BACKOFF_FACTOR = 2
TURN_CANCELLING_MAX_DELAY_MS = 1500
router = APIRouter()
@ -137,6 +148,72 @@ def _resolve_filesystem_selection(
)
def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
"""Bounded exponential delay for TURN_CANCELLING retry hints."""
if attempt < 1:
attempt = 1
delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
)
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
def _build_turn_status_payload(thread_id: int) -> dict[str, object]:
lock = manager.lock_for(str(thread_id))
if not lock.locked():
return {"status": "idle"}
if is_cancel_requested(str(thread_id)):
cancel_state = get_cancel_state(str(thread_id))
attempt = cancel_state[0] if cancel_state else 1
retry_after_ms = _compute_turn_cancelling_retry_delay(attempt)
retry_after_at = int(datetime.now(UTC).timestamp() * 1000) + retry_after_ms
return {
"status": "cancelling",
"retry_after_ms": retry_after_ms,
"retry_after_at": retry_after_at,
}
return {"status": "busy"}
def _set_retry_after_headers(response: Response, retry_after_ms: int) -> None:
response.headers["retry-after-ms"] = str(retry_after_ms)
response.headers["Retry-After"] = str(max(1, (retry_after_ms + 999) // 1000))
def _raise_if_thread_busy_for_start(thread_id: int) -> None:
status_payload = _build_turn_status_payload(thread_id)
status = status_payload["status"]
if status == "idle":
return
if status == "cancelling":
retry_after_ms = int(status_payload.get("retry_after_ms") or 0)
detail = {
"errorCode": "TURN_CANCELLING",
"message": "A previous response is still stopping. Please try again in a moment.",
"retry_after_ms": retry_after_ms if retry_after_ms > 0 else None,
"retry_after_at": status_payload.get("retry_after_at"),
}
headers = (
{
"retry-after-ms": str(retry_after_ms),
"Retry-After": str(max(1, (retry_after_ms + 999) // 1000)),
}
if retry_after_ms > 0
else None
)
raise HTTPException(status_code=409, detail=detail, headers=headers)
raise HTTPException(
status_code=409,
detail={
"errorCode": "THREAD_BUSY",
"message": "Another response is still finishing for this thread. Please try again in a moment.",
},
)
def _find_pre_turn_checkpoint_id(
checkpoint_tuples: list,
*,
@ -1476,6 +1553,7 @@ async def handle_new_chat(
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
_raise_if_thread_busy_for_start(request.chat_id)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,
@ -1550,6 +1628,93 @@ async def handle_new_chat(
) from None
@router.post(
"/threads/{thread_id}/cancel-active-turn",
response_model=CancelActiveTurnResponse,
)
async def cancel_active_turn(
thread_id: int,
response: Response,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Signal cancellation for the currently running turn on ``thread_id``."""
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
)
await check_thread_access(session, thread, user)
status_payload = _build_turn_status_payload(thread_id)
if status_payload["status"] == "idle":
return CancelActiveTurnResponse(
status="idle",
error_code="NO_ACTIVE_TURN",
)
request_cancel(str(thread_id))
response.status_code = 202
updated_payload = _build_turn_status_payload(thread_id)
retry_after_ms = int(updated_payload.get("retry_after_ms") or 0)
retry_after_at = (
int(updated_payload["retry_after_at"])
if "retry_after_at" in updated_payload
else None
)
if retry_after_ms > 0:
_set_retry_after_headers(response, retry_after_ms)
return CancelActiveTurnResponse(
status="cancelling",
error_code="TURN_CANCELLING",
retry_after_ms=retry_after_ms if retry_after_ms > 0 else None,
retry_after_at=retry_after_at,
)
@router.get(
"/threads/{thread_id}/turn-status",
response_model=TurnStatusResponse,
)
async def get_turn_status(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to view chats in this search space",
)
await check_thread_access(session, thread, user)
status_payload = _build_turn_status_payload(thread_id)
return TurnStatusResponse(
status=status_payload["status"], # type: ignore[arg-type]
active_turn_id=None,
retry_after_ms=status_payload.get("retry_after_ms"), # type: ignore[arg-type]
retry_after_at=status_payload.get("retry_after_at"), # type: ignore[arg-type]
)
# =============================================================================
# Chat Regeneration Endpoint (Edit/Reload)
# =============================================================================
@ -1605,6 +1770,7 @@ async def regenerate_response(
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
_raise_if_thread_busy_for_start(thread_id)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,
@ -2012,6 +2178,7 @@ async def resume_chat(
)
await check_thread_access(session, thread, user)
_raise_if_thread_busy_for_start(thread_id)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,

View file

@ -335,6 +335,24 @@ class ResumeRequest(BaseModel):
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
class CancelActiveTurnResponse(BaseModel):
"""Response for canceling an active turn on a chat thread."""
status: Literal["cancelling", "idle"]
error_code: Literal["TURN_CANCELLING", "NO_ACTIVE_TURN"]
retry_after_ms: int | None = None
retry_after_at: int | None = None
class TurnStatusResponse(BaseModel):
"""Current turn execution status for a thread."""
status: Literal["idle", "busy", "cancelling"]
active_turn_id: str | None = None
retry_after_ms: int | None = None
retry_after_at: int | None = None
# =============================================================================
# Public Chat Snapshot Schemas
# =============================================================================

View file

@ -565,7 +565,12 @@ class VercelStreamingService:
# Error Part
# =========================================================================
def format_error(self, error_text: str, error_code: str | None = None) -> str:
def format_error(
self,
error_text: str,
error_code: str | None = None,
extra: dict[str, object] | None = None,
) -> str:
"""
Format an error message.
@ -579,9 +584,11 @@ class VercelStreamingService:
Example output:
data: {"type":"error","errorText":"Something went wrong","errorCode":"SOME_CODE"}
"""
payload: dict[str, str] = {"type": "error", "errorText": error_text}
payload: dict[str, object] = {"type": "error", "errorText": error_text}
if error_code:
payload["errorCode"] = error_code
if extra:
payload.update(extra)
return self._format_sse(payload)
# =========================================================================

View file

@ -45,6 +45,11 @@ from app.agents.new_chat.memory_extraction import (
extract_and_save_memory,
extract_and_save_team_memory,
)
from app.agents.new_chat.middleware.busy_mutex import (
end_turn,
get_cancel_state,
is_cancel_requested,
)
from app.agents.new_chat.middleware.kb_persistence import (
commit_staged_filesystem_state,
)
@ -72,6 +77,18 @@ from app.utils.user_message_multimodal import build_human_message_content
_background_tasks: set[asyncio.Task] = set()
_perf_log = get_perf_logger()
TURN_CANCELLING_INITIAL_DELAY_MS = 200
TURN_CANCELLING_BACKOFF_FACTOR = 2
TURN_CANCELLING_MAX_DELAY_MS = 1500
def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
if attempt < 1:
attempt = 1
delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
)
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
@ -401,15 +418,35 @@ def _classify_stream_exception(
exc: Exception,
*,
flow_label: str,
) -> tuple[str, str, Literal["info", "warn", "error"], bool, str]:
) -> tuple[
str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None
]:
raw = str(exc)
if isinstance(exc, BusyError) or "Thread is busy with another request" in raw:
busy_thread_id = str(exc.request_id) if isinstance(exc, BusyError) else None
if busy_thread_id and is_cancel_requested(busy_thread_id):
cancel_state = get_cancel_state(busy_thread_id)
attempt = cancel_state[0] if cancel_state else 1
retry_after_ms = _compute_turn_cancelling_retry_delay(attempt)
retry_after_at = int(time.time() * 1000) + retry_after_ms
return (
"thread_busy",
"TURN_CANCELLING",
"info",
True,
"A previous response is still stopping. Please try again in a moment.",
{
"retry_after_ms": retry_after_ms,
"retry_after_at": retry_after_at,
},
)
return (
"thread_busy",
"THREAD_BUSY",
"warn",
True,
"Another response is still finishing for this thread. Please try again in a moment.",
None,
)
parsed = _parse_error_payload(raw)
@ -431,6 +468,7 @@ def _classify_stream_exception(
"warn",
True,
"This model is temporarily rate-limited. Please try again in a few seconds or switch models.",
None,
)
return (
@ -439,6 +477,7 @@ def _classify_stream_exception(
"error",
False,
f"Error during {flow_label}: {raw}",
None,
)
@ -470,7 +509,7 @@ def _emit_stream_terminal_error(
message=message,
extra=extra,
)
return streaming_service.format_error(message, error_code=error_code)
return streaming_service.format_error(message, error_code=error_code, extra=extra)
def _legacy_match_lc_id(
@ -2497,6 +2536,7 @@ async def stream_new_chat(
"turn-info",
{"chat_turn_id": stream_result.turn_id},
)
yield streaming_service.format_data("turn-status", {"status": "busy"})
# Initial thinking step - analyzing the request
if mentioned_surfsense_docs:
@ -2805,6 +2845,7 @@ async def stream_new_chat(
task.add_done_callback(_background_tasks.discard)
# Finish the step and message
yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
@ -2819,11 +2860,19 @@ async def stream_new_chat(
severity,
is_expected,
user_message,
error_extra,
) = _classify_stream_exception(e, flow_label="chat")
error_message = f"Error during chat: {e!s}"
print(f"[stream_new_chat] {error_message}")
print(f"[stream_new_chat] Exception type: {type(e).__name__}")
print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}")
if error_code == "TURN_CANCELLING":
status_payload: dict[str, Any] = {"status": "cancelling"}
if error_extra:
status_payload.update(error_extra)
yield streaming_service.format_data("turn-status", status_payload)
else:
yield streaming_service.format_data("turn-status", {"status": "busy"})
yield _emit_stream_error(
message=user_message,
@ -2831,7 +2880,9 @@ async def stream_new_chat(
error_code=error_code,
severity=severity,
is_expected=is_expected,
extra=error_extra,
)
yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
@ -2847,6 +2898,10 @@ async def stream_new_chat(
# (CancelledError is a BaseException), and the rest of the
# finally block — including session.close() — would never run.
with anyio.CancelScope(shield=True):
# Authoritative fallback cleanup for lock/cancel state. Middleware
# teardown can be skipped on some client-abort paths.
end_turn(str(chat_id))
# Release premium reservation if not finalized
if _premium_request_id and _premium_reserved > 0 and user_id:
try:
@ -3206,6 +3261,7 @@ async def stream_resume_chat(
"turn-info",
{"chat_turn_id": stream_result.turn_id},
)
yield streaming_service.format_data("turn-status", {"status": "busy"})
_t_stream_start = time.perf_counter()
_first_event_logged = False
@ -3305,6 +3361,7 @@ async def stream_resume_chat(
},
)
yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
@ -3318,23 +3375,37 @@ async def stream_resume_chat(
severity,
is_expected,
user_message,
error_extra,
) = _classify_stream_exception(e, flow_label="resume")
error_message = f"Error during resume: {e!s}"
print(f"[stream_resume_chat] {error_message}")
print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}")
if error_code == "TURN_CANCELLING":
status_payload: dict[str, Any] = {"status": "cancelling"}
if error_extra:
status_payload.update(error_extra)
yield streaming_service.format_data("turn-status", status_payload)
else:
yield streaming_service.format_data("turn-status", {"status": "busy"})
yield _emit_stream_error(
message=user_message,
error_kind=error_kind,
error_code=error_code,
severity=severity,
is_expected=is_expected,
extra=error_extra,
)
yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
finally:
with anyio.CancelScope(shield=True):
# Authoritative fallback cleanup for lock/cancel state. Middleware
# teardown can be skipped on some client-abort paths.
end_turn(str(chat_id))
# Release premium reservation if not finalized
if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id:
try:

View file

@ -7,7 +7,9 @@ import pytest
from app.agents.new_chat.errors import BusyError
from app.agents.new_chat.middleware.busy_mutex import (
BusyMutexMiddleware,
end_turn,
get_cancel_event,
is_cancel_requested,
manager,
request_cancel,
reset_cancel,
@ -88,3 +90,31 @@ async def test_no_thread_id_skipped_when_not_required() -> None:
def test_reset_cancel_idempotent() -> None:
# Should not raise even if event was never created
reset_cancel("never-seen")
def test_request_cancel_creates_event_for_unseen_thread() -> None:
thread_id = "never-seen-cancel"
reset_cancel(thread_id)
assert request_cancel(thread_id) is True
assert get_cancel_event(thread_id).is_set()
assert is_cancel_requested(thread_id) is True
@pytest.mark.asyncio
async def test_end_turn_force_clears_lock_and_cancel_state() -> None:
thread_id = "forced-end-turn"
mw = BusyMutexMiddleware()
runtime = _Runtime(thread_id)
await mw.abefore_agent({}, runtime)
assert manager.lock_for(thread_id).locked()
request_cancel(thread_id)
assert is_cancel_requested(thread_id) is True
end_turn(thread_id)
assert not manager.lock_for(thread_id).locked()
assert not get_cancel_event(thread_id).is_set()
assert is_cancel_requested(thread_id) is False

View file

@ -8,6 +8,7 @@ import pytest
import app.tasks.chat.stream_new_chat as stream_new_chat_module
from app.agents.new_chat.errors import BusyError
from app.agents.new_chat.middleware.busy_mutex import request_cancel, reset_cancel
from app.tasks.chat.stream_new_chat import (
StreamResult,
_classify_stream_exception,
@ -147,7 +148,7 @@ def test_stream_exception_classifies_rate_limited():
exc = Exception(
'{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}'
)
kind, code, severity, is_expected, user_message = _classify_stream_exception(
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
exc, flow_label="chat"
)
assert kind == "rate_limited"
@ -155,11 +156,12 @@ def test_stream_exception_classifies_rate_limited():
assert severity == "warn"
assert is_expected is True
assert "temporarily rate-limited" in user_message
assert extra is None
def test_stream_exception_classifies_thread_busy():
exc = BusyError(request_id="thread-123")
kind, code, severity, is_expected, user_message = _classify_stream_exception(
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
exc, flow_label="chat"
)
assert kind == "thread_busy"
@ -167,11 +169,12 @@ def test_stream_exception_classifies_thread_busy():
assert severity == "warn"
assert is_expected is True
assert "still finishing for this thread" in user_message
assert extra is None
def test_stream_exception_classifies_thread_busy_from_message():
exc = Exception("Thread is busy with another request")
kind, code, severity, is_expected, user_message = _classify_stream_exception(
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
exc, flow_label="chat"
)
assert kind == "thread_busy"
@ -179,6 +182,24 @@ def test_stream_exception_classifies_thread_busy_from_message():
assert severity == "warn"
assert is_expected is True
assert "still finishing for this thread" in user_message
assert extra is None
def test_stream_exception_classifies_turn_cancelling_when_cancel_requested():
thread_id = "thread-cancelling-1"
reset_cancel(thread_id)
request_cancel(thread_id)
exc = BusyError(request_id=thread_id)
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
exc, flow_label="chat"
)
assert kind == "thread_busy"
assert code == "TURN_CANCELLING"
assert severity == "info"
assert is_expected is True
assert "stopping" in user_message
assert isinstance(extra, dict)
assert "retry_after_ms" in extra
def test_premium_classification_is_error_code_driven():
@ -219,33 +240,33 @@ def test_toast_only_pre_accept_policy_has_no_inline_failed_marker():
def test_network_send_failures_use_unified_retry_toast_message():
classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts"
classifier_source = classifier_path.read_text(encoding="utf-8")
page_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
request_errors_path = (
Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-request-errors.ts"
)
page_source = page_path.read_text(encoding="utf-8")
request_errors_source = request_errors_path.read_text(encoding="utf-8")
assert '"send_failed_pre_accept"' in classifier_source
assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source
assert 'errorCode === "TURN_CANCELLING"' in classifier_source
assert "if (withCode.code) return withCode.code;" in classifier_source
assert 'userMessage: "Message not sent. Please retry."' in classifier_source
assert 'userMessage: "Connection issue. Please try again."' in classifier_source
assert "tagPreAcceptSendFailure(error)" in page_source
assert "const passthroughCodes = new Set([" in page_source
assert '"PREMIUM_QUOTA_EXHAUSTED"' in page_source
assert '"THREAD_BUSY"' in page_source
assert '"AUTH_EXPIRED"' in page_source
assert '"UNAUTHORIZED"' in page_source
assert '"RATE_LIMITED"' in page_source
assert '"NETWORK_ERROR"' in page_source
assert '"STREAM_PARSE_ERROR"' in page_source
assert '"TOOL_EXECUTION_ERROR"' in page_source
assert '"PERSIST_MESSAGE_FAILED"' in page_source
assert '"SERVER_ERROR"' in page_source
assert "passthroughCodes.has(existingCode)" in page_source
assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in page_source
assert 'errorCode: "NETWORK_ERROR"' not in page_source
assert "Failed to start chat. Please try again." not in page_source
assert "const passthroughCodes = new Set([" in request_errors_source
assert '"PREMIUM_QUOTA_EXHAUSTED"' in request_errors_source
assert '"THREAD_BUSY"' in request_errors_source
assert '"TURN_CANCELLING"' in request_errors_source
assert '"AUTH_EXPIRED"' in request_errors_source
assert '"UNAUTHORIZED"' in request_errors_source
assert '"RATE_LIMITED"' in request_errors_source
assert '"NETWORK_ERROR"' in request_errors_source
assert '"STREAM_PARSE_ERROR"' in request_errors_source
assert '"TOOL_EXECUTION_ERROR"' in request_errors_source
assert '"PERSIST_MESSAGE_FAILED"' in request_errors_source
assert '"SERVER_ERROR"' in request_errors_source
assert "passthroughCodes.has(existingCode)" in request_errors_source
assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in request_errors_source
assert 'errorCode: "NETWORK_ERROR"' not in request_errors_source
assert "Failed to start chat. Please try again." not in classifier_source
def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows():
@ -269,3 +290,75 @@ def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows()
# New flow persists only when accepted and not already persisted.
assert "if (newAccepted && !userPersisted) {" in source
assert "const fetchWithTurnCancellingRetry = useCallback(" in source
assert "computeFallbackTurnCancellingRetryDelay" in source
assert 'withMeta.errorCode === "TURN_CANCELLING"' in source
assert 'withMeta.errorCode === "THREAD_BUSY"' in source
assert "await fetchWithTurnCancellingRetry(() =>" in source
def test_cancel_active_turn_route_contract_exists():
routes_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_backend/app/routes/new_chat_routes.py"
)
source = routes_path.read_text(encoding="utf-8")
assert '@router.post(\n "/threads/{thread_id}/cancel-active-turn",' in source
assert "response_model=CancelActiveTurnResponse" in source
assert 'status="cancelling",' in source
assert 'error_code="TURN_CANCELLING",' in source
assert "retry_after_ms=retry_after_ms if retry_after_ms > 0 else None," in source
assert "retry_after_at=" in source
assert 'status="idle",' in source
assert 'error_code="NO_ACTIVE_TURN",' in source
def test_turn_status_route_contract_exists():
routes_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_backend/app/routes/new_chat_routes.py"
)
source = routes_path.read_text(encoding="utf-8")
assert '@router.get(\n "/threads/{thread_id}/turn-status",' in source
assert "response_model=TurnStatusResponse" in source
assert "_build_turn_status_payload(thread_id)" in source
assert "Permission.CHATS_READ.value" in source
assert "_raise_if_thread_busy_for_start(" in source
def test_turn_cancelling_retry_policy_contract_exists():
routes_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_backend/app/routes/new_chat_routes.py"
)
source = routes_path.read_text(encoding="utf-8")
assert "TURN_CANCELLING_INITIAL_DELAY_MS = 200" in source
assert "TURN_CANCELLING_BACKOFF_FACTOR = 2" in source
assert "TURN_CANCELLING_MAX_DELAY_MS = 1500" in source
assert "def _compute_turn_cancelling_retry_delay(" in source
assert "retry-after-ms" in source
assert '"Retry-After"' in source
assert '"errorCode": "TURN_CANCELLING"' in source
def test_turn_status_sse_contract_exists():
stream_source = (
Path(__file__).resolve().parents[3]
/ "surfsense_backend/app/tasks/chat/stream_new_chat.py"
).read_text(encoding="utf-8")
state_source = (
Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/streaming-state.ts"
).read_text(encoding="utf-8")
pipeline_source = (
Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/stream-pipeline.ts"
).read_text(encoding="utf-8")
assert '"turn-status"' in stream_source
assert '"status": "busy"' in stream_source
assert '"status": "idle"' in stream_source
assert "type: \"data-turn-status\"" in state_source
assert "case \"data-turn-status\":" in pipeline_source
assert "end_turn(str(chat_id))" in stream_source

View file

@ -26,7 +26,14 @@ export const setThreadVisibilityAtom = atom(null, (get, set, newVisibility: Chat
export const resetCurrentThreadAtom = atom(null, (_, set) => {
set(currentThreadAtom, initialState);
set(reportPanelAtom, { isOpen: false, reportId: null, title: null, wordCount: null });
set(reportPanelAtom, {
isOpen: false,
reportId: null,
title: null,
wordCount: null,
shareToken: null,
contentType: "markdown",
});
});
/** Target comment ID to scroll to (from URL navigation or inbox click) */

View file

@ -548,8 +548,10 @@ const AssistantMessageInner: FC = () => {
</div>
)}
<div className="aui-assistant-message-footer mt-3 mb-5 ml-2 flex items-center gap-2">
<AssistantActionBar />
<div className="aui-assistant-message-footer mt-3 mb-5 ml-2 h-6">
<div className="h-full opacity-100 transition-opacity">
<AssistantActionBar />
</div>
</div>
</CitationMetadataProvider>
);
@ -642,35 +644,41 @@ export const AssistantMessage: FC = () => {
className="aui-assistant-message-root group fade-in slide-in-from-bottom-1 relative mx-auto w-full max-w-(--thread-max-width) animate-in py-3 duration-150"
data-role="assistant"
>
{/* Comment trigger — right-aligned, just below user query on all screen sizes */}
{showCommentTrigger && (
<div className="mr-2 mb-1 flex justify-end">
<button
ref={isDesktop ? commentTriggerRef : undefined}
type="button"
onClick={
isDesktop ? () => setIsInlineOpen((prev) => !prev) : () => setIsSheetOpen(true)
}
className={cn(
"flex items-center gap-1.5 rounded-full px-3 py-1 text-sm transition-colors",
isDesktop && isInlineOpen
? "bg-primary/10 text-primary"
: hasComments
? "text-primary hover:bg-primary/10"
: "text-muted-foreground hover:text-foreground hover:bg-muted"
)}
>
<MessageCircleReply className={cn("size-3.5", hasComments && "fill-current")} />
{hasComments ? (
<span>
{commentCount} {commentCount === 1 ? "comment" : "comments"}
</span>
) : (
<span>Add comment</span>
)}
</button>
</div>
)}
{/* Fixed trigger slot prevents any vertical reflow when visibility changes */}
<div className="mr-2 mb-1 flex h-7 justify-end">
<button
ref={isDesktop ? commentTriggerRef : undefined}
type="button"
onClick={
showCommentTrigger
? isDesktop
? () => setIsInlineOpen((prev) => !prev)
: () => setIsSheetOpen(true)
: undefined
}
aria-hidden={!showCommentTrigger}
tabIndex={showCommentTrigger ? 0 : -1}
className={cn(
"flex items-center gap-1.5 rounded-full px-3 py-1 text-sm transition-colors",
"opacity-0 pointer-events-none",
showCommentTrigger && "opacity-100 pointer-events-auto",
isDesktop && isInlineOpen
? "bg-primary/10 text-primary"
: hasComments
? "text-primary hover:bg-primary/10"
: "text-muted-foreground hover:text-foreground hover:bg-muted"
)}
>
<MessageCircleReply className={cn("size-3.5", hasComments && "fill-current")} />
{hasComments ? (
<span>
{commentCount} {commentCount === 1 ? "comment" : "comments"}
</span>
) : (
<span>Add comment</span>
)}
</button>
</div>
{/* Desktop floating comment panel — overlays on top of chat content */}
{showCommentTrigger && isDesktop && isInlineOpen && dbMessageId && (

View file

@ -0,0 +1,52 @@
"use client";
import { ThreadPrimitive } from "@assistant-ui/react";
import { ArrowDownIcon } from "lucide-react";
import type { FC, ReactNode } from "react";
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
const ChatScrollToBottom: FC = () => (
<ThreadPrimitive.ScrollToBottom asChild>
<TooltipIconButton
tooltip="Scroll to bottom"
variant="outline"
className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent"
>
<ArrowDownIcon />
</TooltipIconButton>
</ThreadPrimitive.ScrollToBottom>
);
export interface ChatViewportProps {
children: ReactNode;
footer?: ReactNode;
}
export const ChatViewport: FC<ChatViewportProps> = ({ children, footer }) => (
<ThreadPrimitive.Viewport
turnAnchor="top"
autoScroll
scrollToBottomOnRunStart
scrollToBottomOnInitialize
scrollToBottomOnThreadSwitch
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 scroll-smooth"
style={{ scrollbarGutter: "stable" }}
>
<div
aria-hidden
className="aui-chat-viewport-top-fade pointer-events-none sticky top-0 z-10 -mx-4 h-2 shrink-0 bg-gradient-to-b from-main-panel from-20% to-transparent"
/>
{children}
{footer ? (
<ThreadPrimitive.ViewportFooter
className="aui-chat-composer-footer sticky bottom-0 z-20 -mx-4 mt-auto flex flex-col items-stretch bg-gradient-to-t from-main-panel from-60% to-transparent px-4 pt-6"
style={{ paddingBottom: "max(0.5rem, env(safe-area-inset-bottom))" }}
>
<div className="aui-chat-composer-area relative mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-3 overflow-visible">
<ChatScrollToBottom />
{footer}
</div>
</ThreadPrimitive.ViewportFooter>
) : null}
</ThreadPrimitive.Viewport>
);

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,24 @@
"use client";
import { forwardRef, type ComponentPropsWithoutRef, type WheelEvent } from "react";
export type NestedScrollProps = ComponentPropsWithoutRef<"div">;
export const NestedScroll = forwardRef<HTMLDivElement, NestedScrollProps>(
({ onWheel, ...props }, ref) => {
const handleWheel = (event: WheelEvent<HTMLDivElement>) => {
const el = event.currentTarget;
const canScrollUp = el.scrollTop > 0;
const canScrollDown = el.scrollTop < el.scrollHeight - el.clientHeight - 1;
const goingUp = event.deltaY < 0;
const goingDown = event.deltaY > 0;
if ((goingUp && canScrollUp) || (goingDown && canScrollDown)) {
event.stopPropagation();
}
onWheel?.(event);
};
return <div ref={ref} onWheel={handleWheel} {...props} />;
}
);
NestedScroll.displayName = "NestedScroll";

View file

@ -1,18 +0,0 @@
import { ThreadPrimitive } from "@assistant-ui/react";
import { ArrowDownIcon } from "lucide-react";
import type { FC } from "react";
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
export const ThreadScrollToBottom: FC = () => {
return (
<ThreadPrimitive.ScrollToBottom asChild>
<TooltipIconButton
tooltip="Scroll to bottom"
variant="outline"
className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent"
>
<ArrowDownIcon />
</TooltipIconButton>
</ThreadPrimitive.ScrollToBottom>
);
};

View file

@ -5,12 +5,10 @@ import {
ThreadPrimitive,
useAui,
useAuiState,
useThreadViewportStore,
} from "@assistant-ui/react";
import { useAtom, useAtomValue, useSetAtom } from "jotai";
import {
AlertCircle,
ArrowDownIcon,
ArrowUpIcon,
Camera,
ChevronDown,
@ -55,6 +53,7 @@ import {
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
import { AssistantMessage } from "@/components/assistant-ui/assistant-message";
import { ChatSessionStatus } from "@/components/assistant-ui/chat-session-status";
import { ChatViewport } from "@/components/assistant-ui/chat-viewport";
import { ConnectorIndicator } from "@/components/assistant-ui/connector-popup";
import { useDocumentUploadDialog } from "@/components/assistant-ui/document-upload-popup";
import {
@ -112,10 +111,13 @@ const ThreadContent: FC = () => {
["--thread-max-width" as string]: "44rem",
}}
>
<ThreadPrimitive.Viewport
turnAnchor="top"
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4"
style={{ scrollbarGutter: "stable" }}
<ChatViewport
footer={
<AuiIf condition={({ thread }) => !thread.isEmpty}>
<PremiumQuotaPinnedAlert />
<Composer />
</AuiIf>
}
>
<AuiIf condition={({ thread }) => thread.isEmpty}>
<ThreadWelcome />
@ -128,24 +130,7 @@ const ThreadContent: FC = () => {
AssistantMessage,
}}
/>
<AuiIf condition={({ thread }) => !thread.isEmpty}>
<div className="grow" />
</AuiIf>
<ThreadPrimitive.ViewportFooter
className="aui-thread-viewport-footer sticky bottom-0 z-10 mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible rounded-t-3xl bg-main-panel pb-4 md:pb-6"
style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }}
>
<ThreadScrollToBottom />
<AuiIf condition={({ thread }) => !thread.isEmpty}>
<PremiumQuotaPinnedAlert />
</AuiIf>
<AuiIf condition={({ thread }) => !thread.isEmpty}>
<Composer />
</AuiIf>
</ThreadPrimitive.ViewportFooter>
</ThreadPrimitive.Viewport>
</ChatViewport>
</ThreadPrimitive.Root>
);
};
@ -181,20 +166,6 @@ const PremiumQuotaPinnedAlert: FC = () => {
);
};
const ThreadScrollToBottom: FC = () => {
return (
<ThreadPrimitive.ScrollToBottom asChild>
<TooltipIconButton
tooltip="Scroll to bottom"
variant="outline"
className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent"
>
<ArrowDownIcon />
</TooltipIconButton>
</ThreadPrimitive.ScrollToBottom>
);
};
const getTimeBasedGreeting = (user?: { display_name?: string | null; email?: string }): string => {
const hour = new Date().getHours();
@ -411,23 +382,9 @@ const Composer: FC = () => {
>(new Map());
const documentPickerRef = useRef<DocumentMentionPickerRef>(null);
const promptPickerRef = useRef<PromptPickerRef>(null);
const viewportRef = useRef<Element | null>(null);
const { search_space_id, chat_id } = useParams();
const aui = useAui();
const threadViewportStore = useThreadViewportStore();
const hasAutoFocusedRef = useRef(false);
const submitCleanupRef = useRef<(() => void) | null>(null);
useEffect(() => {
return () => {
submitCleanupRef.current?.();
};
}, []);
// Store viewport element reference on mount
useEffect(() => {
viewportRef.current = document.querySelector(".aui-thread-viewport");
}, []);
const electronAPI = useElectronAPI();
const [clipboardInitialText, setClipboardInitialText] = useState<string | undefined>();
@ -626,7 +583,6 @@ const Composer: FC = () => {
[showDocumentPopover, showPromptPicker]
);
// Submit message (blocked during streaming, document picker open, or AI responding to another user)
const handleSubmit = useCallback(() => {
if (isThreadRunning || isBlockedByOtherUser) return;
if (showDocumentPopover || showPromptPicker) return;
@ -638,50 +594,9 @@ const Composer: FC = () => {
setClipboardInitialText(undefined);
}
const viewportEl = viewportRef.current;
const heightBefore = viewportEl?.scrollHeight ?? 0;
aui.composer().send();
editorRef.current?.clear();
setMentionedDocuments([]);
// With turnAnchor="top", ViewportSlack adds min-height to the last
// assistant message so that scrolling-to-bottom actually positions the
// user message at the TOP of the viewport. That slack height is
// calculated asynchronously (ResizeObserver → style → layout).
// Poll via rAF for ~500ms, re-scrolling whenever scrollHeight changes.
const scrollToBottom = () =>
threadViewportStore.getState().scrollToBottom({ behavior: "instant" });
let lastHeight = heightBefore;
let frames = 0;
let cancelled = false;
const POLL_FRAMES = 30;
const pollAndScroll = () => {
if (cancelled) return;
const el = viewportRef.current;
if (el) {
const h = el.scrollHeight;
if (h !== lastHeight) {
lastHeight = h;
scrollToBottom();
}
}
if (++frames < POLL_FRAMES) {
requestAnimationFrame(pollAndScroll);
}
};
requestAnimationFrame(pollAndScroll);
const t1 = setTimeout(scrollToBottom, 100);
const t2 = setTimeout(scrollToBottom, 300);
submitCleanupRef.current = () => {
cancelled = true;
clearTimeout(t1);
clearTimeout(t2);
};
}, [
showDocumentPopover,
showPromptPicker,
@ -690,7 +605,6 @@ const Composer: FC = () => {
clipboardInitialText,
aui,
setMentionedDocuments,
threadViewportStore,
]);
const handleDocumentRemove = useCallback(

View file

@ -13,6 +13,7 @@ import {
isDoomLoopInterrupt,
} from "@/components/tool-ui/doom-loop-approval";
import { GenericHitlApprovalToolUI } from "@/components/tool-ui/generic-hitl-approval";
import { NestedScroll } from "@/components/assistant-ui/nested-scroll";
import {
AlertDialog,
AlertDialogAction,
@ -475,7 +476,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => {
{(argsText || isRunning) && (
<div className="flex flex-col gap-1 min-w-0">
<p className="text-xs font-medium text-muted-foreground">Inputs</p>
<div className="max-h-48 overflow-auto rounded-md bg-muted/40">
<NestedScroll className="max-h-48 overflow-auto rounded-md bg-muted/40">
{argsText ? (
<pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono">
{argsText}
@ -489,7 +490,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => {
Waiting for input
</p>
)}
</div>
</NestedScroll>
</div>
)}
{!isCancelled && result !== undefined && (
@ -497,11 +498,11 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => {
<Separator />
<div className="flex flex-col gap-1 min-w-0">
<p className="text-xs font-medium text-muted-foreground">Result</p>
<div className="max-h-64 overflow-auto rounded-md bg-muted/40">
<NestedScroll className="max-h-64 overflow-auto rounded-md bg-muted/40">
<pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono">
{typeof result === "string" ? result : serializedResult}
</pre>
</div>
</NestedScroll>
</div>
</>
)}

View file

@ -1,4 +1,10 @@
import { ActionBarPrimitive, AuiIf, MessagePrimitive, useAuiState } from "@assistant-ui/react";
import {
ActionBarPrimitive,
AuiIf,
MessagePrimitive,
useAuiState,
useMessagePartText,
} from "@assistant-ui/react";
import { useAtomValue } from "jotai";
import { CheckIcon, CopyIcon, Pencil } from "lucide-react";
import Image from "next/image";
@ -7,6 +13,8 @@ import { currentThreadAtom } from "@/atoms/chat/current-thread.atom";
import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom";
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
import { getMentionDocKey } from "@/lib/chat/mention-doc-key";
import { parseMentionSegments } from "@/lib/chat/parse-mention-segments";
interface AuthorMetadata {
displayName: string | null;
@ -47,23 +55,40 @@ const UserAvatar: FC<AuthorMetadata> = ({ displayName, avatarUrl }) => {
);
};
export const UserMessage: FC = () => {
const UserTextPart: FC = () => {
const messageId = useAuiState(({ message }) => message?.id);
const messageText = useAuiState(({ message }) =>
(message?.content ?? [])
.map((part) =>
typeof part === "object" &&
part !== null &&
"type" in part &&
(part as { type?: string }).type === "text" &&
"text" in part
? String((part as { text?: string }).text ?? "")
: ""
)
.join("")
);
const part = useMessagePartText();
const text = (part as { text?: string }).text ?? "";
const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom);
const mentionedDocs = messageId ? messageDocumentsMap[messageId] : undefined;
const mentionedDocs = (messageId ? messageDocumentsMap[messageId] : undefined) ?? [];
const segments = parseMentionSegments(text, mentionedDocs);
return (
<p style={{ whiteSpace: "pre-line" }} className="break-words">
{segments.map((segment) =>
segment.type === "text" ? (
<span key={`txt-${segment.start}`}>{segment.value}</span>
) : (
<span
key={`mention-${getMentionDocKey(segment.doc)}-${segment.start}`}
className="inline-flex items-center gap-1 mx-0.5 px-1 py-0.5 rounded bg-primary/10 text-xs font-bold text-primary/60 select-none align-middle leading-none"
title={segment.doc.title}
>
<span className="flex items-center text-muted-foreground">
{getConnectorIcon(segment.doc.document_type ?? "UNKNOWN", "h-3 w-3")}
</span>
<span className="max-w-[120px] truncate">{segment.doc.title}</span>
</span>
)
)}
</p>
);
};
const userMessageParts = { Text: UserTextPart };
export const UserMessage: FC = () => {
const metadata = useAuiState(({ message }) => message?.metadata);
const author = metadata?.custom?.author as AuthorMetadata | undefined;
const isSharedChat = useAtomValue(currentThreadAtom).visibility === "SEARCH_SPACE";
@ -78,11 +103,7 @@ export const UserMessage: FC = () => {
<div className="aui-user-message-content-wrapper flex items-end gap-2">
<div className="relative flex-1 min-w-0">
<div className="aui-user-message-content wrap-break-word rounded-2xl bg-muted px-4 py-2.5 text-foreground">
{mentionedDocs && mentionedDocs.length > 0 ? (
<UserMessageWithMentionChips text={messageText} mentionedDocs={mentionedDocs} />
) : (
<MessagePrimitive.Parts />
)}
<MessagePrimitive.Parts components={userMessageParts} />
</div>
<div className="absolute right-0 top-full mt-1 z-10 opacity-100 pointer-events-auto md:opacity-0 md:pointer-events-none md:transition-opacity md:duration-200 md:delay-300 md:group-hover/user-msg:opacity-100 md:group-hover/user-msg:delay-0 md:group-hover/user-msg:pointer-events-auto">
<UserActionBar />
@ -99,64 +120,6 @@ export const UserMessage: FC = () => {
);
};
const UserMessageWithMentionChips: FC<{
text: string;
mentionedDocs: { id: number; title: string; document_type: string }[];
}> = ({ text, mentionedDocs }) => {
type Segment =
| { type: "text"; value: string; start: number }
| { type: "mention"; doc: { id: number; title: string; document_type: string }; start: number };
const tokens = mentionedDocs
.map((doc) => ({ doc, token: `@${doc.title}` }))
.sort((a, b) => b.token.length - a.token.length);
const segments: Segment[] = [];
let i = 0;
let buffer = "";
let bufferStart = 0;
while (i < text.length) {
const tokenMatch = tokens.find(({ token }) => text.startsWith(token, i));
if (tokenMatch) {
if (buffer) {
segments.push({ type: "text", value: buffer, start: bufferStart });
buffer = "";
}
segments.push({ type: "mention", doc: tokenMatch.doc, start: i });
i += tokenMatch.token.length;
bufferStart = i;
continue;
}
if (!buffer) bufferStart = i;
buffer += text[i];
i += 1;
}
if (buffer) {
segments.push({ type: "text", value: buffer, start: bufferStart });
}
return (
<span className="whitespace-pre-wrap break-words">
{segments.map((segment) =>
segment.type === "text" ? (
<span key={`txt-${segment.start}`}>{segment.value}</span>
) : (
<span
key={`mention-${segment.doc.document_type}:${segment.doc.id}-${segment.start}`}
className="inline-flex items-center gap-1 mx-0.5 px-1 py-0.5 rounded bg-primary/10 text-xs font-bold text-primary/60 select-none align-baseline"
title={segment.doc.title}
>
<span className="flex items-center text-muted-foreground">
{getConnectorIcon(segment.doc.document_type ?? "UNKNOWN", "h-3 w-3")}
</span>
<span className="max-w-[120px] truncate">{segment.doc.title}</span>
</span>
)
)}
</span>
);
};
const UserActionBar: FC = () => {
const isThreadRunning = useAuiState(({ thread }) => thread.isRunning);

View file

@ -1,11 +1,10 @@
"use client";
import { AuiIf, ThreadPrimitive } from "@assistant-ui/react";
import { ArrowDownIcon } from "lucide-react";
import type { FC } from "react";
import { AssistantMessage } from "@/components/assistant-ui/assistant-message";
import { ChatViewport } from "@/components/assistant-ui/chat-viewport";
import { EditComposer } from "@/components/assistant-ui/edit-composer";
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
import { UserMessage } from "@/components/assistant-ui/user-message";
import { FreeComposer } from "./free-composer";
@ -24,20 +23,6 @@ const FreeThreadWelcome: FC = () => {
);
};
const ThreadScrollToBottom: FC = () => {
return (
<ThreadPrimitive.ScrollToBottom asChild>
<TooltipIconButton
tooltip="Scroll to bottom"
variant="outline"
className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent"
>
<ArrowDownIcon />
</TooltipIconButton>
</ThreadPrimitive.ScrollToBottom>
);
};
export const FreeThread: FC = () => {
return (
<ThreadPrimitive.Root
@ -46,10 +31,12 @@ export const FreeThread: FC = () => {
["--thread-max-width" as string]: "44rem",
}}
>
<ThreadPrimitive.Viewport
turnAnchor="top"
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4"
style={{ scrollbarGutter: "stable" }}
<ChatViewport
footer={
<AuiIf condition={({ thread }) => !thread.isEmpty}>
<FreeComposer />
</AuiIf>
}
>
<AuiIf condition={({ thread }) => thread.isEmpty}>
<FreeThreadWelcome />
@ -62,21 +49,7 @@ export const FreeThread: FC = () => {
AssistantMessage,
}}
/>
<AuiIf condition={({ thread }) => !thread.isEmpty}>
<div className="grow" />
</AuiIf>
<ThreadPrimitive.ViewportFooter
className="aui-thread-viewport-footer sticky bottom-0 z-10 mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible rounded-t-3xl bg-main-panel pb-4 md:pb-6"
style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }}
>
<ThreadScrollToBottom />
<AuiIf condition={({ thread }) => !thread.isEmpty}>
<FreeComposer />
</AuiIf>
</ThreadPrimitive.ViewportFooter>
</ThreadPrimitive.Viewport>
</ChatViewport>
</ThreadPrimitive.Root>
);
};

View file

@ -236,6 +236,93 @@ interface DisplayItem {
isAutoMode: boolean;
}
const TruncatedNameWithTooltip: React.FC<{
text: string;
className?: string;
enableTooltip: boolean;
}> = ({ text, className, enableTooltip }) => {
const textRef = useRef<HTMLSpanElement>(null);
const openTimerRef = useRef<number | undefined>(undefined);
const [isTruncated, setIsTruncated] = useState(false);
const [open, setOpen] = useState(false);
const recalcTruncation = useCallback(() => {
const el = textRef.current;
if (!el) return;
setIsTruncated(el.scrollWidth > el.clientWidth + 1);
}, []);
useEffect(() => {
if (!enableTooltip) return;
const el = textRef.current;
if (!el) return;
const raf = requestAnimationFrame(recalcTruncation);
recalcTruncation();
const observer = new ResizeObserver(recalcTruncation);
observer.observe(el);
if (el.parentElement) observer.observe(el.parentElement);
window.addEventListener("resize", recalcTruncation);
return () => {
cancelAnimationFrame(raf);
observer.disconnect();
window.removeEventListener("resize", recalcTruncation);
};
}, [enableTooltip, recalcTruncation]);
useEffect(() => {
// Recompute when row text changes.
void text;
requestAnimationFrame(recalcTruncation);
}, [text, recalcTruncation]);
useEffect(
() => () => {
if (openTimerRef.current) window.clearTimeout(openTimerRef.current);
},
[]
);
if (!enableTooltip) {
return (
<span ref={textRef} className={cn("block max-w-full", className)}>
{text}
</span>
);
}
const handleOpenChange = (nextOpen: boolean) => {
if (openTimerRef.current) {
window.clearTimeout(openTimerRef.current);
openTimerRef.current = undefined;
}
if (!nextOpen) {
setOpen(false);
return;
}
if (!isTruncated) return;
openTimerRef.current = window.setTimeout(() => {
setOpen(true);
openTimerRef.current = undefined;
}, 220);
};
return (
<Tooltip open={open} onOpenChange={handleOpenChange}>
<TooltipTrigger asChild>
<span ref={textRef} className={cn("block max-w-full", className)}>
{text}
</span>
</TooltipTrigger>
<TooltipContent side="top" align="start">
{text}
</TooltipContent>
</Tooltip>
);
};
// ─── Component ──────────────────────────────────────────────────────
interface ModelSelectorProps {
@ -936,7 +1023,11 @@ export function ModelSelector({
{/* Model info */}
<div className="flex-1 min-w-0">
<div className="flex items-center gap-1.5">
<span className="font-medium text-sm truncate">{config.name}</span>
<TruncatedNameWithTooltip
text={config.name}
enableTooltip={!isMobile}
className="font-medium text-sm truncate"
/>
{isAutoMode && (
<Badge
variant="secondary"

View file

@ -45,20 +45,21 @@ export const PublicThread: FC<PublicThreadProps> = ({ footer }) => {
["--thread-max-width" as string]: "44rem",
}}
>
<ThreadPrimitive.Viewport className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4">
<ThreadPrimitive.Viewport
scrollToBottomOnInitialize
scrollToBottomOnThreadSwitch
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4 pb-6"
>
<ThreadPrimitive.Messages
components={{
UserMessage: PublicUserMessage,
AssistantMessage: PublicAssistantMessage,
}}
/>
{/* Spacer to ensure footer doesn't overlap last message */}
<div className="h-24" />
</ThreadPrimitive.Viewport>
{footer && (
<div className="sticky bottom-0 z-20 border-t bg-main-panel/95 backdrop-blur supports-backdrop-filter:bg-main-panel/60">
<div className="border-t bg-main-panel/95 backdrop-blur supports-backdrop-filter:bg-main-panel/60">
{footer}
</div>
)}

View file

@ -147,6 +147,22 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError
};
}
if (
errorCode === "TURN_CANCELLING"
) {
return {
kind: "thread_busy",
channel: "toast",
severity: "info",
telemetryEvent: "chat_blocked",
isExpected: true,
userMessage: "A previous response is still stopping. Please try again in a moment.",
rawMessage,
errorCode: errorCode ?? "TURN_CANCELLING",
details: { flow: input.flow },
};
}
if (
errorCode === "THREAD_BUSY"
) {
@ -156,7 +172,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError
severity: "warn",
telemetryEvent: "chat_blocked",
isExpected: true,
userMessage: "A previous response is still stopping. Please try again in a moment.",
userMessage: "Another response is still finishing for this thread. Please try again in a moment.",
rawMessage,
errorCode: errorCode ?? "THREAD_BUSY",
details: { flow: input.flow },

View file

@ -0,0 +1,114 @@
export async function toHttpResponseError(
response: Response
): Promise<Error & { errorCode?: string; retryAfterMs?: number }> {
const statusDefaultCode =
response.status === 409
? "THREAD_BUSY"
: response.status === 429
? "RATE_LIMITED"
: response.status === 401 || response.status === 403
? "AUTH_EXPIRED"
: "SERVER_ERROR";
let rawBody = "";
try {
rawBody = await response.text();
} catch {
// noop
}
let parsedBody: Record<string, unknown> | null = null;
if (rawBody) {
try {
const parsed = JSON.parse(rawBody);
if (typeof parsed === "object" && parsed !== null) {
parsedBody = parsed as Record<string, unknown>;
}
} catch {
// noop
}
}
const detail = parsedBody?.detail;
const detailObject =
typeof detail === "object" && detail !== null ? (detail as Record<string, unknown>) : null;
const detailMessage = typeof detail === "string" ? detail : undefined;
const topLevelMessage =
typeof parsedBody?.message === "string" ? (parsedBody.message as string) : undefined;
const detailNestedMessage =
typeof detailObject?.message === "string" ? (detailObject.message as string) : undefined;
const topLevelCode =
typeof parsedBody?.errorCode === "string"
? parsedBody.errorCode
: typeof parsedBody?.error_code === "string"
? parsedBody.error_code
: undefined;
const detailCode =
typeof detailObject?.errorCode === "string"
? detailObject.errorCode
: typeof detailObject?.error_code === "string"
? detailObject.error_code
: undefined;
const errorCode = detailCode ?? topLevelCode ?? statusDefaultCode;
const detailRetryAfterMs =
typeof detailObject?.retry_after_ms === "number"
? detailObject.retry_after_ms
: typeof detailObject?.retryAfterMs === "number"
? detailObject.retryAfterMs
: undefined;
const topRetryAfterMs =
typeof parsedBody?.retry_after_ms === "number"
? parsedBody.retry_after_ms
: typeof parsedBody?.retryAfterMs === "number"
? parsedBody.retryAfterMs
: undefined;
const headerRetryAfterMsRaw = response.headers.get("retry-after-ms");
const headerRetryAfterMs = headerRetryAfterMsRaw ? Number.parseFloat(headerRetryAfterMsRaw) : NaN;
const retryAfterHeader = response.headers.get("retry-after");
const retryAfterSeconds = retryAfterHeader ? Number.parseFloat(retryAfterHeader) : NaN;
const retryAfterMsFromHeader = Number.isFinite(headerRetryAfterMs)
? Math.max(0, Math.round(headerRetryAfterMs))
: Number.isFinite(retryAfterSeconds)
? Math.max(0, Math.round(retryAfterSeconds * 1000))
: undefined;
const retryAfterMs =
detailRetryAfterMs ?? topRetryAfterMs ?? retryAfterMsFromHeader ?? undefined;
const message =
detailNestedMessage ??
detailMessage ??
topLevelMessage ??
`Backend error: ${response.status}`;
return Object.assign(new Error(message), { errorCode, retryAfterMs });
}
export function tagPreAcceptSendFailure(error: unknown): unknown {
if (error instanceof Error) {
const withCode = error as Error & { errorCode?: string; code?: string };
const existingCode = withCode.errorCode ?? withCode.code;
const passthroughCodes = new Set([
"PREMIUM_QUOTA_EXHAUSTED",
"THREAD_BUSY",
"TURN_CANCELLING",
"AUTH_EXPIRED",
"UNAUTHORIZED",
"RATE_LIMITED",
"NETWORK_ERROR",
"STREAM_PARSE_ERROR",
"TOOL_EXECUTION_ERROR",
"PERSIST_MESSAGE_FAILED",
"SERVER_ERROR",
]);
if (existingCode && passthroughCodes.has(existingCode)) {
return Object.assign(error, { errorCode: existingCode });
}
return Object.assign(error, { errorCode: "SEND_FAILED_PRE_ACCEPT" });
}
return Object.assign(new Error("Failed to send message before stream acceptance"), {
errorCode: "SEND_FAILED_PRE_ACCEPT",
});
}

View file

@ -0,0 +1,54 @@
import type { MentionedDocumentInfo } from "@/atoms/chat/mentioned-documents.atom";
export type MentionSegment =
| { type: "text"; value: string; start: number }
| { type: "mention"; doc: MentionedDocumentInfo; start: number };
/**
* Tokenizes a user message into text and `@mention` segments.
*
* Pure: no React, no DOM, no side effects. Safe to unit-test and reuse.
*
* Mentions are matched greedily by longest title first so that a longer title
* (e.g. `@Project Roadmap`) is never shadowed by a shorter prefix
* (e.g. `@Project`).
*/
export function parseMentionSegments(
text: string,
docs: ReadonlyArray<MentionedDocumentInfo>
): MentionSegment[] {
if (text.length === 0) return [];
if (docs.length === 0) return [{ type: "text", value: text, start: 0 }];
const tokens = docs
.map((doc) => ({ doc, token: `@${doc.title}` }))
.sort((a, b) => b.token.length - a.token.length);
const segments: MentionSegment[] = [];
let i = 0;
let buffer = "";
let bufferStart = 0;
while (i < text.length) {
const tokenMatch = tokens.find(({ token }) => text.startsWith(token, i));
if (tokenMatch) {
if (buffer) {
segments.push({ type: "text", value: buffer, start: bufferStart });
buffer = "";
}
segments.push({ type: "mention", doc: tokenMatch.doc, start: i });
i += tokenMatch.token.length;
bufferStart = i;
continue;
}
if (!buffer) bufferStart = i;
buffer += text[i];
i += 1;
}
if (buffer) {
segments.push({ type: "text", value: buffer, start: bufferStart });
}
return segments;
}

View file

@ -0,0 +1,19 @@
import { FrameBatchedUpdater } from "@/lib/chat/streaming-state";
export function createStreamFlushHelpers(flushMessages: () => void): {
batcher: FrameBatchedUpdater;
scheduleFlush: () => void;
forceFlush: () => void;
} {
const batcher = new FrameBatchedUpdater();
const scheduleFlush = () => batcher.schedule(flushMessages);
// Force-flush helper: ``batcher.flush()`` is a no-op when
// ``dirty=false`` (e.g. a tool starts before any text streamed).
// ``scheduleFlush(); batcher.flush()`` sets the dirty bit first so
// terminal events render promptly without the throttle delay.
const forceFlush = () => {
scheduleFlush();
batcher.flush();
};
return { batcher, scheduleFlush, forceFlush };
}

View file

@ -0,0 +1,196 @@
import {
addStepSeparator,
addToolCall,
appendReasoning,
appendText,
appendToolInputDelta,
type ContentPartsState,
endReasoning,
readSSEStream,
type SSEEvent,
type ThinkingStepData,
type ToolUIGate,
updateThinkingSteps,
updateToolCall,
} from "@/lib/chat/streaming-state";
export type SharedStreamEventContext = {
contentPartsState: ContentPartsState;
toolsWithUI: ToolUIGate;
currentThinkingSteps: Map<string, ThinkingStepData>;
scheduleFlush: () => void;
forceFlush: () => void;
onTokenUsage?: (data: Extract<SSEEvent, { type: "data-token-usage" }>["data"]) => void;
onTurnStatus?: (data: Extract<SSEEvent, { type: "data-turn-status" }>["data"]) => void;
onToolOutputAvailable?: (
event: Extract<SSEEvent, { type: "tool-output-available" }>,
context: {
contentPartsState: ContentPartsState;
toolCallIndices: Map<string, number>;
}
) => void;
};
/**
* After a tool produces output, mark any previously-decided interrupt tool
* calls as completed so the ApprovalCard can transition from shimmer to done.
*/
export function markInterruptsCompleted(
contentParts: Array<{ type: string; result?: unknown }>
): void {
for (const part of contentParts) {
if (
part.type === "tool-call" &&
typeof part.result === "object" &&
part.result !== null &&
(part.result as Record<string, unknown>).__interrupt__ === true &&
(part.result as Record<string, unknown>).__decided__ &&
!(part.result as Record<string, unknown>).__completed__
) {
part.result = { ...(part.result as Record<string, unknown>), __completed__: true };
}
}
}
export function hasPersistableContent(
contentParts: ContentPartsState["contentParts"],
toolsWithUI: ToolUIGate
) {
return contentParts.some(
(part) =>
(part.type === "text" && part.text.length > 0) ||
(part.type === "reasoning" && part.text.length > 0) ||
(part.type === "tool-call" && (toolsWithUI === "all" || toolsWithUI.has(part.toolName)))
);
}
function toStreamTerminalError(
event: Extract<SSEEvent, { type: "error" }>
): Error & { errorCode?: string } {
return Object.assign(new Error(event.errorText || "Server error"), {
errorCode: event.errorCode,
});
}
export function processSharedStreamEvent(parsed: SSEEvent, context: SharedStreamEventContext): boolean {
const { contentPartsState, toolsWithUI, currentThinkingSteps, scheduleFlush, forceFlush } = context;
const { contentParts, toolCallIndices } = contentPartsState;
switch (parsed.type) {
case "text-delta":
appendText(contentPartsState, parsed.delta);
scheduleFlush();
return true;
case "reasoning-delta":
appendReasoning(contentPartsState, parsed.delta);
scheduleFlush();
return true;
case "reasoning-end":
endReasoning(contentPartsState);
scheduleFlush();
return true;
case "start-step":
addStepSeparator(contentPartsState);
scheduleFlush();
return true;
case "finish-step":
return true;
case "tool-input-start":
addToolCall(
contentPartsState,
toolsWithUI,
parsed.toolCallId,
parsed.toolName,
{},
false,
parsed.langchainToolCallId
);
forceFlush();
return true;
case "tool-input-delta":
// High-frequency event: deltas can fire dozens of times per call,
// so use throttled scheduleFlush (NOT forceFlush) to coalesce.
appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta);
scheduleFlush();
return true;
case "tool-input-available": {
const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2);
if (toolCallIndices.has(parsed.toolCallId)) {
updateToolCall(contentPartsState, parsed.toolCallId, {
args: parsed.input || {},
argsText: finalArgsText,
langchainToolCallId: parsed.langchainToolCallId,
});
} else {
addToolCall(
contentPartsState,
toolsWithUI,
parsed.toolCallId,
parsed.toolName,
parsed.input || {},
false,
parsed.langchainToolCallId
);
// addToolCall doesn't accept argsText today; backfill via
// updateToolCall so the new card renders pretty-printed JSON.
updateToolCall(contentPartsState, parsed.toolCallId, {
argsText: finalArgsText,
});
}
forceFlush();
return true;
}
case "tool-output-available":
updateToolCall(contentPartsState, parsed.toolCallId, {
result: parsed.output,
langchainToolCallId: parsed.langchainToolCallId,
});
markInterruptsCompleted(contentParts);
context.onToolOutputAvailable?.(parsed, { contentPartsState, toolCallIndices });
forceFlush();
return true;
case "data-thinking-step": {
const stepData = parsed.data as ThinkingStepData;
if (stepData?.id) {
currentThinkingSteps.set(stepData.id, stepData);
const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps);
if (didUpdate) {
scheduleFlush();
}
}
return true;
}
case "data-token-usage":
context.onTokenUsage?.(parsed.data);
return true;
case "data-turn-status":
context.onTurnStatus?.(parsed.data);
return true;
case "error":
throw toStreamTerminalError(parsed);
default:
return false;
}
}
export async function consumeSseEvents(
response: Response,
onEvent: (event: SSEEvent) => void | Promise<void>
): Promise<void> {
for await (const parsed of readSSEStream(response)) {
await onEvent(parsed);
}
}

View file

@ -0,0 +1,127 @@
import type { ThreadMessageLike } from "@assistant-ui/react";
import {
addToolCall,
type ContentPartsState,
type ToolUIGate,
updateToolCall,
} from "@/lib/chat/streaming-state";
type InterruptActionRequest = {
name: string;
args: Record<string, unknown>;
};
export type EditedInterruptAction = {
name: string;
args: Record<string, unknown>;
};
function readInterruptActions(
interruptData: Record<string, unknown>
): InterruptActionRequest[] {
return (interruptData.action_requests ?? []) as InterruptActionRequest[];
}
/**
* Applies an interrupt request payload to tool-call parts. Existing tool cards
* are updated in-place; missing ones are upserted so approval UI always shows.
*/
export function applyInterruptRequestToContentParts(
contentPartsState: ContentPartsState,
toolsWithUI: ToolUIGate,
interruptData: Record<string, unknown>
): void {
const { contentParts, toolCallIndices } = contentPartsState;
const actionRequests = readInterruptActions(interruptData);
for (const action of actionRequests) {
const existingEntry = Array.from(toolCallIndices.entries()).find(([, idx]) => {
const part = contentParts[idx];
return part?.type === "tool-call" && part.toolName === action.name;
});
if (existingEntry) {
updateToolCall(contentPartsState, existingEntry[0], {
result: { __interrupt__: true, ...interruptData },
});
} else {
const toolCallId = `interrupt-${action.name}`;
addToolCall(contentPartsState, toolsWithUI, toolCallId, action.name, action.args, true);
updateToolCall(contentPartsState, toolCallId, {
result: { __interrupt__: true, ...interruptData },
});
}
}
}
export function mergeEditedInterruptAction(
contentParts: ContentPartsState["contentParts"],
editedAction: EditedInterruptAction | undefined
): void {
if (!editedAction) return;
for (const part of contentParts) {
if (part.type === "tool-call" && part.toolName === editedAction.name) {
const mergedArgs = { ...part.args, ...editedAction.args };
part.args = mergedArgs;
// assistant-ui prefers argsText over JSON.stringify(args)
part.argsText = JSON.stringify(mergedArgs, null, 2);
break;
}
}
}
export function markInterruptDecisionOnContentParts(
contentParts: ContentPartsState["contentParts"],
decisionType: "approve" | "reject" | undefined
): void {
if (!decisionType) return;
for (const part of contentParts) {
if (
part.type === "tool-call" &&
typeof part.result === "object" &&
part.result !== null &&
"__interrupt__" in (part.result as Record<string, unknown>)
) {
part.result = {
...(part.result as Record<string, unknown>),
__decided__: decisionType,
};
}
}
}
/**
* When a streamed message is persisted, the backend returns the durable
* turn_id; merge it into assistant-ui metadata for turn-scoped actions.
*/
export function mergeChatTurnIdIntoMessage(
msg: ThreadMessageLike,
turnId: string | null | undefined
): ThreadMessageLike {
if (!turnId) return msg;
const existingMeta = (msg.metadata ?? {}) as { custom?: Record<string, unknown> };
const existingCustom = existingMeta.custom ?? {};
if ((existingCustom as { chatTurnId?: string }).chatTurnId === turnId) return msg;
return {
...msg,
metadata: {
...existingMeta,
custom: { ...existingCustom, chatTurnId: turnId },
},
};
}
export function readStreamedChatTurnId(data: unknown): string | null {
if (typeof data !== "object" || data === null) return null;
const value = (data as { chat_turn_id?: unknown }).chat_turn_id;
return typeof value === "string" && value.length > 0 ? value : null;
}
export function applyTurnIdToAssistantMessageList(
messages: ThreadMessageLike[],
assistantMsgId: string,
turnId: string
): ThreadMessageLike[] {
return messages.map((m) =>
m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, turnId) : m
);
}

View file

@ -528,6 +528,14 @@ export type SSEEvent =
}>;
};
}
| {
type: "data-turn-status";
data: {
status: "idle" | "busy" | "cancelling";
retry_after_ms?: number;
retry_after_at?: number;
};
}
| {
type: "data-token-usage";
data: {