mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 18:35:19 +02:00
refactor(chat): implement turn cancellation and status management in new chat routes for improved user experience and error handling
This commit is contained in:
parent
4056bd1d69
commit
af66fbf106
12 changed files with 671 additions and 81 deletions
|
|
@ -33,6 +33,7 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -58,6 +59,8 @@ class _ThreadLockManager:
|
|||
weakref.WeakValueDictionary()
|
||||
)
|
||||
self._cancel_events: dict[str, asyncio.Event] = {}
|
||||
self._cancel_requested_at_ms: dict[str, int] = {}
|
||||
self._cancel_attempt_count: dict[str, int] = {}
|
||||
|
||||
def lock_for(self, thread_id: str) -> asyncio.Lock:
|
||||
lock = self._locks.get(thread_id)
|
||||
|
|
@ -76,14 +79,45 @@ class _ThreadLockManager:
|
|||
def request_cancel(self, thread_id: str) -> bool:
|
||||
event = self._cancel_events.get(thread_id)
|
||||
if event is None:
|
||||
return False
|
||||
event = asyncio.Event()
|
||||
self._cancel_events[thread_id] = event
|
||||
event.set()
|
||||
now_ms = int(time.time() * 1000)
|
||||
self._cancel_requested_at_ms[thread_id] = now_ms
|
||||
self._cancel_attempt_count[thread_id] = (
|
||||
self._cancel_attempt_count.get(thread_id, 0) + 1
|
||||
)
|
||||
return True
|
||||
|
||||
def is_cancel_requested(self, thread_id: str) -> bool:
|
||||
event = self._cancel_events.get(thread_id)
|
||||
return bool(event and event.is_set())
|
||||
|
||||
def cancel_state(self, thread_id: str) -> tuple[int, int] | None:
|
||||
if not self.is_cancel_requested(thread_id):
|
||||
return None
|
||||
attempts = self._cancel_attempt_count.get(thread_id, 1)
|
||||
requested_at_ms = self._cancel_requested_at_ms.get(thread_id, 0)
|
||||
return attempts, requested_at_ms
|
||||
|
||||
def reset(self, thread_id: str) -> None:
|
||||
event = self._cancel_events.get(thread_id)
|
||||
if event is not None:
|
||||
event.clear()
|
||||
self._cancel_requested_at_ms.pop(thread_id, None)
|
||||
self._cancel_attempt_count.pop(thread_id, None)
|
||||
|
||||
def end_turn(self, thread_id: str) -> None:
|
||||
"""Best-effort terminal cleanup for a thread turn.
|
||||
|
||||
This is intentionally idempotent and safe to call from outer stream
|
||||
finally-blocks where middleware teardown might be skipped due to abort
|
||||
or disconnect edge-cases.
|
||||
"""
|
||||
lock = self._locks.get(thread_id)
|
||||
if lock is not None and lock.locked():
|
||||
lock.release()
|
||||
self.reset(thread_id)
|
||||
|
||||
|
||||
# Module-level singleton — process-local but reused across all agent
|
||||
|
|
@ -98,15 +132,30 @@ def get_cancel_event(thread_id: str) -> asyncio.Event:
|
|||
|
||||
|
||||
def request_cancel(thread_id: str) -> bool:
|
||||
"""Trip the cancel event for ``thread_id``. Returns True if found."""
|
||||
"""Trip the cancel event for ``thread_id``. Always returns True."""
|
||||
return manager.request_cancel(thread_id)
|
||||
|
||||
|
||||
def is_cancel_requested(thread_id: str) -> bool:
|
||||
"""Return whether ``thread_id`` currently has a pending cancel signal."""
|
||||
return manager.is_cancel_requested(thread_id)
|
||||
|
||||
|
||||
def get_cancel_state(thread_id: str) -> tuple[int, int] | None:
|
||||
"""Return ``(attempt_count, requested_at_ms)`` for pending cancel state."""
|
||||
return manager.cancel_state(thread_id)
|
||||
|
||||
|
||||
def reset_cancel(thread_id: str) -> None:
|
||||
"""Reset the cancel event for ``thread_id`` (called between turns)."""
|
||||
manager.reset(thread_id)
|
||||
|
||||
|
||||
def end_turn(thread_id: str) -> None:
|
||||
"""Force end-of-turn cleanup for lock + cancel state."""
|
||||
manager.end_turn(thread_id)
|
||||
|
||||
|
||||
class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""Block concurrent prompts on the same thread.
|
||||
|
||||
|
|
@ -229,7 +278,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
|||
|
||||
__all__ = [
|
||||
"BusyMutexMiddleware",
|
||||
"end_turn",
|
||||
"get_cancel_event",
|
||||
"get_cancel_state",
|
||||
"is_cancel_requested",
|
||||
"manager",
|
||||
"request_cancel",
|
||||
"reset_cancel",
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import json
|
|||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
|
|
@ -29,6 +29,12 @@ from app.agents.new_chat.filesystem_selection import (
|
|||
FilesystemSelection,
|
||||
LocalFilesystemMount,
|
||||
)
|
||||
from app.agents.new_chat.middleware.busy_mutex import (
|
||||
get_cancel_state,
|
||||
is_cancel_requested,
|
||||
manager,
|
||||
request_cancel,
|
||||
)
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ChatComment,
|
||||
|
|
@ -44,6 +50,7 @@ from app.db import (
|
|||
)
|
||||
from app.schemas.new_chat import (
|
||||
AgentToolInfo,
|
||||
CancelActiveTurnResponse,
|
||||
LocalFilesystemMountPayload,
|
||||
NewChatMessageRead,
|
||||
NewChatRequest,
|
||||
|
|
@ -60,6 +67,7 @@ from app.schemas.new_chat import (
|
|||
ThreadListItem,
|
||||
ThreadListResponse,
|
||||
TokenUsageSummary,
|
||||
TurnStatusResponse,
|
||||
)
|
||||
from app.services.token_tracking_service import record_token_usage
|
||||
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
||||
|
|
@ -72,6 +80,9 @@ from app.utils.user_message_multimodal import (
|
|||
|
||||
_logger = logging.getLogger(__name__)
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
TURN_CANCELLING_INITIAL_DELAY_MS = 200
|
||||
TURN_CANCELLING_BACKOFF_FACTOR = 2
|
||||
TURN_CANCELLING_MAX_DELAY_MS = 1500
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -137,6 +148,72 @@ def _resolve_filesystem_selection(
|
|||
)
|
||||
|
||||
|
||||
def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
|
||||
"""Bounded exponential delay for TURN_CANCELLING retry hints."""
|
||||
if attempt < 1:
|
||||
attempt = 1
|
||||
delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
|
||||
TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
|
||||
)
|
||||
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
|
||||
|
||||
|
||||
def _build_turn_status_payload(thread_id: int) -> dict[str, object]:
|
||||
lock = manager.lock_for(str(thread_id))
|
||||
if not lock.locked():
|
||||
return {"status": "idle"}
|
||||
|
||||
if is_cancel_requested(str(thread_id)):
|
||||
cancel_state = get_cancel_state(str(thread_id))
|
||||
attempt = cancel_state[0] if cancel_state else 1
|
||||
retry_after_ms = _compute_turn_cancelling_retry_delay(attempt)
|
||||
retry_after_at = int(datetime.now(UTC).timestamp() * 1000) + retry_after_ms
|
||||
return {
|
||||
"status": "cancelling",
|
||||
"retry_after_ms": retry_after_ms,
|
||||
"retry_after_at": retry_after_at,
|
||||
}
|
||||
|
||||
return {"status": "busy"}
|
||||
|
||||
|
||||
def _set_retry_after_headers(response: Response, retry_after_ms: int) -> None:
|
||||
response.headers["retry-after-ms"] = str(retry_after_ms)
|
||||
response.headers["Retry-After"] = str(max(1, (retry_after_ms + 999) // 1000))
|
||||
|
||||
|
||||
def _raise_if_thread_busy_for_start(thread_id: int) -> None:
|
||||
status_payload = _build_turn_status_payload(thread_id)
|
||||
status = status_payload["status"]
|
||||
if status == "idle":
|
||||
return
|
||||
if status == "cancelling":
|
||||
retry_after_ms = int(status_payload.get("retry_after_ms") or 0)
|
||||
detail = {
|
||||
"errorCode": "TURN_CANCELLING",
|
||||
"message": "A previous response is still stopping. Please try again in a moment.",
|
||||
"retry_after_ms": retry_after_ms if retry_after_ms > 0 else None,
|
||||
"retry_after_at": status_payload.get("retry_after_at"),
|
||||
}
|
||||
headers = (
|
||||
{
|
||||
"retry-after-ms": str(retry_after_ms),
|
||||
"Retry-After": str(max(1, (retry_after_ms + 999) // 1000)),
|
||||
}
|
||||
if retry_after_ms > 0
|
||||
else None
|
||||
)
|
||||
raise HTTPException(status_code=409, detail=detail, headers=headers)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail={
|
||||
"errorCode": "THREAD_BUSY",
|
||||
"message": "Another response is still finishing for this thread. Please try again in a moment.",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _find_pre_turn_checkpoint_id(
|
||||
checkpoint_tuples: list,
|
||||
*,
|
||||
|
|
@ -1476,6 +1553,7 @@ async def handle_new_chat(
|
|||
|
||||
# Check thread-level access based on visibility
|
||||
await check_thread_access(session, thread, user)
|
||||
_raise_if_thread_busy_for_start(request.chat_id)
|
||||
filesystem_selection = _resolve_filesystem_selection(
|
||||
mode=request.filesystem_mode,
|
||||
client_platform=request.client_platform,
|
||||
|
|
@ -1550,6 +1628,93 @@ async def handle_new_chat(
|
|||
) from None
|
||||
|
||||
|
||||
@router.post(
|
||||
"/threads/{thread_id}/cancel-active-turn",
|
||||
response_model=CancelActiveTurnResponse,
|
||||
)
|
||||
async def cancel_active_turn(
|
||||
thread_id: int,
|
||||
response: Response,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Signal cancellation for the currently running turn on ``thread_id``."""
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
thread = result.scalars().first()
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_UPDATE.value,
|
||||
"You don't have permission to update chats in this search space",
|
||||
)
|
||||
await check_thread_access(session, thread, user)
|
||||
|
||||
status_payload = _build_turn_status_payload(thread_id)
|
||||
if status_payload["status"] == "idle":
|
||||
return CancelActiveTurnResponse(
|
||||
status="idle",
|
||||
error_code="NO_ACTIVE_TURN",
|
||||
)
|
||||
|
||||
request_cancel(str(thread_id))
|
||||
response.status_code = 202
|
||||
updated_payload = _build_turn_status_payload(thread_id)
|
||||
retry_after_ms = int(updated_payload.get("retry_after_ms") or 0)
|
||||
retry_after_at = (
|
||||
int(updated_payload["retry_after_at"])
|
||||
if "retry_after_at" in updated_payload
|
||||
else None
|
||||
)
|
||||
if retry_after_ms > 0:
|
||||
_set_retry_after_headers(response, retry_after_ms)
|
||||
return CancelActiveTurnResponse(
|
||||
status="cancelling",
|
||||
error_code="TURN_CANCELLING",
|
||||
retry_after_ms=retry_after_ms if retry_after_ms > 0 else None,
|
||||
retry_after_at=retry_after_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/threads/{thread_id}/turn-status",
|
||||
response_model=TurnStatusResponse,
|
||||
)
|
||||
async def get_turn_status(
|
||||
thread_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
thread = result.scalars().first()
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to view chats in this search space",
|
||||
)
|
||||
await check_thread_access(session, thread, user)
|
||||
|
||||
status_payload = _build_turn_status_payload(thread_id)
|
||||
return TurnStatusResponse(
|
||||
status=status_payload["status"], # type: ignore[arg-type]
|
||||
active_turn_id=None,
|
||||
retry_after_ms=status_payload.get("retry_after_ms"), # type: ignore[arg-type]
|
||||
retry_after_at=status_payload.get("retry_after_at"), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Chat Regeneration Endpoint (Edit/Reload)
|
||||
# =============================================================================
|
||||
|
|
@ -1605,6 +1770,7 @@ async def regenerate_response(
|
|||
|
||||
# Check thread-level access based on visibility
|
||||
await check_thread_access(session, thread, user)
|
||||
_raise_if_thread_busy_for_start(thread_id)
|
||||
filesystem_selection = _resolve_filesystem_selection(
|
||||
mode=request.filesystem_mode,
|
||||
client_platform=request.client_platform,
|
||||
|
|
@ -2012,6 +2178,7 @@ async def resume_chat(
|
|||
)
|
||||
|
||||
await check_thread_access(session, thread, user)
|
||||
_raise_if_thread_busy_for_start(thread_id)
|
||||
filesystem_selection = _resolve_filesystem_selection(
|
||||
mode=request.filesystem_mode,
|
||||
client_platform=request.client_platform,
|
||||
|
|
|
|||
|
|
@ -335,6 +335,24 @@ class ResumeRequest(BaseModel):
|
|||
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
|
||||
|
||||
|
||||
class CancelActiveTurnResponse(BaseModel):
|
||||
"""Response for canceling an active turn on a chat thread."""
|
||||
|
||||
status: Literal["cancelling", "idle"]
|
||||
error_code: Literal["TURN_CANCELLING", "NO_ACTIVE_TURN"]
|
||||
retry_after_ms: int | None = None
|
||||
retry_after_at: int | None = None
|
||||
|
||||
|
||||
class TurnStatusResponse(BaseModel):
|
||||
"""Current turn execution status for a thread."""
|
||||
|
||||
status: Literal["idle", "busy", "cancelling"]
|
||||
active_turn_id: str | None = None
|
||||
retry_after_ms: int | None = None
|
||||
retry_after_at: int | None = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Public Chat Snapshot Schemas
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -565,7 +565,12 @@ class VercelStreamingService:
|
|||
# Error Part
|
||||
# =========================================================================
|
||||
|
||||
def format_error(self, error_text: str, error_code: str | None = None) -> str:
|
||||
def format_error(
|
||||
self,
|
||||
error_text: str,
|
||||
error_code: str | None = None,
|
||||
extra: dict[str, object] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format an error message.
|
||||
|
||||
|
|
@ -579,9 +584,11 @@ class VercelStreamingService:
|
|||
Example output:
|
||||
data: {"type":"error","errorText":"Something went wrong","errorCode":"SOME_CODE"}
|
||||
"""
|
||||
payload: dict[str, str] = {"type": "error", "errorText": error_text}
|
||||
payload: dict[str, object] = {"type": "error", "errorText": error_text}
|
||||
if error_code:
|
||||
payload["errorCode"] = error_code
|
||||
if extra:
|
||||
payload.update(extra)
|
||||
return self._format_sse(payload)
|
||||
|
||||
# =========================================================================
|
||||
|
|
|
|||
|
|
@ -45,6 +45,11 @@ from app.agents.new_chat.memory_extraction import (
|
|||
extract_and_save_memory,
|
||||
extract_and_save_team_memory,
|
||||
)
|
||||
from app.agents.new_chat.middleware.busy_mutex import (
|
||||
end_turn,
|
||||
get_cancel_state,
|
||||
is_cancel_requested,
|
||||
)
|
||||
from app.agents.new_chat.middleware.kb_persistence import (
|
||||
commit_staged_filesystem_state,
|
||||
)
|
||||
|
|
@ -72,6 +77,18 @@ from app.utils.user_message_multimodal import build_human_message_content
|
|||
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
_perf_log = get_perf_logger()
|
||||
TURN_CANCELLING_INITIAL_DELAY_MS = 200
|
||||
TURN_CANCELLING_BACKOFF_FACTOR = 2
|
||||
TURN_CANCELLING_MAX_DELAY_MS = 1500
|
||||
|
||||
|
||||
def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
|
||||
if attempt < 1:
|
||||
attempt = 1
|
||||
delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
|
||||
TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
|
||||
)
|
||||
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
|
||||
|
||||
|
||||
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
|
||||
|
|
@ -401,15 +418,35 @@ def _classify_stream_exception(
|
|||
exc: Exception,
|
||||
*,
|
||||
flow_label: str,
|
||||
) -> tuple[str, str, Literal["info", "warn", "error"], bool, str]:
|
||||
) -> tuple[
|
||||
str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None
|
||||
]:
|
||||
raw = str(exc)
|
||||
if isinstance(exc, BusyError) or "Thread is busy with another request" in raw:
|
||||
busy_thread_id = str(exc.request_id) if isinstance(exc, BusyError) else None
|
||||
if busy_thread_id and is_cancel_requested(busy_thread_id):
|
||||
cancel_state = get_cancel_state(busy_thread_id)
|
||||
attempt = cancel_state[0] if cancel_state else 1
|
||||
retry_after_ms = _compute_turn_cancelling_retry_delay(attempt)
|
||||
retry_after_at = int(time.time() * 1000) + retry_after_ms
|
||||
return (
|
||||
"thread_busy",
|
||||
"TURN_CANCELLING",
|
||||
"info",
|
||||
True,
|
||||
"A previous response is still stopping. Please try again in a moment.",
|
||||
{
|
||||
"retry_after_ms": retry_after_ms,
|
||||
"retry_after_at": retry_after_at,
|
||||
},
|
||||
)
|
||||
return (
|
||||
"thread_busy",
|
||||
"THREAD_BUSY",
|
||||
"warn",
|
||||
True,
|
||||
"Another response is still finishing for this thread. Please try again in a moment.",
|
||||
None,
|
||||
)
|
||||
|
||||
parsed = _parse_error_payload(raw)
|
||||
|
|
@ -431,6 +468,7 @@ def _classify_stream_exception(
|
|||
"warn",
|
||||
True,
|
||||
"This model is temporarily rate-limited. Please try again in a few seconds or switch models.",
|
||||
None,
|
||||
)
|
||||
|
||||
return (
|
||||
|
|
@ -439,6 +477,7 @@ def _classify_stream_exception(
|
|||
"error",
|
||||
False,
|
||||
f"Error during {flow_label}: {raw}",
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -470,7 +509,7 @@ def _emit_stream_terminal_error(
|
|||
message=message,
|
||||
extra=extra,
|
||||
)
|
||||
return streaming_service.format_error(message, error_code=error_code)
|
||||
return streaming_service.format_error(message, error_code=error_code, extra=extra)
|
||||
|
||||
|
||||
def _legacy_match_lc_id(
|
||||
|
|
@ -2497,6 +2536,7 @@ async def stream_new_chat(
|
|||
"turn-info",
|
||||
{"chat_turn_id": stream_result.turn_id},
|
||||
)
|
||||
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||
|
||||
# Initial thinking step - analyzing the request
|
||||
if mentioned_surfsense_docs:
|
||||
|
|
@ -2805,6 +2845,7 @@ async def stream_new_chat(
|
|||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
# Finish the step and message
|
||||
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
|
|
@ -2819,11 +2860,19 @@ async def stream_new_chat(
|
|||
severity,
|
||||
is_expected,
|
||||
user_message,
|
||||
error_extra,
|
||||
) = _classify_stream_exception(e, flow_label="chat")
|
||||
error_message = f"Error during chat: {e!s}"
|
||||
print(f"[stream_new_chat] {error_message}")
|
||||
print(f"[stream_new_chat] Exception type: {type(e).__name__}")
|
||||
print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}")
|
||||
if error_code == "TURN_CANCELLING":
|
||||
status_payload: dict[str, Any] = {"status": "cancelling"}
|
||||
if error_extra:
|
||||
status_payload.update(error_extra)
|
||||
yield streaming_service.format_data("turn-status", status_payload)
|
||||
else:
|
||||
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||
|
||||
yield _emit_stream_error(
|
||||
message=user_message,
|
||||
|
|
@ -2831,7 +2880,9 @@ async def stream_new_chat(
|
|||
error_code=error_code,
|
||||
severity=severity,
|
||||
is_expected=is_expected,
|
||||
extra=error_extra,
|
||||
)
|
||||
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
|
|
@ -2847,6 +2898,10 @@ async def stream_new_chat(
|
|||
# (CancelledError is a BaseException), and the rest of the
|
||||
# finally block — including session.close() — would never run.
|
||||
with anyio.CancelScope(shield=True):
|
||||
# Authoritative fallback cleanup for lock/cancel state. Middleware
|
||||
# teardown can be skipped on some client-abort paths.
|
||||
end_turn(str(chat_id))
|
||||
|
||||
# Release premium reservation if not finalized
|
||||
if _premium_request_id and _premium_reserved > 0 and user_id:
|
||||
try:
|
||||
|
|
@ -3206,6 +3261,7 @@ async def stream_resume_chat(
|
|||
"turn-info",
|
||||
{"chat_turn_id": stream_result.turn_id},
|
||||
)
|
||||
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||
|
||||
_t_stream_start = time.perf_counter()
|
||||
_first_event_logged = False
|
||||
|
|
@ -3305,6 +3361,7 @@ async def stream_resume_chat(
|
|||
},
|
||||
)
|
||||
|
||||
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
|
|
@ -3318,23 +3375,37 @@ async def stream_resume_chat(
|
|||
severity,
|
||||
is_expected,
|
||||
user_message,
|
||||
error_extra,
|
||||
) = _classify_stream_exception(e, flow_label="resume")
|
||||
error_message = f"Error during resume: {e!s}"
|
||||
print(f"[stream_resume_chat] {error_message}")
|
||||
print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}")
|
||||
if error_code == "TURN_CANCELLING":
|
||||
status_payload: dict[str, Any] = {"status": "cancelling"}
|
||||
if error_extra:
|
||||
status_payload.update(error_extra)
|
||||
yield streaming_service.format_data("turn-status", status_payload)
|
||||
else:
|
||||
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||
yield _emit_stream_error(
|
||||
message=user_message,
|
||||
error_kind=error_kind,
|
||||
error_code=error_code,
|
||||
severity=severity,
|
||||
is_expected=is_expected,
|
||||
extra=error_extra,
|
||||
)
|
||||
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
|
||||
finally:
|
||||
with anyio.CancelScope(shield=True):
|
||||
# Authoritative fallback cleanup for lock/cancel state. Middleware
|
||||
# teardown can be skipped on some client-abort paths.
|
||||
end_turn(str(chat_id))
|
||||
|
||||
# Release premium reservation if not finalized
|
||||
if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -7,7 +7,9 @@ import pytest
|
|||
from app.agents.new_chat.errors import BusyError
|
||||
from app.agents.new_chat.middleware.busy_mutex import (
|
||||
BusyMutexMiddleware,
|
||||
end_turn,
|
||||
get_cancel_event,
|
||||
is_cancel_requested,
|
||||
manager,
|
||||
request_cancel,
|
||||
reset_cancel,
|
||||
|
|
@ -88,3 +90,31 @@ async def test_no_thread_id_skipped_when_not_required() -> None:
|
|||
def test_reset_cancel_idempotent() -> None:
|
||||
# Should not raise even if event was never created
|
||||
reset_cancel("never-seen")
|
||||
|
||||
|
||||
def test_request_cancel_creates_event_for_unseen_thread() -> None:
|
||||
thread_id = "never-seen-cancel"
|
||||
reset_cancel(thread_id)
|
||||
|
||||
assert request_cancel(thread_id) is True
|
||||
assert get_cancel_event(thread_id).is_set()
|
||||
assert is_cancel_requested(thread_id) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_turn_force_clears_lock_and_cancel_state() -> None:
|
||||
thread_id = "forced-end-turn"
|
||||
mw = BusyMutexMiddleware()
|
||||
runtime = _Runtime(thread_id)
|
||||
|
||||
await mw.abefore_agent({}, runtime)
|
||||
assert manager.lock_for(thread_id).locked()
|
||||
|
||||
request_cancel(thread_id)
|
||||
assert is_cancel_requested(thread_id) is True
|
||||
|
||||
end_turn(thread_id)
|
||||
|
||||
assert not manager.lock_for(thread_id).locked()
|
||||
assert not get_cancel_event(thread_id).is_set()
|
||||
assert is_cancel_requested(thread_id) is False
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import pytest
|
|||
|
||||
import app.tasks.chat.stream_new_chat as stream_new_chat_module
|
||||
from app.agents.new_chat.errors import BusyError
|
||||
from app.agents.new_chat.middleware.busy_mutex import request_cancel, reset_cancel
|
||||
from app.tasks.chat.stream_new_chat import (
|
||||
StreamResult,
|
||||
_classify_stream_exception,
|
||||
|
|
@ -147,7 +148,7 @@ def test_stream_exception_classifies_rate_limited():
|
|||
exc = Exception(
|
||||
'{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}'
|
||||
)
|
||||
kind, code, severity, is_expected, user_message = _classify_stream_exception(
|
||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||
exc, flow_label="chat"
|
||||
)
|
||||
assert kind == "rate_limited"
|
||||
|
|
@ -155,11 +156,12 @@ def test_stream_exception_classifies_rate_limited():
|
|||
assert severity == "warn"
|
||||
assert is_expected is True
|
||||
assert "temporarily rate-limited" in user_message
|
||||
assert extra is None
|
||||
|
||||
|
||||
def test_stream_exception_classifies_thread_busy():
|
||||
exc = BusyError(request_id="thread-123")
|
||||
kind, code, severity, is_expected, user_message = _classify_stream_exception(
|
||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||
exc, flow_label="chat"
|
||||
)
|
||||
assert kind == "thread_busy"
|
||||
|
|
@ -167,11 +169,12 @@ def test_stream_exception_classifies_thread_busy():
|
|||
assert severity == "warn"
|
||||
assert is_expected is True
|
||||
assert "still finishing for this thread" in user_message
|
||||
assert extra is None
|
||||
|
||||
|
||||
def test_stream_exception_classifies_thread_busy_from_message():
|
||||
exc = Exception("Thread is busy with another request")
|
||||
kind, code, severity, is_expected, user_message = _classify_stream_exception(
|
||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||
exc, flow_label="chat"
|
||||
)
|
||||
assert kind == "thread_busy"
|
||||
|
|
@ -179,6 +182,24 @@ def test_stream_exception_classifies_thread_busy_from_message():
|
|||
assert severity == "warn"
|
||||
assert is_expected is True
|
||||
assert "still finishing for this thread" in user_message
|
||||
assert extra is None
|
||||
|
||||
|
||||
def test_stream_exception_classifies_turn_cancelling_when_cancel_requested():
|
||||
thread_id = "thread-cancelling-1"
|
||||
reset_cancel(thread_id)
|
||||
request_cancel(thread_id)
|
||||
exc = BusyError(request_id=thread_id)
|
||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||
exc, flow_label="chat"
|
||||
)
|
||||
assert kind == "thread_busy"
|
||||
assert code == "TURN_CANCELLING"
|
||||
assert severity == "info"
|
||||
assert is_expected is True
|
||||
assert "stopping" in user_message
|
||||
assert isinstance(extra, dict)
|
||||
assert "retry_after_ms" in extra
|
||||
|
||||
|
||||
def test_premium_classification_is_error_code_driven():
|
||||
|
|
@ -219,33 +240,33 @@ def test_toast_only_pre_accept_policy_has_no_inline_failed_marker():
|
|||
def test_network_send_failures_use_unified_retry_toast_message():
|
||||
classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts"
|
||||
classifier_source = classifier_path.read_text(encoding="utf-8")
|
||||
page_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
|
||||
request_errors_path = (
|
||||
Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-request-errors.ts"
|
||||
)
|
||||
page_source = page_path.read_text(encoding="utf-8")
|
||||
request_errors_source = request_errors_path.read_text(encoding="utf-8")
|
||||
|
||||
assert '"send_failed_pre_accept"' in classifier_source
|
||||
assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source
|
||||
assert 'errorCode === "TURN_CANCELLING"' in classifier_source
|
||||
assert "if (withCode.code) return withCode.code;" in classifier_source
|
||||
assert 'userMessage: "Message not sent. Please retry."' in classifier_source
|
||||
assert 'userMessage: "Connection issue. Please try again."' in classifier_source
|
||||
assert "tagPreAcceptSendFailure(error)" in page_source
|
||||
assert "const passthroughCodes = new Set([" in page_source
|
||||
assert '"PREMIUM_QUOTA_EXHAUSTED"' in page_source
|
||||
assert '"THREAD_BUSY"' in page_source
|
||||
assert '"AUTH_EXPIRED"' in page_source
|
||||
assert '"UNAUTHORIZED"' in page_source
|
||||
assert '"RATE_LIMITED"' in page_source
|
||||
assert '"NETWORK_ERROR"' in page_source
|
||||
assert '"STREAM_PARSE_ERROR"' in page_source
|
||||
assert '"TOOL_EXECUTION_ERROR"' in page_source
|
||||
assert '"PERSIST_MESSAGE_FAILED"' in page_source
|
||||
assert '"SERVER_ERROR"' in page_source
|
||||
assert "passthroughCodes.has(existingCode)" in page_source
|
||||
assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in page_source
|
||||
assert 'errorCode: "NETWORK_ERROR"' not in page_source
|
||||
assert "Failed to start chat. Please try again." not in page_source
|
||||
assert "const passthroughCodes = new Set([" in request_errors_source
|
||||
assert '"PREMIUM_QUOTA_EXHAUSTED"' in request_errors_source
|
||||
assert '"THREAD_BUSY"' in request_errors_source
|
||||
assert '"TURN_CANCELLING"' in request_errors_source
|
||||
assert '"AUTH_EXPIRED"' in request_errors_source
|
||||
assert '"UNAUTHORIZED"' in request_errors_source
|
||||
assert '"RATE_LIMITED"' in request_errors_source
|
||||
assert '"NETWORK_ERROR"' in request_errors_source
|
||||
assert '"STREAM_PARSE_ERROR"' in request_errors_source
|
||||
assert '"TOOL_EXECUTION_ERROR"' in request_errors_source
|
||||
assert '"PERSIST_MESSAGE_FAILED"' in request_errors_source
|
||||
assert '"SERVER_ERROR"' in request_errors_source
|
||||
assert "passthroughCodes.has(existingCode)" in request_errors_source
|
||||
assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in request_errors_source
|
||||
assert 'errorCode: "NETWORK_ERROR"' not in request_errors_source
|
||||
assert "Failed to start chat. Please try again." not in classifier_source
|
||||
|
||||
|
||||
def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows():
|
||||
|
|
@ -269,3 +290,75 @@ def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows()
|
|||
|
||||
# New flow persists only when accepted and not already persisted.
|
||||
assert "if (newAccepted && !userPersisted) {" in source
|
||||
assert "const fetchWithTurnCancellingRetry = useCallback(" in source
|
||||
assert "computeFallbackTurnCancellingRetryDelay" in source
|
||||
assert 'withMeta.errorCode === "TURN_CANCELLING"' in source
|
||||
assert 'withMeta.errorCode === "THREAD_BUSY"' in source
|
||||
assert "await fetchWithTurnCancellingRetry(() =>" in source
|
||||
|
||||
|
||||
def test_cancel_active_turn_route_contract_exists():
|
||||
routes_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_backend/app/routes/new_chat_routes.py"
|
||||
)
|
||||
source = routes_path.read_text(encoding="utf-8")
|
||||
|
||||
assert '@router.post(\n "/threads/{thread_id}/cancel-active-turn",' in source
|
||||
assert "response_model=CancelActiveTurnResponse" in source
|
||||
assert 'status="cancelling",' in source
|
||||
assert 'error_code="TURN_CANCELLING",' in source
|
||||
assert "retry_after_ms=retry_after_ms if retry_after_ms > 0 else None," in source
|
||||
assert "retry_after_at=" in source
|
||||
assert 'status="idle",' in source
|
||||
assert 'error_code="NO_ACTIVE_TURN",' in source
|
||||
|
||||
|
||||
def test_turn_status_route_contract_exists():
|
||||
routes_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_backend/app/routes/new_chat_routes.py"
|
||||
)
|
||||
source = routes_path.read_text(encoding="utf-8")
|
||||
|
||||
assert '@router.get(\n "/threads/{thread_id}/turn-status",' in source
|
||||
assert "response_model=TurnStatusResponse" in source
|
||||
assert "_build_turn_status_payload(thread_id)" in source
|
||||
assert "Permission.CHATS_READ.value" in source
|
||||
assert "_raise_if_thread_busy_for_start(" in source
|
||||
|
||||
|
||||
def test_turn_cancelling_retry_policy_contract_exists():
|
||||
routes_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_backend/app/routes/new_chat_routes.py"
|
||||
)
|
||||
source = routes_path.read_text(encoding="utf-8")
|
||||
|
||||
assert "TURN_CANCELLING_INITIAL_DELAY_MS = 200" in source
|
||||
assert "TURN_CANCELLING_BACKOFF_FACTOR = 2" in source
|
||||
assert "TURN_CANCELLING_MAX_DELAY_MS = 1500" in source
|
||||
assert "def _compute_turn_cancelling_retry_delay(" in source
|
||||
assert "retry-after-ms" in source
|
||||
assert '"Retry-After"' in source
|
||||
assert '"errorCode": "TURN_CANCELLING"' in source
|
||||
|
||||
|
||||
def test_turn_status_sse_contract_exists():
|
||||
stream_source = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_backend/app/tasks/chat/stream_new_chat.py"
|
||||
).read_text(encoding="utf-8")
|
||||
state_source = (
|
||||
Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/streaming-state.ts"
|
||||
).read_text(encoding="utf-8")
|
||||
pipeline_source = (
|
||||
Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/stream-pipeline.ts"
|
||||
).read_text(encoding="utf-8")
|
||||
|
||||
assert '"turn-status"' in stream_source
|
||||
assert '"status": "busy"' in stream_source
|
||||
assert '"status": "idle"' in stream_source
|
||||
assert "type: \"data-turn-status\"" in state_source
|
||||
assert "case \"data-turn-status\":" in pipeline_source
|
||||
assert "end_turn(str(chat_id))" in stream_source
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue