From af66fbf106921822a895536c358f2b1a9b93b7a8 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 01:47:52 +0530 Subject: [PATCH] refactor(chat): implement turn cancellation and status management in new chat routes for improved user experience and error handling --- .../agents/new_chat/middleware/busy_mutex.py | 56 ++++- .../app/routes/new_chat_routes.py | 169 ++++++++++++++- surfsense_backend/app/schemas/new_chat.py | 18 ++ .../app/services/new_streaming_service.py | 11 +- .../app/tasks/chat/stream_new_chat.py | 75 ++++++- .../unit/agents/new_chat/test_busy_mutex.py | 30 +++ .../unit/test_stream_new_chat_contract.py | 139 ++++++++++--- .../new-chat/[[...chat_id]]/page.tsx | 194 +++++++++++++----- .../lib/chat/chat-error-classifier.ts | 18 +- surfsense_web/lib/chat/chat-request-errors.ts | 29 ++- surfsense_web/lib/chat/stream-pipeline.ts | 5 + surfsense_web/lib/chat/streaming-state.ts | 8 + 12 files changed, 671 insertions(+), 81 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py index c57d85004..d61a56533 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py +++ b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py @@ -33,6 +33,7 @@ from __future__ import annotations import asyncio import logging +import time import weakref from typing import Any @@ -58,6 +59,8 @@ class _ThreadLockManager: weakref.WeakValueDictionary() ) self._cancel_events: dict[str, asyncio.Event] = {} + self._cancel_requested_at_ms: dict[str, int] = {} + self._cancel_attempt_count: dict[str, int] = {} def lock_for(self, thread_id: str) -> asyncio.Lock: lock = self._locks.get(thread_id) @@ -76,14 +79,45 @@ class _ThreadLockManager: def request_cancel(self, thread_id: str) -> bool: event = self._cancel_events.get(thread_id) if event is None: - return False + event = asyncio.Event() + self._cancel_events[thread_id] = event event.set() + now_ms = int(time.time() * 1000) + self._cancel_requested_at_ms[thread_id] = now_ms + self._cancel_attempt_count[thread_id] = ( + self._cancel_attempt_count.get(thread_id, 0) + 1 + ) return True + def is_cancel_requested(self, thread_id: str) -> bool: + event = self._cancel_events.get(thread_id) + return bool(event and event.is_set()) + + def cancel_state(self, thread_id: str) -> tuple[int, int] | None: + if not self.is_cancel_requested(thread_id): + return None + attempts = self._cancel_attempt_count.get(thread_id, 1) + requested_at_ms = self._cancel_requested_at_ms.get(thread_id, 0) + return attempts, requested_at_ms + def reset(self, thread_id: str) -> None: event = self._cancel_events.get(thread_id) if event is not None: event.clear() + self._cancel_requested_at_ms.pop(thread_id, None) + self._cancel_attempt_count.pop(thread_id, None) + + def end_turn(self, thread_id: str) -> None: + """Best-effort terminal cleanup for a thread turn. + + This is intentionally idempotent and safe to call from outer stream + finally-blocks where middleware teardown might be skipped due to abort + or disconnect edge-cases. + """ + lock = self._locks.get(thread_id) + if lock is not None and lock.locked(): + lock.release() + self.reset(thread_id) # Module-level singleton — process-local but reused across all agent @@ -98,15 +132,30 @@ def get_cancel_event(thread_id: str) -> asyncio.Event: def request_cancel(thread_id: str) -> bool: - """Trip the cancel event for ``thread_id``. Returns True if found.""" + """Trip the cancel event for ``thread_id``. Always returns True.""" return manager.request_cancel(thread_id) +def is_cancel_requested(thread_id: str) -> bool: + """Return whether ``thread_id`` currently has a pending cancel signal.""" + return manager.is_cancel_requested(thread_id) + + +def get_cancel_state(thread_id: str) -> tuple[int, int] | None: + """Return ``(attempt_count, requested_at_ms)`` for pending cancel state.""" + return manager.cancel_state(thread_id) + + def reset_cancel(thread_id: str) -> None: """Reset the cancel event for ``thread_id`` (called between turns).""" manager.reset(thread_id) +def end_turn(thread_id: str) -> None: + """Force end-of-turn cleanup for lock + cancel state.""" + manager.end_turn(thread_id) + + class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): """Block concurrent prompts on the same thread. @@ -229,7 +278,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo __all__ = [ "BusyMutexMiddleware", + "end_turn", "get_cancel_event", + "get_cancel_state", + "is_cancel_requested", "manager", "request_cancel", "reset_cancel", diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index e04cce1b5..28b197ca2 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -15,7 +15,7 @@ import json import logging from datetime import UTC, datetime -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Request, Response from fastapi.responses import StreamingResponse from sqlalchemy import func, or_ from sqlalchemy.exc import IntegrityError, OperationalError @@ -29,6 +29,12 @@ from app.agents.new_chat.filesystem_selection import ( FilesystemSelection, LocalFilesystemMount, ) +from app.agents.new_chat.middleware.busy_mutex import ( + get_cancel_state, + is_cancel_requested, + manager, + request_cancel, +) from app.config import config from app.db import ( ChatComment, @@ -44,6 +50,7 @@ from app.db import ( ) from app.schemas.new_chat import ( AgentToolInfo, + CancelActiveTurnResponse, LocalFilesystemMountPayload, NewChatMessageRead, NewChatRequest, @@ -60,6 +67,7 @@ from app.schemas.new_chat import ( ThreadListItem, ThreadListResponse, TokenUsageSummary, + TurnStatusResponse, ) from app.services.token_tracking_service import record_token_usage from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat @@ -72,6 +80,9 @@ from app.utils.user_message_multimodal import ( _logger = logging.getLogger(__name__) _background_tasks: set[asyncio.Task] = set() +TURN_CANCELLING_INITIAL_DELAY_MS = 200 +TURN_CANCELLING_BACKOFF_FACTOR = 2 +TURN_CANCELLING_MAX_DELAY_MS = 1500 router = APIRouter() @@ -137,6 +148,72 @@ def _resolve_filesystem_selection( ) +def _compute_turn_cancelling_retry_delay(attempt: int) -> int: + """Bounded exponential delay for TURN_CANCELLING retry hints.""" + if attempt < 1: + attempt = 1 + delay = TURN_CANCELLING_INITIAL_DELAY_MS * ( + TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1) + ) + return min(delay, TURN_CANCELLING_MAX_DELAY_MS) + + +def _build_turn_status_payload(thread_id: int) -> dict[str, object]: + lock = manager.lock_for(str(thread_id)) + if not lock.locked(): + return {"status": "idle"} + + if is_cancel_requested(str(thread_id)): + cancel_state = get_cancel_state(str(thread_id)) + attempt = cancel_state[0] if cancel_state else 1 + retry_after_ms = _compute_turn_cancelling_retry_delay(attempt) + retry_after_at = int(datetime.now(UTC).timestamp() * 1000) + retry_after_ms + return { + "status": "cancelling", + "retry_after_ms": retry_after_ms, + "retry_after_at": retry_after_at, + } + + return {"status": "busy"} + + +def _set_retry_after_headers(response: Response, retry_after_ms: int) -> None: + response.headers["retry-after-ms"] = str(retry_after_ms) + response.headers["Retry-After"] = str(max(1, (retry_after_ms + 999) // 1000)) + + +def _raise_if_thread_busy_for_start(thread_id: int) -> None: + status_payload = _build_turn_status_payload(thread_id) + status = status_payload["status"] + if status == "idle": + return + if status == "cancelling": + retry_after_ms = int(status_payload.get("retry_after_ms") or 0) + detail = { + "errorCode": "TURN_CANCELLING", + "message": "A previous response is still stopping. Please try again in a moment.", + "retry_after_ms": retry_after_ms if retry_after_ms > 0 else None, + "retry_after_at": status_payload.get("retry_after_at"), + } + headers = ( + { + "retry-after-ms": str(retry_after_ms), + "Retry-After": str(max(1, (retry_after_ms + 999) // 1000)), + } + if retry_after_ms > 0 + else None + ) + raise HTTPException(status_code=409, detail=detail, headers=headers) + + raise HTTPException( + status_code=409, + detail={ + "errorCode": "THREAD_BUSY", + "message": "Another response is still finishing for this thread. Please try again in a moment.", + }, + ) + + def _find_pre_turn_checkpoint_id( checkpoint_tuples: list, *, @@ -1476,6 +1553,7 @@ async def handle_new_chat( # Check thread-level access based on visibility await check_thread_access(session, thread, user) + _raise_if_thread_busy_for_start(request.chat_id) filesystem_selection = _resolve_filesystem_selection( mode=request.filesystem_mode, client_platform=request.client_platform, @@ -1550,6 +1628,93 @@ async def handle_new_chat( ) from None +@router.post( + "/threads/{thread_id}/cancel-active-turn", + response_model=CancelActiveTurnResponse, +) +async def cancel_active_turn( + thread_id: int, + response: Response, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Signal cancellation for the currently running turn on ``thread_id``.""" + result = await session.execute( + select(NewChatThread).filter(NewChatThread.id == thread_id) + ) + thread = result.scalars().first() + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + + await check_permission( + session, + user, + thread.search_space_id, + Permission.CHATS_UPDATE.value, + "You don't have permission to update chats in this search space", + ) + await check_thread_access(session, thread, user) + + status_payload = _build_turn_status_payload(thread_id) + if status_payload["status"] == "idle": + return CancelActiveTurnResponse( + status="idle", + error_code="NO_ACTIVE_TURN", + ) + + request_cancel(str(thread_id)) + response.status_code = 202 + updated_payload = _build_turn_status_payload(thread_id) + retry_after_ms = int(updated_payload.get("retry_after_ms") or 0) + retry_after_at = ( + int(updated_payload["retry_after_at"]) + if "retry_after_at" in updated_payload + else None + ) + if retry_after_ms > 0: + _set_retry_after_headers(response, retry_after_ms) + return CancelActiveTurnResponse( + status="cancelling", + error_code="TURN_CANCELLING", + retry_after_ms=retry_after_ms if retry_after_ms > 0 else None, + retry_after_at=retry_after_at, + ) + + +@router.get( + "/threads/{thread_id}/turn-status", + response_model=TurnStatusResponse, +) +async def get_turn_status( + thread_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + result = await session.execute( + select(NewChatThread).filter(NewChatThread.id == thread_id) + ) + thread = result.scalars().first() + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + + await check_permission( + session, + user, + thread.search_space_id, + Permission.CHATS_READ.value, + "You don't have permission to view chats in this search space", + ) + await check_thread_access(session, thread, user) + + status_payload = _build_turn_status_payload(thread_id) + return TurnStatusResponse( + status=status_payload["status"], # type: ignore[arg-type] + active_turn_id=None, + retry_after_ms=status_payload.get("retry_after_ms"), # type: ignore[arg-type] + retry_after_at=status_payload.get("retry_after_at"), # type: ignore[arg-type] + ) + + # ============================================================================= # Chat Regeneration Endpoint (Edit/Reload) # ============================================================================= @@ -1605,6 +1770,7 @@ async def regenerate_response( # Check thread-level access based on visibility await check_thread_access(session, thread, user) + _raise_if_thread_busy_for_start(thread_id) filesystem_selection = _resolve_filesystem_selection( mode=request.filesystem_mode, client_platform=request.client_platform, @@ -2012,6 +2178,7 @@ async def resume_chat( ) await check_thread_access(session, thread, user) + _raise_if_thread_busy_for_start(thread_id) filesystem_selection = _resolve_filesystem_selection( mode=request.filesystem_mode, client_platform=request.client_platform, diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index c7284e901..ec5eefc07 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -335,6 +335,24 @@ class ResumeRequest(BaseModel): local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None +class CancelActiveTurnResponse(BaseModel): + """Response for canceling an active turn on a chat thread.""" + + status: Literal["cancelling", "idle"] + error_code: Literal["TURN_CANCELLING", "NO_ACTIVE_TURN"] + retry_after_ms: int | None = None + retry_after_at: int | None = None + + +class TurnStatusResponse(BaseModel): + """Current turn execution status for a thread.""" + + status: Literal["idle", "busy", "cancelling"] + active_turn_id: str | None = None + retry_after_ms: int | None = None + retry_after_at: int | None = None + + # ============================================================================= # Public Chat Snapshot Schemas # ============================================================================= diff --git a/surfsense_backend/app/services/new_streaming_service.py b/surfsense_backend/app/services/new_streaming_service.py index 842481f1c..55129668c 100644 --- a/surfsense_backend/app/services/new_streaming_service.py +++ b/surfsense_backend/app/services/new_streaming_service.py @@ -565,7 +565,12 @@ class VercelStreamingService: # Error Part # ========================================================================= - def format_error(self, error_text: str, error_code: str | None = None) -> str: + def format_error( + self, + error_text: str, + error_code: str | None = None, + extra: dict[str, object] | None = None, + ) -> str: """ Format an error message. @@ -579,9 +584,11 @@ class VercelStreamingService: Example output: data: {"type":"error","errorText":"Something went wrong","errorCode":"SOME_CODE"} """ - payload: dict[str, str] = {"type": "error", "errorText": error_text} + payload: dict[str, object] = {"type": "error", "errorText": error_text} if error_code: payload["errorCode"] = error_code + if extra: + payload.update(extra) return self._format_sse(payload) # ========================================================================= diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 2afa851b5..63c149771 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -45,6 +45,11 @@ from app.agents.new_chat.memory_extraction import ( extract_and_save_memory, extract_and_save_team_memory, ) +from app.agents.new_chat.middleware.busy_mutex import ( + end_turn, + get_cancel_state, + is_cancel_requested, +) from app.agents.new_chat.middleware.kb_persistence import ( commit_staged_filesystem_state, ) @@ -72,6 +77,18 @@ from app.utils.user_message_multimodal import build_human_message_content _background_tasks: set[asyncio.Task] = set() _perf_log = get_perf_logger() +TURN_CANCELLING_INITIAL_DELAY_MS = 200 +TURN_CANCELLING_BACKOFF_FACTOR = 2 +TURN_CANCELLING_MAX_DELAY_MS = 1500 + + +def _compute_turn_cancelling_retry_delay(attempt: int) -> int: + if attempt < 1: + attempt = 1 + delay = TURN_CANCELLING_INITIAL_DELAY_MS * ( + TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1) + ) + return min(delay, TURN_CANCELLING_MAX_DELAY_MS) def _extract_chunk_parts(chunk: Any) -> dict[str, Any]: @@ -401,15 +418,35 @@ def _classify_stream_exception( exc: Exception, *, flow_label: str, -) -> tuple[str, str, Literal["info", "warn", "error"], bool, str]: +) -> tuple[ + str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None +]: raw = str(exc) if isinstance(exc, BusyError) or "Thread is busy with another request" in raw: + busy_thread_id = str(exc.request_id) if isinstance(exc, BusyError) else None + if busy_thread_id and is_cancel_requested(busy_thread_id): + cancel_state = get_cancel_state(busy_thread_id) + attempt = cancel_state[0] if cancel_state else 1 + retry_after_ms = _compute_turn_cancelling_retry_delay(attempt) + retry_after_at = int(time.time() * 1000) + retry_after_ms + return ( + "thread_busy", + "TURN_CANCELLING", + "info", + True, + "A previous response is still stopping. Please try again in a moment.", + { + "retry_after_ms": retry_after_ms, + "retry_after_at": retry_after_at, + }, + ) return ( "thread_busy", "THREAD_BUSY", "warn", True, "Another response is still finishing for this thread. Please try again in a moment.", + None, ) parsed = _parse_error_payload(raw) @@ -431,6 +468,7 @@ def _classify_stream_exception( "warn", True, "This model is temporarily rate-limited. Please try again in a few seconds or switch models.", + None, ) return ( @@ -439,6 +477,7 @@ def _classify_stream_exception( "error", False, f"Error during {flow_label}: {raw}", + None, ) @@ -470,7 +509,7 @@ def _emit_stream_terminal_error( message=message, extra=extra, ) - return streaming_service.format_error(message, error_code=error_code) + return streaming_service.format_error(message, error_code=error_code, extra=extra) def _legacy_match_lc_id( @@ -2497,6 +2536,7 @@ async def stream_new_chat( "turn-info", {"chat_turn_id": stream_result.turn_id}, ) + yield streaming_service.format_data("turn-status", {"status": "busy"}) # Initial thinking step - analyzing the request if mentioned_surfsense_docs: @@ -2805,6 +2845,7 @@ async def stream_new_chat( task.add_done_callback(_background_tasks.discard) # Finish the step and message + yield streaming_service.format_data("turn-status", {"status": "idle"}) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() @@ -2819,11 +2860,19 @@ async def stream_new_chat( severity, is_expected, user_message, + error_extra, ) = _classify_stream_exception(e, flow_label="chat") error_message = f"Error during chat: {e!s}" print(f"[stream_new_chat] {error_message}") print(f"[stream_new_chat] Exception type: {type(e).__name__}") print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}") + if error_code == "TURN_CANCELLING": + status_payload: dict[str, Any] = {"status": "cancelling"} + if error_extra: + status_payload.update(error_extra) + yield streaming_service.format_data("turn-status", status_payload) + else: + yield streaming_service.format_data("turn-status", {"status": "busy"}) yield _emit_stream_error( message=user_message, @@ -2831,7 +2880,9 @@ async def stream_new_chat( error_code=error_code, severity=severity, is_expected=is_expected, + extra=error_extra, ) + yield streaming_service.format_data("turn-status", {"status": "idle"}) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() @@ -2847,6 +2898,10 @@ async def stream_new_chat( # (CancelledError is a BaseException), and the rest of the # finally block — including session.close() — would never run. with anyio.CancelScope(shield=True): + # Authoritative fallback cleanup for lock/cancel state. Middleware + # teardown can be skipped on some client-abort paths. + end_turn(str(chat_id)) + # Release premium reservation if not finalized if _premium_request_id and _premium_reserved > 0 and user_id: try: @@ -3206,6 +3261,7 @@ async def stream_resume_chat( "turn-info", {"chat_turn_id": stream_result.turn_id}, ) + yield streaming_service.format_data("turn-status", {"status": "busy"}) _t_stream_start = time.perf_counter() _first_event_logged = False @@ -3305,6 +3361,7 @@ async def stream_resume_chat( }, ) + yield streaming_service.format_data("turn-status", {"status": "idle"}) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() @@ -3318,23 +3375,37 @@ async def stream_resume_chat( severity, is_expected, user_message, + error_extra, ) = _classify_stream_exception(e, flow_label="resume") error_message = f"Error during resume: {e!s}" print(f"[stream_resume_chat] {error_message}") print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}") + if error_code == "TURN_CANCELLING": + status_payload: dict[str, Any] = {"status": "cancelling"} + if error_extra: + status_payload.update(error_extra) + yield streaming_service.format_data("turn-status", status_payload) + else: + yield streaming_service.format_data("turn-status", {"status": "busy"}) yield _emit_stream_error( message=user_message, error_kind=error_kind, error_code=error_code, severity=severity, is_expected=is_expected, + extra=error_extra, ) + yield streaming_service.format_data("turn-status", {"status": "idle"}) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() finally: with anyio.CancelScope(shield=True): + # Authoritative fallback cleanup for lock/cancel state. Middleware + # teardown can be skipped on some client-abort paths. + end_turn(str(chat_id)) + # Release premium reservation if not finalized if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id: try: diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py index 0c7bf17f6..c923dc499 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py @@ -7,7 +7,9 @@ import pytest from app.agents.new_chat.errors import BusyError from app.agents.new_chat.middleware.busy_mutex import ( BusyMutexMiddleware, + end_turn, get_cancel_event, + is_cancel_requested, manager, request_cancel, reset_cancel, @@ -88,3 +90,31 @@ async def test_no_thread_id_skipped_when_not_required() -> None: def test_reset_cancel_idempotent() -> None: # Should not raise even if event was never created reset_cancel("never-seen") + + +def test_request_cancel_creates_event_for_unseen_thread() -> None: + thread_id = "never-seen-cancel" + reset_cancel(thread_id) + + assert request_cancel(thread_id) is True + assert get_cancel_event(thread_id).is_set() + assert is_cancel_requested(thread_id) is True + + +@pytest.mark.asyncio +async def test_end_turn_force_clears_lock_and_cancel_state() -> None: + thread_id = "forced-end-turn" + mw = BusyMutexMiddleware() + runtime = _Runtime(thread_id) + + await mw.abefore_agent({}, runtime) + assert manager.lock_for(thread_id).locked() + + request_cancel(thread_id) + assert is_cancel_requested(thread_id) is True + + end_turn(thread_id) + + assert not manager.lock_for(thread_id).locked() + assert not get_cancel_event(thread_id).is_set() + assert is_cancel_requested(thread_id) is False diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 86ea7edd1..a1345c15c 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -8,6 +8,7 @@ import pytest import app.tasks.chat.stream_new_chat as stream_new_chat_module from app.agents.new_chat.errors import BusyError +from app.agents.new_chat.middleware.busy_mutex import request_cancel, reset_cancel from app.tasks.chat.stream_new_chat import ( StreamResult, _classify_stream_exception, @@ -147,7 +148,7 @@ def test_stream_exception_classifies_rate_limited(): exc = Exception( '{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}' ) - kind, code, severity, is_expected, user_message = _classify_stream_exception( + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( exc, flow_label="chat" ) assert kind == "rate_limited" @@ -155,11 +156,12 @@ def test_stream_exception_classifies_rate_limited(): assert severity == "warn" assert is_expected is True assert "temporarily rate-limited" in user_message + assert extra is None def test_stream_exception_classifies_thread_busy(): exc = BusyError(request_id="thread-123") - kind, code, severity, is_expected, user_message = _classify_stream_exception( + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( exc, flow_label="chat" ) assert kind == "thread_busy" @@ -167,11 +169,12 @@ def test_stream_exception_classifies_thread_busy(): assert severity == "warn" assert is_expected is True assert "still finishing for this thread" in user_message + assert extra is None def test_stream_exception_classifies_thread_busy_from_message(): exc = Exception("Thread is busy with another request") - kind, code, severity, is_expected, user_message = _classify_stream_exception( + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( exc, flow_label="chat" ) assert kind == "thread_busy" @@ -179,6 +182,24 @@ def test_stream_exception_classifies_thread_busy_from_message(): assert severity == "warn" assert is_expected is True assert "still finishing for this thread" in user_message + assert extra is None + + +def test_stream_exception_classifies_turn_cancelling_when_cancel_requested(): + thread_id = "thread-cancelling-1" + reset_cancel(thread_id) + request_cancel(thread_id) + exc = BusyError(request_id=thread_id) + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( + exc, flow_label="chat" + ) + assert kind == "thread_busy" + assert code == "TURN_CANCELLING" + assert severity == "info" + assert is_expected is True + assert "stopping" in user_message + assert isinstance(extra, dict) + assert "retry_after_ms" in extra def test_premium_classification_is_error_code_driven(): @@ -219,33 +240,33 @@ def test_toast_only_pre_accept_policy_has_no_inline_failed_marker(): def test_network_send_failures_use_unified_retry_toast_message(): classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts" classifier_source = classifier_path.read_text(encoding="utf-8") - page_path = ( - Path(__file__).resolve().parents[3] - / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx" + request_errors_path = ( + Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-request-errors.ts" ) - page_source = page_path.read_text(encoding="utf-8") + request_errors_source = request_errors_path.read_text(encoding="utf-8") assert '"send_failed_pre_accept"' in classifier_source assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source + assert 'errorCode === "TURN_CANCELLING"' in classifier_source assert "if (withCode.code) return withCode.code;" in classifier_source assert 'userMessage: "Message not sent. Please retry."' in classifier_source assert 'userMessage: "Connection issue. Please try again."' in classifier_source - assert "tagPreAcceptSendFailure(error)" in page_source - assert "const passthroughCodes = new Set([" in page_source - assert '"PREMIUM_QUOTA_EXHAUSTED"' in page_source - assert '"THREAD_BUSY"' in page_source - assert '"AUTH_EXPIRED"' in page_source - assert '"UNAUTHORIZED"' in page_source - assert '"RATE_LIMITED"' in page_source - assert '"NETWORK_ERROR"' in page_source - assert '"STREAM_PARSE_ERROR"' in page_source - assert '"TOOL_EXECUTION_ERROR"' in page_source - assert '"PERSIST_MESSAGE_FAILED"' in page_source - assert '"SERVER_ERROR"' in page_source - assert "passthroughCodes.has(existingCode)" in page_source - assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in page_source - assert 'errorCode: "NETWORK_ERROR"' not in page_source - assert "Failed to start chat. Please try again." not in page_source + assert "const passthroughCodes = new Set([" in request_errors_source + assert '"PREMIUM_QUOTA_EXHAUSTED"' in request_errors_source + assert '"THREAD_BUSY"' in request_errors_source + assert '"TURN_CANCELLING"' in request_errors_source + assert '"AUTH_EXPIRED"' in request_errors_source + assert '"UNAUTHORIZED"' in request_errors_source + assert '"RATE_LIMITED"' in request_errors_source + assert '"NETWORK_ERROR"' in request_errors_source + assert '"STREAM_PARSE_ERROR"' in request_errors_source + assert '"TOOL_EXECUTION_ERROR"' in request_errors_source + assert '"PERSIST_MESSAGE_FAILED"' in request_errors_source + assert '"SERVER_ERROR"' in request_errors_source + assert "passthroughCodes.has(existingCode)" in request_errors_source + assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in request_errors_source + assert 'errorCode: "NETWORK_ERROR"' not in request_errors_source + assert "Failed to start chat. Please try again." not in classifier_source def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows(): @@ -269,3 +290,75 @@ def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows() # New flow persists only when accepted and not already persisted. assert "if (newAccepted && !userPersisted) {" in source + assert "const fetchWithTurnCancellingRetry = useCallback(" in source + assert "computeFallbackTurnCancellingRetryDelay" in source + assert 'withMeta.errorCode === "TURN_CANCELLING"' in source + assert 'withMeta.errorCode === "THREAD_BUSY"' in source + assert "await fetchWithTurnCancellingRetry(() =>" in source + + +def test_cancel_active_turn_route_contract_exists(): + routes_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_backend/app/routes/new_chat_routes.py" + ) + source = routes_path.read_text(encoding="utf-8") + + assert '@router.post(\n "/threads/{thread_id}/cancel-active-turn",' in source + assert "response_model=CancelActiveTurnResponse" in source + assert 'status="cancelling",' in source + assert 'error_code="TURN_CANCELLING",' in source + assert "retry_after_ms=retry_after_ms if retry_after_ms > 0 else None," in source + assert "retry_after_at=" in source + assert 'status="idle",' in source + assert 'error_code="NO_ACTIVE_TURN",' in source + + +def test_turn_status_route_contract_exists(): + routes_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_backend/app/routes/new_chat_routes.py" + ) + source = routes_path.read_text(encoding="utf-8") + + assert '@router.get(\n "/threads/{thread_id}/turn-status",' in source + assert "response_model=TurnStatusResponse" in source + assert "_build_turn_status_payload(thread_id)" in source + assert "Permission.CHATS_READ.value" in source + assert "_raise_if_thread_busy_for_start(" in source + + +def test_turn_cancelling_retry_policy_contract_exists(): + routes_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_backend/app/routes/new_chat_routes.py" + ) + source = routes_path.read_text(encoding="utf-8") + + assert "TURN_CANCELLING_INITIAL_DELAY_MS = 200" in source + assert "TURN_CANCELLING_BACKOFF_FACTOR = 2" in source + assert "TURN_CANCELLING_MAX_DELAY_MS = 1500" in source + assert "def _compute_turn_cancelling_retry_delay(" in source + assert "retry-after-ms" in source + assert '"Retry-After"' in source + assert '"errorCode": "TURN_CANCELLING"' in source + + +def test_turn_status_sse_contract_exists(): + stream_source = ( + Path(__file__).resolve().parents[3] + / "surfsense_backend/app/tasks/chat/stream_new_chat.py" + ).read_text(encoding="utf-8") + state_source = ( + Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/streaming-state.ts" + ).read_text(encoding="utf-8") + pipeline_source = ( + Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/stream-pipeline.ts" + ).read_text(encoding="utf-8") + + assert '"turn-status"' in stream_source + assert '"status": "busy"' in stream_source + assert '"status": "idle"' in stream_source + assert "type: \"data-turn-status\"" in state_source + assert "case \"data-turn-status\":" in pipeline_source + assert "end_turn(str(chat_id))" in stream_source diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 02c2914be..1b25ca431 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -182,6 +182,20 @@ function extractMentionedDocuments(content: unknown): MentionedDocumentInfo[] { * ``stream_new_chat.py``) keep the JSON from ballooning. */ const TOOLS_WITH_UI_ALL: ToolUIGate = "all"; +const TURN_CANCELLING_INITIAL_DELAY_MS = 200; +const TURN_CANCELLING_BACKOFF_FACTOR = 2; +const TURN_CANCELLING_MAX_DELAY_MS = 1500; +const RECENT_CANCEL_WINDOW_MS = 5_000; + +function sleep(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)); +} + +function computeFallbackTurnCancellingRetryDelay(attempt: number): number { + const safeAttempt = Math.max(1, attempt); + const raw = TURN_CANCELLING_INITIAL_DELAY_MS * TURN_CANCELLING_BACKOFF_FACTOR ** (safeAttempt - 1); + return Math.min(raw, TURN_CANCELLING_MAX_DELAY_MS); +} export default function NewChatPage() { const params = useParams(); @@ -193,6 +207,7 @@ export default function NewChatPage() { const [isRunning, setIsRunning] = useState(false); const [tokenUsageStore] = useState(() => createTokenUsageStore()); const abortControllerRef = useRef(null); + const recentCancelRequestedAtRef = useRef(0); const [pendingInterrupt, setPendingInterrupt] = useState<{ threadId: number; assistantMsgId: string; @@ -598,6 +613,36 @@ export default function NewChatPage() { [handleChatFailure] ); + const fetchWithTurnCancellingRetry = useCallback( + async (runFetch: () => Promise) => { + const maxAttempts = 4; + for (let attempt = 1; attempt <= maxAttempts; attempt += 1) { + const response = await runFetch(); + if (response.ok) { + return response; + } + const error = await toHttpResponseError(response); + const withMeta = error as Error & { errorCode?: string; retryAfterMs?: number }; + const isTurnCancelling = withMeta.errorCode === "TURN_CANCELLING"; + const isRecentThreadBusyAfterCancel = + withMeta.errorCode === "THREAD_BUSY" && + Date.now() - recentCancelRequestedAtRef.current <= RECENT_CANCEL_WINDOW_MS; + if ((isTurnCancelling || isRecentThreadBusyAfterCancel) && attempt < maxAttempts) { + const waitMs = + withMeta.retryAfterMs ?? computeFallbackTurnCancellingRetryDelay(attempt); + await sleep(waitMs); + continue; + } + throw error; + } + + throw Object.assign(new Error("Turn cancellation retry limit exceeded"), { + errorCode: "TURN_CANCELLING", + }); + }, + [] + ); + // Initialize thread and load messages // For new chats (no urlChatId), we use lazy creation - thread is created on first message const initializeThread = useCallback(async () => { @@ -767,12 +812,39 @@ export default function NewChatPage() { // Cancel ongoing request const cancelRun = useCallback(async () => { + if (threadId) { + const token = getBearerToken(); + if (token) { + const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; + try { + const response = await fetch( + `${backendUrl}/api/v1/threads/${threadId}/cancel-active-turn`, + { + method: "POST", + headers: { + Authorization: `Bearer ${token}`, + }, + } + ); + if (response.ok) { + const payload = (await response.json()) as { + error_code?: string; + }; + if (payload.error_code === "TURN_CANCELLING") { + recentCancelRequestedAtRef.current = Date.now(); + } + } + } catch (error) { + console.warn("[NewChatPage] Failed to signal cancel-active-turn:", error); + } + } + } if (abortControllerRef.current) { abortControllerRef.current.abort(); abortControllerRef.current = null; } setIsRunning(false); - }, []); + }, [threadId]); // Handle new message from user const onNew = useCallback( @@ -971,29 +1043,33 @@ export default function NewChatPage() { setMentionedDocuments([]); } - const response = await fetch(`${backendUrl}/api/v1/new_chat`, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${token}`, - }, - body: JSON.stringify({ - chat_id: currentThreadId, - user_query: userQuery.trim(), - search_space_id: searchSpaceId, - filesystem_mode: selection.filesystem_mode, - client_platform: selection.client_platform, - local_filesystem_mounts: selection.local_filesystem_mounts, - messages: messageHistory, - mentioned_document_ids: hasDocumentIds ? mentionedDocumentIds.document_ids : undefined, - mentioned_surfsense_doc_ids: hasSurfsenseDocIds - ? mentionedDocumentIds.surfsense_doc_ids - : undefined, - disabled_tools: disabledTools.length > 0 ? disabledTools : undefined, - ...(userImages.length > 0 ? { user_images: userImages } : {}), - }), - signal: controller.signal, - }); + const response = await fetchWithTurnCancellingRetry(() => + fetch(`${backendUrl}/api/v1/new_chat`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify({ + chat_id: currentThreadId, + user_query: userQuery.trim(), + search_space_id: searchSpaceId, + filesystem_mode: selection.filesystem_mode, + client_platform: selection.client_platform, + local_filesystem_mounts: selection.local_filesystem_mounts, + messages: messageHistory, + mentioned_document_ids: hasDocumentIds + ? mentionedDocumentIds.document_ids + : undefined, + mentioned_surfsense_doc_ids: hasSurfsenseDocIds + ? mentionedDocumentIds.surfsense_doc_ids + : undefined, + disabled_tools: disabledTools.length > 0 ? disabledTools : undefined, + ...(userImages.length > 0 ? { user_images: userImages } : {}), + }), + signal: controller.signal, + }) + ); if (!response.ok) { throw await toHttpResponseError(response); @@ -1033,6 +1109,11 @@ export default function NewChatPage() { tokenUsageData = data; tokenUsageStore.set(assistantMsgId, data); }, + onTurnStatus: (data) => { + if (data.status === "cancelling") { + recentCancelRequestedAtRef.current = Date.now(); + } + }, onToolOutputAvailable: (event, sharedCtx) => { if (event.output?.status === "pending" && event.output?.podcast_id) { const idx = sharedCtx.toolCallIndices.get(event.toolCallId); @@ -1257,6 +1338,7 @@ export default function NewChatPage() { tokenUsageStore, pendingUserImageUrls, setPendingUserImageUrls, + fetchWithTurnCancellingRetry, handleStreamTerminalError, handleChatFailure, persistAssistantTurn, @@ -1354,21 +1436,23 @@ export default function NewChatPage() { try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; const selection = await getAgentFilesystemSelection(searchSpaceId); - const response = await fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${token}`, - }, - body: JSON.stringify({ - search_space_id: searchSpaceId, - decisions, - filesystem_mode: selection.filesystem_mode, - client_platform: selection.client_platform, - local_filesystem_mounts: selection.local_filesystem_mounts, - }), - signal: controller.signal, - }); + const response = await fetchWithTurnCancellingRetry(() => + fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify({ + search_space_id: searchSpaceId, + decisions, + filesystem_mode: selection.filesystem_mode, + client_platform: selection.client_platform, + local_filesystem_mounts: selection.local_filesystem_mounts, + }), + signal: controller.signal, + }) + ); if (!response.ok) { throw await toHttpResponseError(response); @@ -1399,6 +1483,11 @@ export default function NewChatPage() { tokenUsageData = data; tokenUsageStore.set(assistantMsgId, data); }, + onTurnStatus: (data) => { + if (data.status === "cancelling") { + recentCancelRequestedAtRef.current = Date.now(); + } + }, }) ) { return; @@ -1496,6 +1585,7 @@ export default function NewChatPage() { searchSpaceId, queryClient, tokenUsageStore, + fetchWithTurnCancellingRetry, handleStreamTerminalError, persistAssistantTurn, ] @@ -1700,15 +1790,17 @@ export default function NewChatPage() { requestBody.revert_actions = true; } } - const response = await fetch(getRegenerateUrl(threadId), { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${token}`, - }, - body: JSON.stringify(requestBody), - signal: controller.signal, - }); + const response = await fetchWithTurnCancellingRetry(() => + fetch(getRegenerateUrl(threadId), { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify(requestBody), + signal: controller.signal, + }) + ); if (!response.ok) { throw await toHttpResponseError(response); @@ -1774,6 +1866,11 @@ export default function NewChatPage() { tokenUsageData = data; tokenUsageStore.set(assistantMsgId, data); }, + onTurnStatus: (data) => { + if (data.status === "cancelling") { + recentCancelRequestedAtRef.current = Date.now(); + } + }, onToolOutputAvailable: (event, sharedCtx) => { if (event.output?.status === "pending" && event.output?.podcast_id) { const idx = sharedCtx.toolCallIndices.get(event.toolCallId); @@ -1945,6 +2042,7 @@ export default function NewChatPage() { setMessageDocumentsMap, queryClient, tokenUsageStore, + fetchWithTurnCancellingRetry, handleStreamTerminalError, persistAssistantTurn, persistUserTurn, diff --git a/surfsense_web/lib/chat/chat-error-classifier.ts b/surfsense_web/lib/chat/chat-error-classifier.ts index 57341a4c3..7dfbfc1a1 100644 --- a/surfsense_web/lib/chat/chat-error-classifier.ts +++ b/surfsense_web/lib/chat/chat-error-classifier.ts @@ -147,6 +147,22 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } + if ( + errorCode === "TURN_CANCELLING" + ) { + return { + kind: "thread_busy", + channel: "toast", + severity: "info", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: "A previous response is still stopping. Please try again in a moment.", + rawMessage, + errorCode: errorCode ?? "TURN_CANCELLING", + details: { flow: input.flow }, + }; + } + if ( errorCode === "THREAD_BUSY" ) { @@ -156,7 +172,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError severity: "warn", telemetryEvent: "chat_blocked", isExpected: true, - userMessage: "A previous response is still stopping. Please try again in a moment.", + userMessage: "Another response is still finishing for this thread. Please try again in a moment.", rawMessage, errorCode: errorCode ?? "THREAD_BUSY", details: { flow: input.flow }, diff --git a/surfsense_web/lib/chat/chat-request-errors.ts b/surfsense_web/lib/chat/chat-request-errors.ts index 3026e8203..708831354 100644 --- a/surfsense_web/lib/chat/chat-request-errors.ts +++ b/surfsense_web/lib/chat/chat-request-errors.ts @@ -1,6 +1,6 @@ export async function toHttpResponseError( response: Response -): Promise { +): Promise { const statusDefaultCode = response.status === 409 ? "THREAD_BUSY" @@ -52,13 +52,37 @@ export async function toHttpResponseError( : undefined; const errorCode = detailCode ?? topLevelCode ?? statusDefaultCode; + + const detailRetryAfterMs = + typeof detailObject?.retry_after_ms === "number" + ? detailObject.retry_after_ms + : typeof detailObject?.retryAfterMs === "number" + ? detailObject.retryAfterMs + : undefined; + const topRetryAfterMs = + typeof parsedBody?.retry_after_ms === "number" + ? parsedBody.retry_after_ms + : typeof parsedBody?.retryAfterMs === "number" + ? parsedBody.retryAfterMs + : undefined; + const headerRetryAfterMsRaw = response.headers.get("retry-after-ms"); + const headerRetryAfterMs = headerRetryAfterMsRaw ? Number.parseFloat(headerRetryAfterMsRaw) : NaN; + const retryAfterHeader = response.headers.get("retry-after"); + const retryAfterSeconds = retryAfterHeader ? Number.parseFloat(retryAfterHeader) : NaN; + const retryAfterMsFromHeader = Number.isFinite(headerRetryAfterMs) + ? Math.max(0, Math.round(headerRetryAfterMs)) + : Number.isFinite(retryAfterSeconds) + ? Math.max(0, Math.round(retryAfterSeconds * 1000)) + : undefined; + const retryAfterMs = + detailRetryAfterMs ?? topRetryAfterMs ?? retryAfterMsFromHeader ?? undefined; const message = detailNestedMessage ?? detailMessage ?? topLevelMessage ?? `Backend error: ${response.status}`; - return Object.assign(new Error(message), { errorCode }); + return Object.assign(new Error(message), { errorCode, retryAfterMs }); } export function tagPreAcceptSendFailure(error: unknown): unknown { @@ -68,6 +92,7 @@ export function tagPreAcceptSendFailure(error: unknown): unknown { const passthroughCodes = new Set([ "PREMIUM_QUOTA_EXHAUSTED", "THREAD_BUSY", + "TURN_CANCELLING", "AUTH_EXPIRED", "UNAUTHORIZED", "RATE_LIMITED", diff --git a/surfsense_web/lib/chat/stream-pipeline.ts b/surfsense_web/lib/chat/stream-pipeline.ts index 8957bdea3..c9118f949 100644 --- a/surfsense_web/lib/chat/stream-pipeline.ts +++ b/surfsense_web/lib/chat/stream-pipeline.ts @@ -21,6 +21,7 @@ export type SharedStreamEventContext = { scheduleFlush: () => void; forceFlush: () => void; onTokenUsage?: (data: Extract["data"]) => void; + onTurnStatus?: (data: Extract["data"]) => void; onToolOutputAvailable?: ( event: Extract, context: { @@ -173,6 +174,10 @@ export function processSharedStreamEvent(parsed: SSEEvent, context: SharedStream context.onTokenUsage?.(parsed.data); return true; + case "data-turn-status": + context.onTurnStatus?.(parsed.data); + return true; + case "error": throw toStreamTerminalError(parsed); diff --git a/surfsense_web/lib/chat/streaming-state.ts b/surfsense_web/lib/chat/streaming-state.ts index 445bbe83d..80e7bffbe 100644 --- a/surfsense_web/lib/chat/streaming-state.ts +++ b/surfsense_web/lib/chat/streaming-state.ts @@ -528,6 +528,14 @@ export type SSEEvent = }>; }; } + | { + type: "data-turn-status"; + data: { + status: "idle" | "busy" | "cancelling"; + retry_after_ms?: number; + retry_after_at?: number; + }; + } | { type: "data-token-usage"; data: {