refactor(chat): implement turn cancellation and status management in new chat routes for improved user experience and error handling

This commit is contained in:
Anish Sarkar 2026-05-01 01:47:52 +05:30
parent 4056bd1d69
commit af66fbf106
12 changed files with 671 additions and 81 deletions

View file

@ -33,6 +33,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import time
import weakref import weakref
from typing import Any from typing import Any
@ -58,6 +59,8 @@ class _ThreadLockManager:
weakref.WeakValueDictionary() weakref.WeakValueDictionary()
) )
self._cancel_events: dict[str, asyncio.Event] = {} self._cancel_events: dict[str, asyncio.Event] = {}
self._cancel_requested_at_ms: dict[str, int] = {}
self._cancel_attempt_count: dict[str, int] = {}
def lock_for(self, thread_id: str) -> asyncio.Lock: def lock_for(self, thread_id: str) -> asyncio.Lock:
lock = self._locks.get(thread_id) lock = self._locks.get(thread_id)
@ -76,14 +79,45 @@ class _ThreadLockManager:
def request_cancel(self, thread_id: str) -> bool: def request_cancel(self, thread_id: str) -> bool:
event = self._cancel_events.get(thread_id) event = self._cancel_events.get(thread_id)
if event is None: if event is None:
return False event = asyncio.Event()
self._cancel_events[thread_id] = event
event.set() event.set()
now_ms = int(time.time() * 1000)
self._cancel_requested_at_ms[thread_id] = now_ms
self._cancel_attempt_count[thread_id] = (
self._cancel_attempt_count.get(thread_id, 0) + 1
)
return True return True
def is_cancel_requested(self, thread_id: str) -> bool:
event = self._cancel_events.get(thread_id)
return bool(event and event.is_set())
def cancel_state(self, thread_id: str) -> tuple[int, int] | None:
if not self.is_cancel_requested(thread_id):
return None
attempts = self._cancel_attempt_count.get(thread_id, 1)
requested_at_ms = self._cancel_requested_at_ms.get(thread_id, 0)
return attempts, requested_at_ms
def reset(self, thread_id: str) -> None: def reset(self, thread_id: str) -> None:
event = self._cancel_events.get(thread_id) event = self._cancel_events.get(thread_id)
if event is not None: if event is not None:
event.clear() event.clear()
self._cancel_requested_at_ms.pop(thread_id, None)
self._cancel_attempt_count.pop(thread_id, None)
def end_turn(self, thread_id: str) -> None:
"""Best-effort terminal cleanup for a thread turn.
This is intentionally idempotent and safe to call from outer stream
finally-blocks where middleware teardown might be skipped due to abort
or disconnect edge-cases.
"""
lock = self._locks.get(thread_id)
if lock is not None and lock.locked():
lock.release()
self.reset(thread_id)
# Module-level singleton — process-local but reused across all agent # Module-level singleton — process-local but reused across all agent
@ -98,15 +132,30 @@ def get_cancel_event(thread_id: str) -> asyncio.Event:
def request_cancel(thread_id: str) -> bool: def request_cancel(thread_id: str) -> bool:
"""Trip the cancel event for ``thread_id``. Returns True if found.""" """Trip the cancel event for ``thread_id``. Always returns True."""
return manager.request_cancel(thread_id) return manager.request_cancel(thread_id)
def is_cancel_requested(thread_id: str) -> bool:
"""Return whether ``thread_id`` currently has a pending cancel signal."""
return manager.is_cancel_requested(thread_id)
def get_cancel_state(thread_id: str) -> tuple[int, int] | None:
"""Return ``(attempt_count, requested_at_ms)`` for pending cancel state."""
return manager.cancel_state(thread_id)
def reset_cancel(thread_id: str) -> None: def reset_cancel(thread_id: str) -> None:
"""Reset the cancel event for ``thread_id`` (called between turns).""" """Reset the cancel event for ``thread_id`` (called between turns)."""
manager.reset(thread_id) manager.reset(thread_id)
def end_turn(thread_id: str) -> None:
"""Force end-of-turn cleanup for lock + cancel state."""
manager.end_turn(thread_id)
class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
"""Block concurrent prompts on the same thread. """Block concurrent prompts on the same thread.
@ -229,7 +278,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
__all__ = [ __all__ = [
"BusyMutexMiddleware", "BusyMutexMiddleware",
"end_turn",
"get_cancel_event", "get_cancel_event",
"get_cancel_state",
"is_cancel_requested",
"manager", "manager",
"request_cancel", "request_cancel",
"reset_cancel", "reset_cancel",

View file

@ -15,7 +15,7 @@ import json
import logging import logging
from datetime import UTC, datetime from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from sqlalchemy import func, or_ from sqlalchemy import func, or_
from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.exc import IntegrityError, OperationalError
@ -29,6 +29,12 @@ from app.agents.new_chat.filesystem_selection import (
FilesystemSelection, FilesystemSelection,
LocalFilesystemMount, LocalFilesystemMount,
) )
from app.agents.new_chat.middleware.busy_mutex import (
get_cancel_state,
is_cancel_requested,
manager,
request_cancel,
)
from app.config import config from app.config import config
from app.db import ( from app.db import (
ChatComment, ChatComment,
@ -44,6 +50,7 @@ from app.db import (
) )
from app.schemas.new_chat import ( from app.schemas.new_chat import (
AgentToolInfo, AgentToolInfo,
CancelActiveTurnResponse,
LocalFilesystemMountPayload, LocalFilesystemMountPayload,
NewChatMessageRead, NewChatMessageRead,
NewChatRequest, NewChatRequest,
@ -60,6 +67,7 @@ from app.schemas.new_chat import (
ThreadListItem, ThreadListItem,
ThreadListResponse, ThreadListResponse,
TokenUsageSummary, TokenUsageSummary,
TurnStatusResponse,
) )
from app.services.token_tracking_service import record_token_usage from app.services.token_tracking_service import record_token_usage
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
@ -72,6 +80,9 @@ from app.utils.user_message_multimodal import (
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
_background_tasks: set[asyncio.Task] = set() _background_tasks: set[asyncio.Task] = set()
TURN_CANCELLING_INITIAL_DELAY_MS = 200
TURN_CANCELLING_BACKOFF_FACTOR = 2
TURN_CANCELLING_MAX_DELAY_MS = 1500
router = APIRouter() router = APIRouter()
@ -137,6 +148,72 @@ def _resolve_filesystem_selection(
) )
def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
"""Bounded exponential delay for TURN_CANCELLING retry hints."""
if attempt < 1:
attempt = 1
delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
)
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
def _build_turn_status_payload(thread_id: int) -> dict[str, object]:
lock = manager.lock_for(str(thread_id))
if not lock.locked():
return {"status": "idle"}
if is_cancel_requested(str(thread_id)):
cancel_state = get_cancel_state(str(thread_id))
attempt = cancel_state[0] if cancel_state else 1
retry_after_ms = _compute_turn_cancelling_retry_delay(attempt)
retry_after_at = int(datetime.now(UTC).timestamp() * 1000) + retry_after_ms
return {
"status": "cancelling",
"retry_after_ms": retry_after_ms,
"retry_after_at": retry_after_at,
}
return {"status": "busy"}
def _set_retry_after_headers(response: Response, retry_after_ms: int) -> None:
response.headers["retry-after-ms"] = str(retry_after_ms)
response.headers["Retry-After"] = str(max(1, (retry_after_ms + 999) // 1000))
def _raise_if_thread_busy_for_start(thread_id: int) -> None:
status_payload = _build_turn_status_payload(thread_id)
status = status_payload["status"]
if status == "idle":
return
if status == "cancelling":
retry_after_ms = int(status_payload.get("retry_after_ms") or 0)
detail = {
"errorCode": "TURN_CANCELLING",
"message": "A previous response is still stopping. Please try again in a moment.",
"retry_after_ms": retry_after_ms if retry_after_ms > 0 else None,
"retry_after_at": status_payload.get("retry_after_at"),
}
headers = (
{
"retry-after-ms": str(retry_after_ms),
"Retry-After": str(max(1, (retry_after_ms + 999) // 1000)),
}
if retry_after_ms > 0
else None
)
raise HTTPException(status_code=409, detail=detail, headers=headers)
raise HTTPException(
status_code=409,
detail={
"errorCode": "THREAD_BUSY",
"message": "Another response is still finishing for this thread. Please try again in a moment.",
},
)
def _find_pre_turn_checkpoint_id( def _find_pre_turn_checkpoint_id(
checkpoint_tuples: list, checkpoint_tuples: list,
*, *,
@ -1476,6 +1553,7 @@ async def handle_new_chat(
# Check thread-level access based on visibility # Check thread-level access based on visibility
await check_thread_access(session, thread, user) await check_thread_access(session, thread, user)
_raise_if_thread_busy_for_start(request.chat_id)
filesystem_selection = _resolve_filesystem_selection( filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode, mode=request.filesystem_mode,
client_platform=request.client_platform, client_platform=request.client_platform,
@ -1550,6 +1628,93 @@ async def handle_new_chat(
) from None ) from None
@router.post(
"/threads/{thread_id}/cancel-active-turn",
response_model=CancelActiveTurnResponse,
)
async def cancel_active_turn(
thread_id: int,
response: Response,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Signal cancellation for the currently running turn on ``thread_id``."""
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
)
await check_thread_access(session, thread, user)
status_payload = _build_turn_status_payload(thread_id)
if status_payload["status"] == "idle":
return CancelActiveTurnResponse(
status="idle",
error_code="NO_ACTIVE_TURN",
)
request_cancel(str(thread_id))
response.status_code = 202
updated_payload = _build_turn_status_payload(thread_id)
retry_after_ms = int(updated_payload.get("retry_after_ms") or 0)
retry_after_at = (
int(updated_payload["retry_after_at"])
if "retry_after_at" in updated_payload
else None
)
if retry_after_ms > 0:
_set_retry_after_headers(response, retry_after_ms)
return CancelActiveTurnResponse(
status="cancelling",
error_code="TURN_CANCELLING",
retry_after_ms=retry_after_ms if retry_after_ms > 0 else None,
retry_after_at=retry_after_at,
)
@router.get(
"/threads/{thread_id}/turn-status",
response_model=TurnStatusResponse,
)
async def get_turn_status(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to view chats in this search space",
)
await check_thread_access(session, thread, user)
status_payload = _build_turn_status_payload(thread_id)
return TurnStatusResponse(
status=status_payload["status"], # type: ignore[arg-type]
active_turn_id=None,
retry_after_ms=status_payload.get("retry_after_ms"), # type: ignore[arg-type]
retry_after_at=status_payload.get("retry_after_at"), # type: ignore[arg-type]
)
# ============================================================================= # =============================================================================
# Chat Regeneration Endpoint (Edit/Reload) # Chat Regeneration Endpoint (Edit/Reload)
# ============================================================================= # =============================================================================
@ -1605,6 +1770,7 @@ async def regenerate_response(
# Check thread-level access based on visibility # Check thread-level access based on visibility
await check_thread_access(session, thread, user) await check_thread_access(session, thread, user)
_raise_if_thread_busy_for_start(thread_id)
filesystem_selection = _resolve_filesystem_selection( filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode, mode=request.filesystem_mode,
client_platform=request.client_platform, client_platform=request.client_platform,
@ -2012,6 +2178,7 @@ async def resume_chat(
) )
await check_thread_access(session, thread, user) await check_thread_access(session, thread, user)
_raise_if_thread_busy_for_start(thread_id)
filesystem_selection = _resolve_filesystem_selection( filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode, mode=request.filesystem_mode,
client_platform=request.client_platform, client_platform=request.client_platform,

View file

@ -335,6 +335,24 @@ class ResumeRequest(BaseModel):
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
class CancelActiveTurnResponse(BaseModel):
"""Response for canceling an active turn on a chat thread."""
status: Literal["cancelling", "idle"]
error_code: Literal["TURN_CANCELLING", "NO_ACTIVE_TURN"]
retry_after_ms: int | None = None
retry_after_at: int | None = None
class TurnStatusResponse(BaseModel):
"""Current turn execution status for a thread."""
status: Literal["idle", "busy", "cancelling"]
active_turn_id: str | None = None
retry_after_ms: int | None = None
retry_after_at: int | None = None
# ============================================================================= # =============================================================================
# Public Chat Snapshot Schemas # Public Chat Snapshot Schemas
# ============================================================================= # =============================================================================

View file

@ -565,7 +565,12 @@ class VercelStreamingService:
# Error Part # Error Part
# ========================================================================= # =========================================================================
def format_error(self, error_text: str, error_code: str | None = None) -> str: def format_error(
self,
error_text: str,
error_code: str | None = None,
extra: dict[str, object] | None = None,
) -> str:
""" """
Format an error message. Format an error message.
@ -579,9 +584,11 @@ class VercelStreamingService:
Example output: Example output:
data: {"type":"error","errorText":"Something went wrong","errorCode":"SOME_CODE"} data: {"type":"error","errorText":"Something went wrong","errorCode":"SOME_CODE"}
""" """
payload: dict[str, str] = {"type": "error", "errorText": error_text} payload: dict[str, object] = {"type": "error", "errorText": error_text}
if error_code: if error_code:
payload["errorCode"] = error_code payload["errorCode"] = error_code
if extra:
payload.update(extra)
return self._format_sse(payload) return self._format_sse(payload)
# ========================================================================= # =========================================================================

View file

@ -45,6 +45,11 @@ from app.agents.new_chat.memory_extraction import (
extract_and_save_memory, extract_and_save_memory,
extract_and_save_team_memory, extract_and_save_team_memory,
) )
from app.agents.new_chat.middleware.busy_mutex import (
end_turn,
get_cancel_state,
is_cancel_requested,
)
from app.agents.new_chat.middleware.kb_persistence import ( from app.agents.new_chat.middleware.kb_persistence import (
commit_staged_filesystem_state, commit_staged_filesystem_state,
) )
@ -72,6 +77,18 @@ from app.utils.user_message_multimodal import build_human_message_content
_background_tasks: set[asyncio.Task] = set() _background_tasks: set[asyncio.Task] = set()
_perf_log = get_perf_logger() _perf_log = get_perf_logger()
TURN_CANCELLING_INITIAL_DELAY_MS = 200
TURN_CANCELLING_BACKOFF_FACTOR = 2
TURN_CANCELLING_MAX_DELAY_MS = 1500
def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
if attempt < 1:
attempt = 1
delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
)
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]: def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
@ -401,15 +418,35 @@ def _classify_stream_exception(
exc: Exception, exc: Exception,
*, *,
flow_label: str, flow_label: str,
) -> tuple[str, str, Literal["info", "warn", "error"], bool, str]: ) -> tuple[
str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None
]:
raw = str(exc) raw = str(exc)
if isinstance(exc, BusyError) or "Thread is busy with another request" in raw: if isinstance(exc, BusyError) or "Thread is busy with another request" in raw:
busy_thread_id = str(exc.request_id) if isinstance(exc, BusyError) else None
if busy_thread_id and is_cancel_requested(busy_thread_id):
cancel_state = get_cancel_state(busy_thread_id)
attempt = cancel_state[0] if cancel_state else 1
retry_after_ms = _compute_turn_cancelling_retry_delay(attempt)
retry_after_at = int(time.time() * 1000) + retry_after_ms
return (
"thread_busy",
"TURN_CANCELLING",
"info",
True,
"A previous response is still stopping. Please try again in a moment.",
{
"retry_after_ms": retry_after_ms,
"retry_after_at": retry_after_at,
},
)
return ( return (
"thread_busy", "thread_busy",
"THREAD_BUSY", "THREAD_BUSY",
"warn", "warn",
True, True,
"Another response is still finishing for this thread. Please try again in a moment.", "Another response is still finishing for this thread. Please try again in a moment.",
None,
) )
parsed = _parse_error_payload(raw) parsed = _parse_error_payload(raw)
@ -431,6 +468,7 @@ def _classify_stream_exception(
"warn", "warn",
True, True,
"This model is temporarily rate-limited. Please try again in a few seconds or switch models.", "This model is temporarily rate-limited. Please try again in a few seconds or switch models.",
None,
) )
return ( return (
@ -439,6 +477,7 @@ def _classify_stream_exception(
"error", "error",
False, False,
f"Error during {flow_label}: {raw}", f"Error during {flow_label}: {raw}",
None,
) )
@ -470,7 +509,7 @@ def _emit_stream_terminal_error(
message=message, message=message,
extra=extra, extra=extra,
) )
return streaming_service.format_error(message, error_code=error_code) return streaming_service.format_error(message, error_code=error_code, extra=extra)
def _legacy_match_lc_id( def _legacy_match_lc_id(
@ -2497,6 +2536,7 @@ async def stream_new_chat(
"turn-info", "turn-info",
{"chat_turn_id": stream_result.turn_id}, {"chat_turn_id": stream_result.turn_id},
) )
yield streaming_service.format_data("turn-status", {"status": "busy"})
# Initial thinking step - analyzing the request # Initial thinking step - analyzing the request
if mentioned_surfsense_docs: if mentioned_surfsense_docs:
@ -2805,6 +2845,7 @@ async def stream_new_chat(
task.add_done_callback(_background_tasks.discard) task.add_done_callback(_background_tasks.discard)
# Finish the step and message # Finish the step and message
yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step() yield streaming_service.format_finish_step()
yield streaming_service.format_finish() yield streaming_service.format_finish()
yield streaming_service.format_done() yield streaming_service.format_done()
@ -2819,11 +2860,19 @@ async def stream_new_chat(
severity, severity,
is_expected, is_expected,
user_message, user_message,
error_extra,
) = _classify_stream_exception(e, flow_label="chat") ) = _classify_stream_exception(e, flow_label="chat")
error_message = f"Error during chat: {e!s}" error_message = f"Error during chat: {e!s}"
print(f"[stream_new_chat] {error_message}") print(f"[stream_new_chat] {error_message}")
print(f"[stream_new_chat] Exception type: {type(e).__name__}") print(f"[stream_new_chat] Exception type: {type(e).__name__}")
print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}") print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}")
if error_code == "TURN_CANCELLING":
status_payload: dict[str, Any] = {"status": "cancelling"}
if error_extra:
status_payload.update(error_extra)
yield streaming_service.format_data("turn-status", status_payload)
else:
yield streaming_service.format_data("turn-status", {"status": "busy"})
yield _emit_stream_error( yield _emit_stream_error(
message=user_message, message=user_message,
@ -2831,7 +2880,9 @@ async def stream_new_chat(
error_code=error_code, error_code=error_code,
severity=severity, severity=severity,
is_expected=is_expected, is_expected=is_expected,
extra=error_extra,
) )
yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step() yield streaming_service.format_finish_step()
yield streaming_service.format_finish() yield streaming_service.format_finish()
yield streaming_service.format_done() yield streaming_service.format_done()
@ -2847,6 +2898,10 @@ async def stream_new_chat(
# (CancelledError is a BaseException), and the rest of the # (CancelledError is a BaseException), and the rest of the
# finally block — including session.close() — would never run. # finally block — including session.close() — would never run.
with anyio.CancelScope(shield=True): with anyio.CancelScope(shield=True):
# Authoritative fallback cleanup for lock/cancel state. Middleware
# teardown can be skipped on some client-abort paths.
end_turn(str(chat_id))
# Release premium reservation if not finalized # Release premium reservation if not finalized
if _premium_request_id and _premium_reserved > 0 and user_id: if _premium_request_id and _premium_reserved > 0 and user_id:
try: try:
@ -3206,6 +3261,7 @@ async def stream_resume_chat(
"turn-info", "turn-info",
{"chat_turn_id": stream_result.turn_id}, {"chat_turn_id": stream_result.turn_id},
) )
yield streaming_service.format_data("turn-status", {"status": "busy"})
_t_stream_start = time.perf_counter() _t_stream_start = time.perf_counter()
_first_event_logged = False _first_event_logged = False
@ -3305,6 +3361,7 @@ async def stream_resume_chat(
}, },
) )
yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step() yield streaming_service.format_finish_step()
yield streaming_service.format_finish() yield streaming_service.format_finish()
yield streaming_service.format_done() yield streaming_service.format_done()
@ -3318,23 +3375,37 @@ async def stream_resume_chat(
severity, severity,
is_expected, is_expected,
user_message, user_message,
error_extra,
) = _classify_stream_exception(e, flow_label="resume") ) = _classify_stream_exception(e, flow_label="resume")
error_message = f"Error during resume: {e!s}" error_message = f"Error during resume: {e!s}"
print(f"[stream_resume_chat] {error_message}") print(f"[stream_resume_chat] {error_message}")
print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}") print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}")
if error_code == "TURN_CANCELLING":
status_payload: dict[str, Any] = {"status": "cancelling"}
if error_extra:
status_payload.update(error_extra)
yield streaming_service.format_data("turn-status", status_payload)
else:
yield streaming_service.format_data("turn-status", {"status": "busy"})
yield _emit_stream_error( yield _emit_stream_error(
message=user_message, message=user_message,
error_kind=error_kind, error_kind=error_kind,
error_code=error_code, error_code=error_code,
severity=severity, severity=severity,
is_expected=is_expected, is_expected=is_expected,
extra=error_extra,
) )
yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step() yield streaming_service.format_finish_step()
yield streaming_service.format_finish() yield streaming_service.format_finish()
yield streaming_service.format_done() yield streaming_service.format_done()
finally: finally:
with anyio.CancelScope(shield=True): with anyio.CancelScope(shield=True):
# Authoritative fallback cleanup for lock/cancel state. Middleware
# teardown can be skipped on some client-abort paths.
end_turn(str(chat_id))
# Release premium reservation if not finalized # Release premium reservation if not finalized
if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id: if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id:
try: try:

View file

@ -7,7 +7,9 @@ import pytest
from app.agents.new_chat.errors import BusyError from app.agents.new_chat.errors import BusyError
from app.agents.new_chat.middleware.busy_mutex import ( from app.agents.new_chat.middleware.busy_mutex import (
BusyMutexMiddleware, BusyMutexMiddleware,
end_turn,
get_cancel_event, get_cancel_event,
is_cancel_requested,
manager, manager,
request_cancel, request_cancel,
reset_cancel, reset_cancel,
@ -88,3 +90,31 @@ async def test_no_thread_id_skipped_when_not_required() -> None:
def test_reset_cancel_idempotent() -> None: def test_reset_cancel_idempotent() -> None:
# Should not raise even if event was never created # Should not raise even if event was never created
reset_cancel("never-seen") reset_cancel("never-seen")
def test_request_cancel_creates_event_for_unseen_thread() -> None:
thread_id = "never-seen-cancel"
reset_cancel(thread_id)
assert request_cancel(thread_id) is True
assert get_cancel_event(thread_id).is_set()
assert is_cancel_requested(thread_id) is True
@pytest.mark.asyncio
async def test_end_turn_force_clears_lock_and_cancel_state() -> None:
thread_id = "forced-end-turn"
mw = BusyMutexMiddleware()
runtime = _Runtime(thread_id)
await mw.abefore_agent({}, runtime)
assert manager.lock_for(thread_id).locked()
request_cancel(thread_id)
assert is_cancel_requested(thread_id) is True
end_turn(thread_id)
assert not manager.lock_for(thread_id).locked()
assert not get_cancel_event(thread_id).is_set()
assert is_cancel_requested(thread_id) is False

View file

@ -8,6 +8,7 @@ import pytest
import app.tasks.chat.stream_new_chat as stream_new_chat_module import app.tasks.chat.stream_new_chat as stream_new_chat_module
from app.agents.new_chat.errors import BusyError from app.agents.new_chat.errors import BusyError
from app.agents.new_chat.middleware.busy_mutex import request_cancel, reset_cancel
from app.tasks.chat.stream_new_chat import ( from app.tasks.chat.stream_new_chat import (
StreamResult, StreamResult,
_classify_stream_exception, _classify_stream_exception,
@ -147,7 +148,7 @@ def test_stream_exception_classifies_rate_limited():
exc = Exception( exc = Exception(
'{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}' '{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}'
) )
kind, code, severity, is_expected, user_message = _classify_stream_exception( kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
exc, flow_label="chat" exc, flow_label="chat"
) )
assert kind == "rate_limited" assert kind == "rate_limited"
@ -155,11 +156,12 @@ def test_stream_exception_classifies_rate_limited():
assert severity == "warn" assert severity == "warn"
assert is_expected is True assert is_expected is True
assert "temporarily rate-limited" in user_message assert "temporarily rate-limited" in user_message
assert extra is None
def test_stream_exception_classifies_thread_busy(): def test_stream_exception_classifies_thread_busy():
exc = BusyError(request_id="thread-123") exc = BusyError(request_id="thread-123")
kind, code, severity, is_expected, user_message = _classify_stream_exception( kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
exc, flow_label="chat" exc, flow_label="chat"
) )
assert kind == "thread_busy" assert kind == "thread_busy"
@ -167,11 +169,12 @@ def test_stream_exception_classifies_thread_busy():
assert severity == "warn" assert severity == "warn"
assert is_expected is True assert is_expected is True
assert "still finishing for this thread" in user_message assert "still finishing for this thread" in user_message
assert extra is None
def test_stream_exception_classifies_thread_busy_from_message(): def test_stream_exception_classifies_thread_busy_from_message():
exc = Exception("Thread is busy with another request") exc = Exception("Thread is busy with another request")
kind, code, severity, is_expected, user_message = _classify_stream_exception( kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
exc, flow_label="chat" exc, flow_label="chat"
) )
assert kind == "thread_busy" assert kind == "thread_busy"
@ -179,6 +182,24 @@ def test_stream_exception_classifies_thread_busy_from_message():
assert severity == "warn" assert severity == "warn"
assert is_expected is True assert is_expected is True
assert "still finishing for this thread" in user_message assert "still finishing for this thread" in user_message
assert extra is None
def test_stream_exception_classifies_turn_cancelling_when_cancel_requested():
thread_id = "thread-cancelling-1"
reset_cancel(thread_id)
request_cancel(thread_id)
exc = BusyError(request_id=thread_id)
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
exc, flow_label="chat"
)
assert kind == "thread_busy"
assert code == "TURN_CANCELLING"
assert severity == "info"
assert is_expected is True
assert "stopping" in user_message
assert isinstance(extra, dict)
assert "retry_after_ms" in extra
def test_premium_classification_is_error_code_driven(): def test_premium_classification_is_error_code_driven():
@ -219,33 +240,33 @@ def test_toast_only_pre_accept_policy_has_no_inline_failed_marker():
def test_network_send_failures_use_unified_retry_toast_message(): def test_network_send_failures_use_unified_retry_toast_message():
classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts" classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts"
classifier_source = classifier_path.read_text(encoding="utf-8") classifier_source = classifier_path.read_text(encoding="utf-8")
page_path = ( request_errors_path = (
Path(__file__).resolve().parents[3] Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-request-errors.ts"
/ "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
) )
page_source = page_path.read_text(encoding="utf-8") request_errors_source = request_errors_path.read_text(encoding="utf-8")
assert '"send_failed_pre_accept"' in classifier_source assert '"send_failed_pre_accept"' in classifier_source
assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source
assert 'errorCode === "TURN_CANCELLING"' in classifier_source
assert "if (withCode.code) return withCode.code;" in classifier_source assert "if (withCode.code) return withCode.code;" in classifier_source
assert 'userMessage: "Message not sent. Please retry."' in classifier_source assert 'userMessage: "Message not sent. Please retry."' in classifier_source
assert 'userMessage: "Connection issue. Please try again."' in classifier_source assert 'userMessage: "Connection issue. Please try again."' in classifier_source
assert "tagPreAcceptSendFailure(error)" in page_source assert "const passthroughCodes = new Set([" in request_errors_source
assert "const passthroughCodes = new Set([" in page_source assert '"PREMIUM_QUOTA_EXHAUSTED"' in request_errors_source
assert '"PREMIUM_QUOTA_EXHAUSTED"' in page_source assert '"THREAD_BUSY"' in request_errors_source
assert '"THREAD_BUSY"' in page_source assert '"TURN_CANCELLING"' in request_errors_source
assert '"AUTH_EXPIRED"' in page_source assert '"AUTH_EXPIRED"' in request_errors_source
assert '"UNAUTHORIZED"' in page_source assert '"UNAUTHORIZED"' in request_errors_source
assert '"RATE_LIMITED"' in page_source assert '"RATE_LIMITED"' in request_errors_source
assert '"NETWORK_ERROR"' in page_source assert '"NETWORK_ERROR"' in request_errors_source
assert '"STREAM_PARSE_ERROR"' in page_source assert '"STREAM_PARSE_ERROR"' in request_errors_source
assert '"TOOL_EXECUTION_ERROR"' in page_source assert '"TOOL_EXECUTION_ERROR"' in request_errors_source
assert '"PERSIST_MESSAGE_FAILED"' in page_source assert '"PERSIST_MESSAGE_FAILED"' in request_errors_source
assert '"SERVER_ERROR"' in page_source assert '"SERVER_ERROR"' in request_errors_source
assert "passthroughCodes.has(existingCode)" in page_source assert "passthroughCodes.has(existingCode)" in request_errors_source
assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in page_source assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in request_errors_source
assert 'errorCode: "NETWORK_ERROR"' not in page_source assert 'errorCode: "NETWORK_ERROR"' not in request_errors_source
assert "Failed to start chat. Please try again." not in page_source assert "Failed to start chat. Please try again." not in classifier_source
def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows(): def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows():
@ -269,3 +290,75 @@ def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows()
# New flow persists only when accepted and not already persisted. # New flow persists only when accepted and not already persisted.
assert "if (newAccepted && !userPersisted) {" in source assert "if (newAccepted && !userPersisted) {" in source
assert "const fetchWithTurnCancellingRetry = useCallback(" in source
assert "computeFallbackTurnCancellingRetryDelay" in source
assert 'withMeta.errorCode === "TURN_CANCELLING"' in source
assert 'withMeta.errorCode === "THREAD_BUSY"' in source
assert "await fetchWithTurnCancellingRetry(() =>" in source
def test_cancel_active_turn_route_contract_exists():
routes_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_backend/app/routes/new_chat_routes.py"
)
source = routes_path.read_text(encoding="utf-8")
assert '@router.post(\n "/threads/{thread_id}/cancel-active-turn",' in source
assert "response_model=CancelActiveTurnResponse" in source
assert 'status="cancelling",' in source
assert 'error_code="TURN_CANCELLING",' in source
assert "retry_after_ms=retry_after_ms if retry_after_ms > 0 else None," in source
assert "retry_after_at=" in source
assert 'status="idle",' in source
assert 'error_code="NO_ACTIVE_TURN",' in source
def test_turn_status_route_contract_exists():
routes_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_backend/app/routes/new_chat_routes.py"
)
source = routes_path.read_text(encoding="utf-8")
assert '@router.get(\n "/threads/{thread_id}/turn-status",' in source
assert "response_model=TurnStatusResponse" in source
assert "_build_turn_status_payload(thread_id)" in source
assert "Permission.CHATS_READ.value" in source
assert "_raise_if_thread_busy_for_start(" in source
def test_turn_cancelling_retry_policy_contract_exists():
routes_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_backend/app/routes/new_chat_routes.py"
)
source = routes_path.read_text(encoding="utf-8")
assert "TURN_CANCELLING_INITIAL_DELAY_MS = 200" in source
assert "TURN_CANCELLING_BACKOFF_FACTOR = 2" in source
assert "TURN_CANCELLING_MAX_DELAY_MS = 1500" in source
assert "def _compute_turn_cancelling_retry_delay(" in source
assert "retry-after-ms" in source
assert '"Retry-After"' in source
assert '"errorCode": "TURN_CANCELLING"' in source
def test_turn_status_sse_contract_exists():
stream_source = (
Path(__file__).resolve().parents[3]
/ "surfsense_backend/app/tasks/chat/stream_new_chat.py"
).read_text(encoding="utf-8")
state_source = (
Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/streaming-state.ts"
).read_text(encoding="utf-8")
pipeline_source = (
Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/stream-pipeline.ts"
).read_text(encoding="utf-8")
assert '"turn-status"' in stream_source
assert '"status": "busy"' in stream_source
assert '"status": "idle"' in stream_source
assert "type: \"data-turn-status\"" in state_source
assert "case \"data-turn-status\":" in pipeline_source
assert "end_turn(str(chat_id))" in stream_source

View file

@ -182,6 +182,20 @@ function extractMentionedDocuments(content: unknown): MentionedDocumentInfo[] {
* ``stream_new_chat.py``) keep the JSON from ballooning. * ``stream_new_chat.py``) keep the JSON from ballooning.
*/ */
const TOOLS_WITH_UI_ALL: ToolUIGate = "all"; const TOOLS_WITH_UI_ALL: ToolUIGate = "all";
const TURN_CANCELLING_INITIAL_DELAY_MS = 200;
const TURN_CANCELLING_BACKOFF_FACTOR = 2;
const TURN_CANCELLING_MAX_DELAY_MS = 1500;
const RECENT_CANCEL_WINDOW_MS = 5_000;
function sleep(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}
function computeFallbackTurnCancellingRetryDelay(attempt: number): number {
const safeAttempt = Math.max(1, attempt);
const raw = TURN_CANCELLING_INITIAL_DELAY_MS * TURN_CANCELLING_BACKOFF_FACTOR ** (safeAttempt - 1);
return Math.min(raw, TURN_CANCELLING_MAX_DELAY_MS);
}
export default function NewChatPage() { export default function NewChatPage() {
const params = useParams(); const params = useParams();
@ -193,6 +207,7 @@ export default function NewChatPage() {
const [isRunning, setIsRunning] = useState(false); const [isRunning, setIsRunning] = useState(false);
const [tokenUsageStore] = useState(() => createTokenUsageStore()); const [tokenUsageStore] = useState(() => createTokenUsageStore());
const abortControllerRef = useRef<AbortController | null>(null); const abortControllerRef = useRef<AbortController | null>(null);
const recentCancelRequestedAtRef = useRef(0);
const [pendingInterrupt, setPendingInterrupt] = useState<{ const [pendingInterrupt, setPendingInterrupt] = useState<{
threadId: number; threadId: number;
assistantMsgId: string; assistantMsgId: string;
@ -598,6 +613,36 @@ export default function NewChatPage() {
[handleChatFailure] [handleChatFailure]
); );
const fetchWithTurnCancellingRetry = useCallback(
async (runFetch: () => Promise<Response>) => {
const maxAttempts = 4;
for (let attempt = 1; attempt <= maxAttempts; attempt += 1) {
const response = await runFetch();
if (response.ok) {
return response;
}
const error = await toHttpResponseError(response);
const withMeta = error as Error & { errorCode?: string; retryAfterMs?: number };
const isTurnCancelling = withMeta.errorCode === "TURN_CANCELLING";
const isRecentThreadBusyAfterCancel =
withMeta.errorCode === "THREAD_BUSY" &&
Date.now() - recentCancelRequestedAtRef.current <= RECENT_CANCEL_WINDOW_MS;
if ((isTurnCancelling || isRecentThreadBusyAfterCancel) && attempt < maxAttempts) {
const waitMs =
withMeta.retryAfterMs ?? computeFallbackTurnCancellingRetryDelay(attempt);
await sleep(waitMs);
continue;
}
throw error;
}
throw Object.assign(new Error("Turn cancellation retry limit exceeded"), {
errorCode: "TURN_CANCELLING",
});
},
[]
);
// Initialize thread and load messages // Initialize thread and load messages
// For new chats (no urlChatId), we use lazy creation - thread is created on first message // For new chats (no urlChatId), we use lazy creation - thread is created on first message
const initializeThread = useCallback(async () => { const initializeThread = useCallback(async () => {
@ -767,12 +812,39 @@ export default function NewChatPage() {
// Cancel ongoing request // Cancel ongoing request
const cancelRun = useCallback(async () => { const cancelRun = useCallback(async () => {
if (threadId) {
const token = getBearerToken();
if (token) {
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
try {
const response = await fetch(
`${backendUrl}/api/v1/threads/${threadId}/cancel-active-turn`,
{
method: "POST",
headers: {
Authorization: `Bearer ${token}`,
},
}
);
if (response.ok) {
const payload = (await response.json()) as {
error_code?: string;
};
if (payload.error_code === "TURN_CANCELLING") {
recentCancelRequestedAtRef.current = Date.now();
}
}
} catch (error) {
console.warn("[NewChatPage] Failed to signal cancel-active-turn:", error);
}
}
}
if (abortControllerRef.current) { if (abortControllerRef.current) {
abortControllerRef.current.abort(); abortControllerRef.current.abort();
abortControllerRef.current = null; abortControllerRef.current = null;
} }
setIsRunning(false); setIsRunning(false);
}, []); }, [threadId]);
// Handle new message from user // Handle new message from user
const onNew = useCallback( const onNew = useCallback(
@ -971,29 +1043,33 @@ export default function NewChatPage() {
setMentionedDocuments([]); setMentionedDocuments([]);
} }
const response = await fetch(`${backendUrl}/api/v1/new_chat`, { const response = await fetchWithTurnCancellingRetry(() =>
method: "POST", fetch(`${backendUrl}/api/v1/new_chat`, {
headers: { method: "POST",
"Content-Type": "application/json", headers: {
Authorization: `Bearer ${token}`, "Content-Type": "application/json",
}, Authorization: `Bearer ${token}`,
body: JSON.stringify({ },
chat_id: currentThreadId, body: JSON.stringify({
user_query: userQuery.trim(), chat_id: currentThreadId,
search_space_id: searchSpaceId, user_query: userQuery.trim(),
filesystem_mode: selection.filesystem_mode, search_space_id: searchSpaceId,
client_platform: selection.client_platform, filesystem_mode: selection.filesystem_mode,
local_filesystem_mounts: selection.local_filesystem_mounts, client_platform: selection.client_platform,
messages: messageHistory, local_filesystem_mounts: selection.local_filesystem_mounts,
mentioned_document_ids: hasDocumentIds ? mentionedDocumentIds.document_ids : undefined, messages: messageHistory,
mentioned_surfsense_doc_ids: hasSurfsenseDocIds mentioned_document_ids: hasDocumentIds
? mentionedDocumentIds.surfsense_doc_ids ? mentionedDocumentIds.document_ids
: undefined, : undefined,
disabled_tools: disabledTools.length > 0 ? disabledTools : undefined, mentioned_surfsense_doc_ids: hasSurfsenseDocIds
...(userImages.length > 0 ? { user_images: userImages } : {}), ? mentionedDocumentIds.surfsense_doc_ids
}), : undefined,
signal: controller.signal, disabled_tools: disabledTools.length > 0 ? disabledTools : undefined,
}); ...(userImages.length > 0 ? { user_images: userImages } : {}),
}),
signal: controller.signal,
})
);
if (!response.ok) { if (!response.ok) {
throw await toHttpResponseError(response); throw await toHttpResponseError(response);
@ -1033,6 +1109,11 @@ export default function NewChatPage() {
tokenUsageData = data; tokenUsageData = data;
tokenUsageStore.set(assistantMsgId, data); tokenUsageStore.set(assistantMsgId, data);
}, },
onTurnStatus: (data) => {
if (data.status === "cancelling") {
recentCancelRequestedAtRef.current = Date.now();
}
},
onToolOutputAvailable: (event, sharedCtx) => { onToolOutputAvailable: (event, sharedCtx) => {
if (event.output?.status === "pending" && event.output?.podcast_id) { if (event.output?.status === "pending" && event.output?.podcast_id) {
const idx = sharedCtx.toolCallIndices.get(event.toolCallId); const idx = sharedCtx.toolCallIndices.get(event.toolCallId);
@ -1257,6 +1338,7 @@ export default function NewChatPage() {
tokenUsageStore, tokenUsageStore,
pendingUserImageUrls, pendingUserImageUrls,
setPendingUserImageUrls, setPendingUserImageUrls,
fetchWithTurnCancellingRetry,
handleStreamTerminalError, handleStreamTerminalError,
handleChatFailure, handleChatFailure,
persistAssistantTurn, persistAssistantTurn,
@ -1354,21 +1436,23 @@ export default function NewChatPage() {
try { try {
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
const selection = await getAgentFilesystemSelection(searchSpaceId); const selection = await getAgentFilesystemSelection(searchSpaceId);
const response = await fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, { const response = await fetchWithTurnCancellingRetry(() =>
method: "POST", fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, {
headers: { method: "POST",
"Content-Type": "application/json", headers: {
Authorization: `Bearer ${token}`, "Content-Type": "application/json",
}, Authorization: `Bearer ${token}`,
body: JSON.stringify({ },
search_space_id: searchSpaceId, body: JSON.stringify({
decisions, search_space_id: searchSpaceId,
filesystem_mode: selection.filesystem_mode, decisions,
client_platform: selection.client_platform, filesystem_mode: selection.filesystem_mode,
local_filesystem_mounts: selection.local_filesystem_mounts, client_platform: selection.client_platform,
}), local_filesystem_mounts: selection.local_filesystem_mounts,
signal: controller.signal, }),
}); signal: controller.signal,
})
);
if (!response.ok) { if (!response.ok) {
throw await toHttpResponseError(response); throw await toHttpResponseError(response);
@ -1399,6 +1483,11 @@ export default function NewChatPage() {
tokenUsageData = data; tokenUsageData = data;
tokenUsageStore.set(assistantMsgId, data); tokenUsageStore.set(assistantMsgId, data);
}, },
onTurnStatus: (data) => {
if (data.status === "cancelling") {
recentCancelRequestedAtRef.current = Date.now();
}
},
}) })
) { ) {
return; return;
@ -1496,6 +1585,7 @@ export default function NewChatPage() {
searchSpaceId, searchSpaceId,
queryClient, queryClient,
tokenUsageStore, tokenUsageStore,
fetchWithTurnCancellingRetry,
handleStreamTerminalError, handleStreamTerminalError,
persistAssistantTurn, persistAssistantTurn,
] ]
@ -1700,15 +1790,17 @@ export default function NewChatPage() {
requestBody.revert_actions = true; requestBody.revert_actions = true;
} }
} }
const response = await fetch(getRegenerateUrl(threadId), { const response = await fetchWithTurnCancellingRetry(() =>
method: "POST", fetch(getRegenerateUrl(threadId), {
headers: { method: "POST",
"Content-Type": "application/json", headers: {
Authorization: `Bearer ${token}`, "Content-Type": "application/json",
}, Authorization: `Bearer ${token}`,
body: JSON.stringify(requestBody), },
signal: controller.signal, body: JSON.stringify(requestBody),
}); signal: controller.signal,
})
);
if (!response.ok) { if (!response.ok) {
throw await toHttpResponseError(response); throw await toHttpResponseError(response);
@ -1774,6 +1866,11 @@ export default function NewChatPage() {
tokenUsageData = data; tokenUsageData = data;
tokenUsageStore.set(assistantMsgId, data); tokenUsageStore.set(assistantMsgId, data);
}, },
onTurnStatus: (data) => {
if (data.status === "cancelling") {
recentCancelRequestedAtRef.current = Date.now();
}
},
onToolOutputAvailable: (event, sharedCtx) => { onToolOutputAvailable: (event, sharedCtx) => {
if (event.output?.status === "pending" && event.output?.podcast_id) { if (event.output?.status === "pending" && event.output?.podcast_id) {
const idx = sharedCtx.toolCallIndices.get(event.toolCallId); const idx = sharedCtx.toolCallIndices.get(event.toolCallId);
@ -1945,6 +2042,7 @@ export default function NewChatPage() {
setMessageDocumentsMap, setMessageDocumentsMap,
queryClient, queryClient,
tokenUsageStore, tokenUsageStore,
fetchWithTurnCancellingRetry,
handleStreamTerminalError, handleStreamTerminalError,
persistAssistantTurn, persistAssistantTurn,
persistUserTurn, persistUserTurn,

View file

@ -147,6 +147,22 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError
}; };
} }
if (
errorCode === "TURN_CANCELLING"
) {
return {
kind: "thread_busy",
channel: "toast",
severity: "info",
telemetryEvent: "chat_blocked",
isExpected: true,
userMessage: "A previous response is still stopping. Please try again in a moment.",
rawMessage,
errorCode: errorCode ?? "TURN_CANCELLING",
details: { flow: input.flow },
};
}
if ( if (
errorCode === "THREAD_BUSY" errorCode === "THREAD_BUSY"
) { ) {
@ -156,7 +172,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError
severity: "warn", severity: "warn",
telemetryEvent: "chat_blocked", telemetryEvent: "chat_blocked",
isExpected: true, isExpected: true,
userMessage: "A previous response is still stopping. Please try again in a moment.", userMessage: "Another response is still finishing for this thread. Please try again in a moment.",
rawMessage, rawMessage,
errorCode: errorCode ?? "THREAD_BUSY", errorCode: errorCode ?? "THREAD_BUSY",
details: { flow: input.flow }, details: { flow: input.flow },

View file

@ -1,6 +1,6 @@
export async function toHttpResponseError( export async function toHttpResponseError(
response: Response response: Response
): Promise<Error & { errorCode?: string }> { ): Promise<Error & { errorCode?: string; retryAfterMs?: number }> {
const statusDefaultCode = const statusDefaultCode =
response.status === 409 response.status === 409
? "THREAD_BUSY" ? "THREAD_BUSY"
@ -52,13 +52,37 @@ export async function toHttpResponseError(
: undefined; : undefined;
const errorCode = detailCode ?? topLevelCode ?? statusDefaultCode; const errorCode = detailCode ?? topLevelCode ?? statusDefaultCode;
const detailRetryAfterMs =
typeof detailObject?.retry_after_ms === "number"
? detailObject.retry_after_ms
: typeof detailObject?.retryAfterMs === "number"
? detailObject.retryAfterMs
: undefined;
const topRetryAfterMs =
typeof parsedBody?.retry_after_ms === "number"
? parsedBody.retry_after_ms
: typeof parsedBody?.retryAfterMs === "number"
? parsedBody.retryAfterMs
: undefined;
const headerRetryAfterMsRaw = response.headers.get("retry-after-ms");
const headerRetryAfterMs = headerRetryAfterMsRaw ? Number.parseFloat(headerRetryAfterMsRaw) : NaN;
const retryAfterHeader = response.headers.get("retry-after");
const retryAfterSeconds = retryAfterHeader ? Number.parseFloat(retryAfterHeader) : NaN;
const retryAfterMsFromHeader = Number.isFinite(headerRetryAfterMs)
? Math.max(0, Math.round(headerRetryAfterMs))
: Number.isFinite(retryAfterSeconds)
? Math.max(0, Math.round(retryAfterSeconds * 1000))
: undefined;
const retryAfterMs =
detailRetryAfterMs ?? topRetryAfterMs ?? retryAfterMsFromHeader ?? undefined;
const message = const message =
detailNestedMessage ?? detailNestedMessage ??
detailMessage ?? detailMessage ??
topLevelMessage ?? topLevelMessage ??
`Backend error: ${response.status}`; `Backend error: ${response.status}`;
return Object.assign(new Error(message), { errorCode }); return Object.assign(new Error(message), { errorCode, retryAfterMs });
} }
export function tagPreAcceptSendFailure(error: unknown): unknown { export function tagPreAcceptSendFailure(error: unknown): unknown {
@ -68,6 +92,7 @@ export function tagPreAcceptSendFailure(error: unknown): unknown {
const passthroughCodes = new Set([ const passthroughCodes = new Set([
"PREMIUM_QUOTA_EXHAUSTED", "PREMIUM_QUOTA_EXHAUSTED",
"THREAD_BUSY", "THREAD_BUSY",
"TURN_CANCELLING",
"AUTH_EXPIRED", "AUTH_EXPIRED",
"UNAUTHORIZED", "UNAUTHORIZED",
"RATE_LIMITED", "RATE_LIMITED",

View file

@ -21,6 +21,7 @@ export type SharedStreamEventContext = {
scheduleFlush: () => void; scheduleFlush: () => void;
forceFlush: () => void; forceFlush: () => void;
onTokenUsage?: (data: Extract<SSEEvent, { type: "data-token-usage" }>["data"]) => void; onTokenUsage?: (data: Extract<SSEEvent, { type: "data-token-usage" }>["data"]) => void;
onTurnStatus?: (data: Extract<SSEEvent, { type: "data-turn-status" }>["data"]) => void;
onToolOutputAvailable?: ( onToolOutputAvailable?: (
event: Extract<SSEEvent, { type: "tool-output-available" }>, event: Extract<SSEEvent, { type: "tool-output-available" }>,
context: { context: {
@ -173,6 +174,10 @@ export function processSharedStreamEvent(parsed: SSEEvent, context: SharedStream
context.onTokenUsage?.(parsed.data); context.onTokenUsage?.(parsed.data);
return true; return true;
case "data-turn-status":
context.onTurnStatus?.(parsed.data);
return true;
case "error": case "error":
throw toStreamTerminalError(parsed); throw toStreamTerminalError(parsed);

View file

@ -528,6 +528,14 @@ export type SSEEvent =
}>; }>;
}; };
} }
| {
type: "data-turn-status";
data: {
status: "idle" | "busy" | "cancelling";
retry_after_ms?: number;
retry_after_at?: number;
};
}
| { | {
type: "data-token-usage"; type: "data-token-usage";
data: { data: {