mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-05 05:42:39 +02:00
refactor(chat): implement turn cancellation and status management in new chat routes for improved user experience and error handling
This commit is contained in:
parent
4056bd1d69
commit
af66fbf106
12 changed files with 671 additions and 81 deletions
|
|
@ -15,7 +15,7 @@ import json
|
|||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
|
|
@ -29,6 +29,12 @@ from app.agents.new_chat.filesystem_selection import (
|
|||
FilesystemSelection,
|
||||
LocalFilesystemMount,
|
||||
)
|
||||
from app.agents.new_chat.middleware.busy_mutex import (
|
||||
get_cancel_state,
|
||||
is_cancel_requested,
|
||||
manager,
|
||||
request_cancel,
|
||||
)
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ChatComment,
|
||||
|
|
@ -44,6 +50,7 @@ from app.db import (
|
|||
)
|
||||
from app.schemas.new_chat import (
|
||||
AgentToolInfo,
|
||||
CancelActiveTurnResponse,
|
||||
LocalFilesystemMountPayload,
|
||||
NewChatMessageRead,
|
||||
NewChatRequest,
|
||||
|
|
@ -60,6 +67,7 @@ from app.schemas.new_chat import (
|
|||
ThreadListItem,
|
||||
ThreadListResponse,
|
||||
TokenUsageSummary,
|
||||
TurnStatusResponse,
|
||||
)
|
||||
from app.services.token_tracking_service import record_token_usage
|
||||
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
||||
|
|
@ -72,6 +80,9 @@ from app.utils.user_message_multimodal import (
|
|||
|
||||
_logger = logging.getLogger(__name__)
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
TURN_CANCELLING_INITIAL_DELAY_MS = 200
|
||||
TURN_CANCELLING_BACKOFF_FACTOR = 2
|
||||
TURN_CANCELLING_MAX_DELAY_MS = 1500
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -137,6 +148,72 @@ def _resolve_filesystem_selection(
|
|||
)
|
||||
|
||||
|
||||
def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
|
||||
"""Bounded exponential delay for TURN_CANCELLING retry hints."""
|
||||
if attempt < 1:
|
||||
attempt = 1
|
||||
delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
|
||||
TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
|
||||
)
|
||||
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
|
||||
|
||||
|
||||
def _build_turn_status_payload(thread_id: int) -> dict[str, object]:
|
||||
lock = manager.lock_for(str(thread_id))
|
||||
if not lock.locked():
|
||||
return {"status": "idle"}
|
||||
|
||||
if is_cancel_requested(str(thread_id)):
|
||||
cancel_state = get_cancel_state(str(thread_id))
|
||||
attempt = cancel_state[0] if cancel_state else 1
|
||||
retry_after_ms = _compute_turn_cancelling_retry_delay(attempt)
|
||||
retry_after_at = int(datetime.now(UTC).timestamp() * 1000) + retry_after_ms
|
||||
return {
|
||||
"status": "cancelling",
|
||||
"retry_after_ms": retry_after_ms,
|
||||
"retry_after_at": retry_after_at,
|
||||
}
|
||||
|
||||
return {"status": "busy"}
|
||||
|
||||
|
||||
def _set_retry_after_headers(response: Response, retry_after_ms: int) -> None:
|
||||
response.headers["retry-after-ms"] = str(retry_after_ms)
|
||||
response.headers["Retry-After"] = str(max(1, (retry_after_ms + 999) // 1000))
|
||||
|
||||
|
||||
def _raise_if_thread_busy_for_start(thread_id: int) -> None:
|
||||
status_payload = _build_turn_status_payload(thread_id)
|
||||
status = status_payload["status"]
|
||||
if status == "idle":
|
||||
return
|
||||
if status == "cancelling":
|
||||
retry_after_ms = int(status_payload.get("retry_after_ms") or 0)
|
||||
detail = {
|
||||
"errorCode": "TURN_CANCELLING",
|
||||
"message": "A previous response is still stopping. Please try again in a moment.",
|
||||
"retry_after_ms": retry_after_ms if retry_after_ms > 0 else None,
|
||||
"retry_after_at": status_payload.get("retry_after_at"),
|
||||
}
|
||||
headers = (
|
||||
{
|
||||
"retry-after-ms": str(retry_after_ms),
|
||||
"Retry-After": str(max(1, (retry_after_ms + 999) // 1000)),
|
||||
}
|
||||
if retry_after_ms > 0
|
||||
else None
|
||||
)
|
||||
raise HTTPException(status_code=409, detail=detail, headers=headers)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail={
|
||||
"errorCode": "THREAD_BUSY",
|
||||
"message": "Another response is still finishing for this thread. Please try again in a moment.",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _find_pre_turn_checkpoint_id(
|
||||
checkpoint_tuples: list,
|
||||
*,
|
||||
|
|
@ -1476,6 +1553,7 @@ async def handle_new_chat(
|
|||
|
||||
# Check thread-level access based on visibility
|
||||
await check_thread_access(session, thread, user)
|
||||
_raise_if_thread_busy_for_start(request.chat_id)
|
||||
filesystem_selection = _resolve_filesystem_selection(
|
||||
mode=request.filesystem_mode,
|
||||
client_platform=request.client_platform,
|
||||
|
|
@ -1550,6 +1628,93 @@ async def handle_new_chat(
|
|||
) from None
|
||||
|
||||
|
||||
@router.post(
|
||||
"/threads/{thread_id}/cancel-active-turn",
|
||||
response_model=CancelActiveTurnResponse,
|
||||
)
|
||||
async def cancel_active_turn(
|
||||
thread_id: int,
|
||||
response: Response,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Signal cancellation for the currently running turn on ``thread_id``."""
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
thread = result.scalars().first()
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_UPDATE.value,
|
||||
"You don't have permission to update chats in this search space",
|
||||
)
|
||||
await check_thread_access(session, thread, user)
|
||||
|
||||
status_payload = _build_turn_status_payload(thread_id)
|
||||
if status_payload["status"] == "idle":
|
||||
return CancelActiveTurnResponse(
|
||||
status="idle",
|
||||
error_code="NO_ACTIVE_TURN",
|
||||
)
|
||||
|
||||
request_cancel(str(thread_id))
|
||||
response.status_code = 202
|
||||
updated_payload = _build_turn_status_payload(thread_id)
|
||||
retry_after_ms = int(updated_payload.get("retry_after_ms") or 0)
|
||||
retry_after_at = (
|
||||
int(updated_payload["retry_after_at"])
|
||||
if "retry_after_at" in updated_payload
|
||||
else None
|
||||
)
|
||||
if retry_after_ms > 0:
|
||||
_set_retry_after_headers(response, retry_after_ms)
|
||||
return CancelActiveTurnResponse(
|
||||
status="cancelling",
|
||||
error_code="TURN_CANCELLING",
|
||||
retry_after_ms=retry_after_ms if retry_after_ms > 0 else None,
|
||||
retry_after_at=retry_after_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/threads/{thread_id}/turn-status",
|
||||
response_model=TurnStatusResponse,
|
||||
)
|
||||
async def get_turn_status(
|
||||
thread_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
thread = result.scalars().first()
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to view chats in this search space",
|
||||
)
|
||||
await check_thread_access(session, thread, user)
|
||||
|
||||
status_payload = _build_turn_status_payload(thread_id)
|
||||
return TurnStatusResponse(
|
||||
status=status_payload["status"], # type: ignore[arg-type]
|
||||
active_turn_id=None,
|
||||
retry_after_ms=status_payload.get("retry_after_ms"), # type: ignore[arg-type]
|
||||
retry_after_at=status_payload.get("retry_after_at"), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Chat Regeneration Endpoint (Edit/Reload)
|
||||
# =============================================================================
|
||||
|
|
@ -1605,6 +1770,7 @@ async def regenerate_response(
|
|||
|
||||
# Check thread-level access based on visibility
|
||||
await check_thread_access(session, thread, user)
|
||||
_raise_if_thread_busy_for_start(thread_id)
|
||||
filesystem_selection = _resolve_filesystem_selection(
|
||||
mode=request.filesystem_mode,
|
||||
client_platform=request.client_platform,
|
||||
|
|
@ -2012,6 +2178,7 @@ async def resume_chat(
|
|||
)
|
||||
|
||||
await check_thread_access(session, thread, user)
|
||||
_raise_if_thread_busy_for_start(thread_id)
|
||||
filesystem_selection = _resolve_filesystem_selection(
|
||||
mode=request.filesystem_mode,
|
||||
client_platform=request.client_platform,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue