refactor(chat): implement turn cancellation and status management in new chat routes for improved user experience and error handling

This commit is contained in:
Anish Sarkar 2026-05-01 01:47:52 +05:30
parent 4056bd1d69
commit af66fbf106
12 changed files with 671 additions and 81 deletions

View file

@ -33,6 +33,7 @@ from __future__ import annotations
import asyncio
import logging
import time
import weakref
from typing import Any
@ -58,6 +59,8 @@ class _ThreadLockManager:
weakref.WeakValueDictionary()
)
self._cancel_events: dict[str, asyncio.Event] = {}
self._cancel_requested_at_ms: dict[str, int] = {}
self._cancel_attempt_count: dict[str, int] = {}
def lock_for(self, thread_id: str) -> asyncio.Lock:
lock = self._locks.get(thread_id)
@ -76,14 +79,45 @@ class _ThreadLockManager:
def request_cancel(self, thread_id: str) -> bool:
event = self._cancel_events.get(thread_id)
if event is None:
return False
event = asyncio.Event()
self._cancel_events[thread_id] = event
event.set()
now_ms = int(time.time() * 1000)
self._cancel_requested_at_ms[thread_id] = now_ms
self._cancel_attempt_count[thread_id] = (
self._cancel_attempt_count.get(thread_id, 0) + 1
)
return True
def is_cancel_requested(self, thread_id: str) -> bool:
event = self._cancel_events.get(thread_id)
return bool(event and event.is_set())
def cancel_state(self, thread_id: str) -> tuple[int, int] | None:
if not self.is_cancel_requested(thread_id):
return None
attempts = self._cancel_attempt_count.get(thread_id, 1)
requested_at_ms = self._cancel_requested_at_ms.get(thread_id, 0)
return attempts, requested_at_ms
def reset(self, thread_id: str) -> None:
event = self._cancel_events.get(thread_id)
if event is not None:
event.clear()
self._cancel_requested_at_ms.pop(thread_id, None)
self._cancel_attempt_count.pop(thread_id, None)
def end_turn(self, thread_id: str) -> None:
"""Best-effort terminal cleanup for a thread turn.
This is intentionally idempotent and safe to call from outer stream
finally-blocks where middleware teardown might be skipped due to abort
or disconnect edge-cases.
"""
lock = self._locks.get(thread_id)
if lock is not None and lock.locked():
lock.release()
self.reset(thread_id)
# Module-level singleton — process-local but reused across all agent
@ -98,15 +132,30 @@ def get_cancel_event(thread_id: str) -> asyncio.Event:
def request_cancel(thread_id: str) -> bool:
"""Trip the cancel event for ``thread_id``. Returns True if found."""
"""Trip the cancel event for ``thread_id``. Always returns True."""
return manager.request_cancel(thread_id)
def is_cancel_requested(thread_id: str) -> bool:
"""Return whether ``thread_id`` currently has a pending cancel signal."""
return manager.is_cancel_requested(thread_id)
def get_cancel_state(thread_id: str) -> tuple[int, int] | None:
"""Return ``(attempt_count, requested_at_ms)`` for pending cancel state."""
return manager.cancel_state(thread_id)
def reset_cancel(thread_id: str) -> None:
"""Reset the cancel event for ``thread_id`` (called between turns)."""
manager.reset(thread_id)
def end_turn(thread_id: str) -> None:
"""Force end-of-turn cleanup for lock + cancel state."""
manager.end_turn(thread_id)
class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
"""Block concurrent prompts on the same thread.
@ -229,7 +278,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
__all__ = [
"BusyMutexMiddleware",
"end_turn",
"get_cancel_event",
"get_cancel_state",
"is_cancel_requested",
"manager",
"request_cancel",
"reset_cancel",

View file

@ -15,7 +15,7 @@ import json
import logging
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from sqlalchemy import func, or_
from sqlalchemy.exc import IntegrityError, OperationalError
@ -29,6 +29,12 @@ from app.agents.new_chat.filesystem_selection import (
FilesystemSelection,
LocalFilesystemMount,
)
from app.agents.new_chat.middleware.busy_mutex import (
get_cancel_state,
is_cancel_requested,
manager,
request_cancel,
)
from app.config import config
from app.db import (
ChatComment,
@ -44,6 +50,7 @@ from app.db import (
)
from app.schemas.new_chat import (
AgentToolInfo,
CancelActiveTurnResponse,
LocalFilesystemMountPayload,
NewChatMessageRead,
NewChatRequest,
@ -60,6 +67,7 @@ from app.schemas.new_chat import (
ThreadListItem,
ThreadListResponse,
TokenUsageSummary,
TurnStatusResponse,
)
from app.services.token_tracking_service import record_token_usage
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
@ -72,6 +80,9 @@ from app.utils.user_message_multimodal import (
_logger = logging.getLogger(__name__)
_background_tasks: set[asyncio.Task] = set()
TURN_CANCELLING_INITIAL_DELAY_MS = 200
TURN_CANCELLING_BACKOFF_FACTOR = 2
TURN_CANCELLING_MAX_DELAY_MS = 1500
router = APIRouter()
@ -137,6 +148,72 @@ def _resolve_filesystem_selection(
)
def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
"""Bounded exponential delay for TURN_CANCELLING retry hints."""
if attempt < 1:
attempt = 1
delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
)
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
def _build_turn_status_payload(thread_id: int) -> dict[str, object]:
lock = manager.lock_for(str(thread_id))
if not lock.locked():
return {"status": "idle"}
if is_cancel_requested(str(thread_id)):
cancel_state = get_cancel_state(str(thread_id))
attempt = cancel_state[0] if cancel_state else 1
retry_after_ms = _compute_turn_cancelling_retry_delay(attempt)
retry_after_at = int(datetime.now(UTC).timestamp() * 1000) + retry_after_ms
return {
"status": "cancelling",
"retry_after_ms": retry_after_ms,
"retry_after_at": retry_after_at,
}
return {"status": "busy"}
def _set_retry_after_headers(response: Response, retry_after_ms: int) -> None:
response.headers["retry-after-ms"] = str(retry_after_ms)
response.headers["Retry-After"] = str(max(1, (retry_after_ms + 999) // 1000))
def _raise_if_thread_busy_for_start(thread_id: int) -> None:
status_payload = _build_turn_status_payload(thread_id)
status = status_payload["status"]
if status == "idle":
return
if status == "cancelling":
retry_after_ms = int(status_payload.get("retry_after_ms") or 0)
detail = {
"errorCode": "TURN_CANCELLING",
"message": "A previous response is still stopping. Please try again in a moment.",
"retry_after_ms": retry_after_ms if retry_after_ms > 0 else None,
"retry_after_at": status_payload.get("retry_after_at"),
}
headers = (
{
"retry-after-ms": str(retry_after_ms),
"Retry-After": str(max(1, (retry_after_ms + 999) // 1000)),
}
if retry_after_ms > 0
else None
)
raise HTTPException(status_code=409, detail=detail, headers=headers)
raise HTTPException(
status_code=409,
detail={
"errorCode": "THREAD_BUSY",
"message": "Another response is still finishing for this thread. Please try again in a moment.",
},
)
def _find_pre_turn_checkpoint_id(
checkpoint_tuples: list,
*,
@ -1476,6 +1553,7 @@ async def handle_new_chat(
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
_raise_if_thread_busy_for_start(request.chat_id)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,
@ -1550,6 +1628,93 @@ async def handle_new_chat(
) from None
@router.post(
"/threads/{thread_id}/cancel-active-turn",
response_model=CancelActiveTurnResponse,
)
async def cancel_active_turn(
thread_id: int,
response: Response,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Signal cancellation for the currently running turn on ``thread_id``."""
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
)
await check_thread_access(session, thread, user)
status_payload = _build_turn_status_payload(thread_id)
if status_payload["status"] == "idle":
return CancelActiveTurnResponse(
status="idle",
error_code="NO_ACTIVE_TURN",
)
request_cancel(str(thread_id))
response.status_code = 202
updated_payload = _build_turn_status_payload(thread_id)
retry_after_ms = int(updated_payload.get("retry_after_ms") or 0)
retry_after_at = (
int(updated_payload["retry_after_at"])
if "retry_after_at" in updated_payload
else None
)
if retry_after_ms > 0:
_set_retry_after_headers(response, retry_after_ms)
return CancelActiveTurnResponse(
status="cancelling",
error_code="TURN_CANCELLING",
retry_after_ms=retry_after_ms if retry_after_ms > 0 else None,
retry_after_at=retry_after_at,
)
@router.get(
"/threads/{thread_id}/turn-status",
response_model=TurnStatusResponse,
)
async def get_turn_status(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to view chats in this search space",
)
await check_thread_access(session, thread, user)
status_payload = _build_turn_status_payload(thread_id)
return TurnStatusResponse(
status=status_payload["status"], # type: ignore[arg-type]
active_turn_id=None,
retry_after_ms=status_payload.get("retry_after_ms"), # type: ignore[arg-type]
retry_after_at=status_payload.get("retry_after_at"), # type: ignore[arg-type]
)
# =============================================================================
# Chat Regeneration Endpoint (Edit/Reload)
# =============================================================================
@ -1605,6 +1770,7 @@ async def regenerate_response(
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
_raise_if_thread_busy_for_start(thread_id)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,
@ -2012,6 +2178,7 @@ async def resume_chat(
)
await check_thread_access(session, thread, user)
_raise_if_thread_busy_for_start(thread_id)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,

View file

@ -335,6 +335,24 @@ class ResumeRequest(BaseModel):
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
class CancelActiveTurnResponse(BaseModel):
"""Response for canceling an active turn on a chat thread."""
status: Literal["cancelling", "idle"]
error_code: Literal["TURN_CANCELLING", "NO_ACTIVE_TURN"]
retry_after_ms: int | None = None
retry_after_at: int | None = None
class TurnStatusResponse(BaseModel):
"""Current turn execution status for a thread."""
status: Literal["idle", "busy", "cancelling"]
active_turn_id: str | None = None
retry_after_ms: int | None = None
retry_after_at: int | None = None
# =============================================================================
# Public Chat Snapshot Schemas
# =============================================================================

View file

@ -565,7 +565,12 @@ class VercelStreamingService:
# Error Part
# =========================================================================
def format_error(self, error_text: str, error_code: str | None = None) -> str:
def format_error(
self,
error_text: str,
error_code: str | None = None,
extra: dict[str, object] | None = None,
) -> str:
"""
Format an error message.
@ -579,9 +584,11 @@ class VercelStreamingService:
Example output:
data: {"type":"error","errorText":"Something went wrong","errorCode":"SOME_CODE"}
"""
payload: dict[str, str] = {"type": "error", "errorText": error_text}
payload: dict[str, object] = {"type": "error", "errorText": error_text}
if error_code:
payload["errorCode"] = error_code
if extra:
payload.update(extra)
return self._format_sse(payload)
# =========================================================================

View file

@ -45,6 +45,11 @@ from app.agents.new_chat.memory_extraction import (
extract_and_save_memory,
extract_and_save_team_memory,
)
from app.agents.new_chat.middleware.busy_mutex import (
end_turn,
get_cancel_state,
is_cancel_requested,
)
from app.agents.new_chat.middleware.kb_persistence import (
commit_staged_filesystem_state,
)
@ -72,6 +77,18 @@ from app.utils.user_message_multimodal import build_human_message_content
_background_tasks: set[asyncio.Task] = set()
_perf_log = get_perf_logger()
TURN_CANCELLING_INITIAL_DELAY_MS = 200
TURN_CANCELLING_BACKOFF_FACTOR = 2
TURN_CANCELLING_MAX_DELAY_MS = 1500
def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
if attempt < 1:
attempt = 1
delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
)
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
@ -401,15 +418,35 @@ def _classify_stream_exception(
exc: Exception,
*,
flow_label: str,
) -> tuple[str, str, Literal["info", "warn", "error"], bool, str]:
) -> tuple[
str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None
]:
raw = str(exc)
if isinstance(exc, BusyError) or "Thread is busy with another request" in raw:
busy_thread_id = str(exc.request_id) if isinstance(exc, BusyError) else None
if busy_thread_id and is_cancel_requested(busy_thread_id):
cancel_state = get_cancel_state(busy_thread_id)
attempt = cancel_state[0] if cancel_state else 1
retry_after_ms = _compute_turn_cancelling_retry_delay(attempt)
retry_after_at = int(time.time() * 1000) + retry_after_ms
return (
"thread_busy",
"TURN_CANCELLING",
"info",
True,
"A previous response is still stopping. Please try again in a moment.",
{
"retry_after_ms": retry_after_ms,
"retry_after_at": retry_after_at,
},
)
return (
"thread_busy",
"THREAD_BUSY",
"warn",
True,
"Another response is still finishing for this thread. Please try again in a moment.",
None,
)
parsed = _parse_error_payload(raw)
@ -431,6 +468,7 @@ def _classify_stream_exception(
"warn",
True,
"This model is temporarily rate-limited. Please try again in a few seconds or switch models.",
None,
)
return (
@ -439,6 +477,7 @@ def _classify_stream_exception(
"error",
False,
f"Error during {flow_label}: {raw}",
None,
)
@ -470,7 +509,7 @@ def _emit_stream_terminal_error(
message=message,
extra=extra,
)
return streaming_service.format_error(message, error_code=error_code)
return streaming_service.format_error(message, error_code=error_code, extra=extra)
def _legacy_match_lc_id(
@ -2497,6 +2536,7 @@ async def stream_new_chat(
"turn-info",
{"chat_turn_id": stream_result.turn_id},
)
yield streaming_service.format_data("turn-status", {"status": "busy"})
# Initial thinking step - analyzing the request
if mentioned_surfsense_docs:
@ -2805,6 +2845,7 @@ async def stream_new_chat(
task.add_done_callback(_background_tasks.discard)
# Finish the step and message
yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
@ -2819,11 +2860,19 @@ async def stream_new_chat(
severity,
is_expected,
user_message,
error_extra,
) = _classify_stream_exception(e, flow_label="chat")
error_message = f"Error during chat: {e!s}"
print(f"[stream_new_chat] {error_message}")
print(f"[stream_new_chat] Exception type: {type(e).__name__}")
print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}")
if error_code == "TURN_CANCELLING":
status_payload: dict[str, Any] = {"status": "cancelling"}
if error_extra:
status_payload.update(error_extra)
yield streaming_service.format_data("turn-status", status_payload)
else:
yield streaming_service.format_data("turn-status", {"status": "busy"})
yield _emit_stream_error(
message=user_message,
@ -2831,7 +2880,9 @@ async def stream_new_chat(
error_code=error_code,
severity=severity,
is_expected=is_expected,
extra=error_extra,
)
yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
@ -2847,6 +2898,10 @@ async def stream_new_chat(
# (CancelledError is a BaseException), and the rest of the
# finally block — including session.close() — would never run.
with anyio.CancelScope(shield=True):
# Authoritative fallback cleanup for lock/cancel state. Middleware
# teardown can be skipped on some client-abort paths.
end_turn(str(chat_id))
# Release premium reservation if not finalized
if _premium_request_id and _premium_reserved > 0 and user_id:
try:
@ -3206,6 +3261,7 @@ async def stream_resume_chat(
"turn-info",
{"chat_turn_id": stream_result.turn_id},
)
yield streaming_service.format_data("turn-status", {"status": "busy"})
_t_stream_start = time.perf_counter()
_first_event_logged = False
@ -3305,6 +3361,7 @@ async def stream_resume_chat(
},
)
yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
@ -3318,23 +3375,37 @@ async def stream_resume_chat(
severity,
is_expected,
user_message,
error_extra,
) = _classify_stream_exception(e, flow_label="resume")
error_message = f"Error during resume: {e!s}"
print(f"[stream_resume_chat] {error_message}")
print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}")
if error_code == "TURN_CANCELLING":
status_payload: dict[str, Any] = {"status": "cancelling"}
if error_extra:
status_payload.update(error_extra)
yield streaming_service.format_data("turn-status", status_payload)
else:
yield streaming_service.format_data("turn-status", {"status": "busy"})
yield _emit_stream_error(
message=user_message,
error_kind=error_kind,
error_code=error_code,
severity=severity,
is_expected=is_expected,
extra=error_extra,
)
yield streaming_service.format_data("turn-status", {"status": "idle"})
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
finally:
with anyio.CancelScope(shield=True):
# Authoritative fallback cleanup for lock/cancel state. Middleware
# teardown can be skipped on some client-abort paths.
end_turn(str(chat_id))
# Release premium reservation if not finalized
if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id:
try: