refactor(chat): implement turn cancellation and status management in new chat routes for improved user experience and error handling

This commit is contained in:
Anish Sarkar 2026-05-01 01:47:52 +05:30
parent 4056bd1d69
commit af66fbf106
12 changed files with 671 additions and 81 deletions

View file

@ -15,7 +15,7 @@ import json
import logging
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from sqlalchemy import func, or_
from sqlalchemy.exc import IntegrityError, OperationalError
@ -29,6 +29,12 @@ from app.agents.new_chat.filesystem_selection import (
FilesystemSelection,
LocalFilesystemMount,
)
from app.agents.new_chat.middleware.busy_mutex import (
get_cancel_state,
is_cancel_requested,
manager,
request_cancel,
)
from app.config import config
from app.db import (
ChatComment,
@ -44,6 +50,7 @@ from app.db import (
)
from app.schemas.new_chat import (
AgentToolInfo,
CancelActiveTurnResponse,
LocalFilesystemMountPayload,
NewChatMessageRead,
NewChatRequest,
@ -60,6 +67,7 @@ from app.schemas.new_chat import (
ThreadListItem,
ThreadListResponse,
TokenUsageSummary,
TurnStatusResponse,
)
from app.services.token_tracking_service import record_token_usage
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
@ -72,6 +80,9 @@ from app.utils.user_message_multimodal import (
_logger = logging.getLogger(__name__)
_background_tasks: set[asyncio.Task] = set()
TURN_CANCELLING_INITIAL_DELAY_MS = 200
TURN_CANCELLING_BACKOFF_FACTOR = 2
TURN_CANCELLING_MAX_DELAY_MS = 1500
router = APIRouter()
@ -137,6 +148,72 @@ def _resolve_filesystem_selection(
)
def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
"""Bounded exponential delay for TURN_CANCELLING retry hints."""
if attempt < 1:
attempt = 1
delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
)
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
def _build_turn_status_payload(thread_id: int) -> dict[str, object]:
lock = manager.lock_for(str(thread_id))
if not lock.locked():
return {"status": "idle"}
if is_cancel_requested(str(thread_id)):
cancel_state = get_cancel_state(str(thread_id))
attempt = cancel_state[0] if cancel_state else 1
retry_after_ms = _compute_turn_cancelling_retry_delay(attempt)
retry_after_at = int(datetime.now(UTC).timestamp() * 1000) + retry_after_ms
return {
"status": "cancelling",
"retry_after_ms": retry_after_ms,
"retry_after_at": retry_after_at,
}
return {"status": "busy"}
def _set_retry_after_headers(response: Response, retry_after_ms: int) -> None:
response.headers["retry-after-ms"] = str(retry_after_ms)
response.headers["Retry-After"] = str(max(1, (retry_after_ms + 999) // 1000))
def _raise_if_thread_busy_for_start(thread_id: int) -> None:
status_payload = _build_turn_status_payload(thread_id)
status = status_payload["status"]
if status == "idle":
return
if status == "cancelling":
retry_after_ms = int(status_payload.get("retry_after_ms") or 0)
detail = {
"errorCode": "TURN_CANCELLING",
"message": "A previous response is still stopping. Please try again in a moment.",
"retry_after_ms": retry_after_ms if retry_after_ms > 0 else None,
"retry_after_at": status_payload.get("retry_after_at"),
}
headers = (
{
"retry-after-ms": str(retry_after_ms),
"Retry-After": str(max(1, (retry_after_ms + 999) // 1000)),
}
if retry_after_ms > 0
else None
)
raise HTTPException(status_code=409, detail=detail, headers=headers)
raise HTTPException(
status_code=409,
detail={
"errorCode": "THREAD_BUSY",
"message": "Another response is still finishing for this thread. Please try again in a moment.",
},
)
def _find_pre_turn_checkpoint_id(
checkpoint_tuples: list,
*,
@ -1476,6 +1553,7 @@ async def handle_new_chat(
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
_raise_if_thread_busy_for_start(request.chat_id)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,
@ -1550,6 +1628,93 @@ async def handle_new_chat(
) from None
@router.post(
"/threads/{thread_id}/cancel-active-turn",
response_model=CancelActiveTurnResponse,
)
async def cancel_active_turn(
thread_id: int,
response: Response,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Signal cancellation for the currently running turn on ``thread_id``."""
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
)
await check_thread_access(session, thread, user)
status_payload = _build_turn_status_payload(thread_id)
if status_payload["status"] == "idle":
return CancelActiveTurnResponse(
status="idle",
error_code="NO_ACTIVE_TURN",
)
request_cancel(str(thread_id))
response.status_code = 202
updated_payload = _build_turn_status_payload(thread_id)
retry_after_ms = int(updated_payload.get("retry_after_ms") or 0)
retry_after_at = (
int(updated_payload["retry_after_at"])
if "retry_after_at" in updated_payload
else None
)
if retry_after_ms > 0:
_set_retry_after_headers(response, retry_after_ms)
return CancelActiveTurnResponse(
status="cancelling",
error_code="TURN_CANCELLING",
retry_after_ms=retry_after_ms if retry_after_ms > 0 else None,
retry_after_at=retry_after_at,
)
@router.get(
"/threads/{thread_id}/turn-status",
response_model=TurnStatusResponse,
)
async def get_turn_status(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to view chats in this search space",
)
await check_thread_access(session, thread, user)
status_payload = _build_turn_status_payload(thread_id)
return TurnStatusResponse(
status=status_payload["status"], # type: ignore[arg-type]
active_turn_id=None,
retry_after_ms=status_payload.get("retry_after_ms"), # type: ignore[arg-type]
retry_after_at=status_payload.get("retry_after_at"), # type: ignore[arg-type]
)
# =============================================================================
# Chat Regeneration Endpoint (Edit/Reload)
# =============================================================================
@ -1605,6 +1770,7 @@ async def regenerate_response(
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
_raise_if_thread_busy_for_start(thread_id)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,
@ -2012,6 +2178,7 @@ async def resume_chat(
)
await check_thread_access(session, thread, user)
_raise_if_thread_busy_for_start(thread_id)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,