mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-03 21:02:40 +02:00
refactor(chat): implement turn cancellation and status management in new chat routes for improved user experience and error handling
This commit is contained in:
parent
4056bd1d69
commit
af66fbf106
12 changed files with 671 additions and 81 deletions
|
|
@ -33,6 +33,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
import weakref
|
import weakref
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -58,6 +59,8 @@ class _ThreadLockManager:
|
||||||
weakref.WeakValueDictionary()
|
weakref.WeakValueDictionary()
|
||||||
)
|
)
|
||||||
self._cancel_events: dict[str, asyncio.Event] = {}
|
self._cancel_events: dict[str, asyncio.Event] = {}
|
||||||
|
self._cancel_requested_at_ms: dict[str, int] = {}
|
||||||
|
self._cancel_attempt_count: dict[str, int] = {}
|
||||||
|
|
||||||
def lock_for(self, thread_id: str) -> asyncio.Lock:
|
def lock_for(self, thread_id: str) -> asyncio.Lock:
|
||||||
lock = self._locks.get(thread_id)
|
lock = self._locks.get(thread_id)
|
||||||
|
|
@ -76,14 +79,45 @@ class _ThreadLockManager:
|
||||||
def request_cancel(self, thread_id: str) -> bool:
|
def request_cancel(self, thread_id: str) -> bool:
|
||||||
event = self._cancel_events.get(thread_id)
|
event = self._cancel_events.get(thread_id)
|
||||||
if event is None:
|
if event is None:
|
||||||
return False
|
event = asyncio.Event()
|
||||||
|
self._cancel_events[thread_id] = event
|
||||||
event.set()
|
event.set()
|
||||||
|
now_ms = int(time.time() * 1000)
|
||||||
|
self._cancel_requested_at_ms[thread_id] = now_ms
|
||||||
|
self._cancel_attempt_count[thread_id] = (
|
||||||
|
self._cancel_attempt_count.get(thread_id, 0) + 1
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def is_cancel_requested(self, thread_id: str) -> bool:
|
||||||
|
event = self._cancel_events.get(thread_id)
|
||||||
|
return bool(event and event.is_set())
|
||||||
|
|
||||||
|
def cancel_state(self, thread_id: str) -> tuple[int, int] | None:
|
||||||
|
if not self.is_cancel_requested(thread_id):
|
||||||
|
return None
|
||||||
|
attempts = self._cancel_attempt_count.get(thread_id, 1)
|
||||||
|
requested_at_ms = self._cancel_requested_at_ms.get(thread_id, 0)
|
||||||
|
return attempts, requested_at_ms
|
||||||
|
|
||||||
def reset(self, thread_id: str) -> None:
|
def reset(self, thread_id: str) -> None:
|
||||||
event = self._cancel_events.get(thread_id)
|
event = self._cancel_events.get(thread_id)
|
||||||
if event is not None:
|
if event is not None:
|
||||||
event.clear()
|
event.clear()
|
||||||
|
self._cancel_requested_at_ms.pop(thread_id, None)
|
||||||
|
self._cancel_attempt_count.pop(thread_id, None)
|
||||||
|
|
||||||
|
def end_turn(self, thread_id: str) -> None:
|
||||||
|
"""Best-effort terminal cleanup for a thread turn.
|
||||||
|
|
||||||
|
This is intentionally idempotent and safe to call from outer stream
|
||||||
|
finally-blocks where middleware teardown might be skipped due to abort
|
||||||
|
or disconnect edge-cases.
|
||||||
|
"""
|
||||||
|
lock = self._locks.get(thread_id)
|
||||||
|
if lock is not None and lock.locked():
|
||||||
|
lock.release()
|
||||||
|
self.reset(thread_id)
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton — process-local but reused across all agent
|
# Module-level singleton — process-local but reused across all agent
|
||||||
|
|
@ -98,15 +132,30 @@ def get_cancel_event(thread_id: str) -> asyncio.Event:
|
||||||
|
|
||||||
|
|
||||||
def request_cancel(thread_id: str) -> bool:
|
def request_cancel(thread_id: str) -> bool:
|
||||||
"""Trip the cancel event for ``thread_id``. Returns True if found."""
|
"""Trip the cancel event for ``thread_id``. Always returns True."""
|
||||||
return manager.request_cancel(thread_id)
|
return manager.request_cancel(thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
def is_cancel_requested(thread_id: str) -> bool:
|
||||||
|
"""Return whether ``thread_id`` currently has a pending cancel signal."""
|
||||||
|
return manager.is_cancel_requested(thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cancel_state(thread_id: str) -> tuple[int, int] | None:
|
||||||
|
"""Return ``(attempt_count, requested_at_ms)`` for pending cancel state."""
|
||||||
|
return manager.cancel_state(thread_id)
|
||||||
|
|
||||||
|
|
||||||
def reset_cancel(thread_id: str) -> None:
|
def reset_cancel(thread_id: str) -> None:
|
||||||
"""Reset the cancel event for ``thread_id`` (called between turns)."""
|
"""Reset the cancel event for ``thread_id`` (called between turns)."""
|
||||||
manager.reset(thread_id)
|
manager.reset(thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
def end_turn(thread_id: str) -> None:
|
||||||
|
"""Force end-of-turn cleanup for lock + cancel state."""
|
||||||
|
manager.end_turn(thread_id)
|
||||||
|
|
||||||
|
|
||||||
class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||||
"""Block concurrent prompts on the same thread.
|
"""Block concurrent prompts on the same thread.
|
||||||
|
|
||||||
|
|
@ -229,7 +278,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BusyMutexMiddleware",
|
"BusyMutexMiddleware",
|
||||||
|
"end_turn",
|
||||||
"get_cancel_event",
|
"get_cancel_event",
|
||||||
|
"get_cancel_state",
|
||||||
|
"is_cancel_requested",
|
||||||
"manager",
|
"manager",
|
||||||
"request_cancel",
|
"request_cancel",
|
||||||
"reset_cancel",
|
"reset_cancel",
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy import func, or_
|
from sqlalchemy import func, or_
|
||||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||||
|
|
@ -29,6 +29,12 @@ from app.agents.new_chat.filesystem_selection import (
|
||||||
FilesystemSelection,
|
FilesystemSelection,
|
||||||
LocalFilesystemMount,
|
LocalFilesystemMount,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.middleware.busy_mutex import (
|
||||||
|
get_cancel_state,
|
||||||
|
is_cancel_requested,
|
||||||
|
manager,
|
||||||
|
request_cancel,
|
||||||
|
)
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.db import (
|
from app.db import (
|
||||||
ChatComment,
|
ChatComment,
|
||||||
|
|
@ -44,6 +50,7 @@ from app.db import (
|
||||||
)
|
)
|
||||||
from app.schemas.new_chat import (
|
from app.schemas.new_chat import (
|
||||||
AgentToolInfo,
|
AgentToolInfo,
|
||||||
|
CancelActiveTurnResponse,
|
||||||
LocalFilesystemMountPayload,
|
LocalFilesystemMountPayload,
|
||||||
NewChatMessageRead,
|
NewChatMessageRead,
|
||||||
NewChatRequest,
|
NewChatRequest,
|
||||||
|
|
@ -60,6 +67,7 @@ from app.schemas.new_chat import (
|
||||||
ThreadListItem,
|
ThreadListItem,
|
||||||
ThreadListResponse,
|
ThreadListResponse,
|
||||||
TokenUsageSummary,
|
TokenUsageSummary,
|
||||||
|
TurnStatusResponse,
|
||||||
)
|
)
|
||||||
from app.services.token_tracking_service import record_token_usage
|
from app.services.token_tracking_service import record_token_usage
|
||||||
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
||||||
|
|
@ -72,6 +80,9 @@ from app.utils.user_message_multimodal import (
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
_background_tasks: set[asyncio.Task] = set()
|
_background_tasks: set[asyncio.Task] = set()
|
||||||
|
TURN_CANCELLING_INITIAL_DELAY_MS = 200
|
||||||
|
TURN_CANCELLING_BACKOFF_FACTOR = 2
|
||||||
|
TURN_CANCELLING_MAX_DELAY_MS = 1500
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
@ -137,6 +148,72 @@ def _resolve_filesystem_selection(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
|
||||||
|
"""Bounded exponential delay for TURN_CANCELLING retry hints."""
|
||||||
|
if attempt < 1:
|
||||||
|
attempt = 1
|
||||||
|
delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
|
||||||
|
TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
|
||||||
|
)
|
||||||
|
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_turn_status_payload(thread_id: int) -> dict[str, object]:
|
||||||
|
lock = manager.lock_for(str(thread_id))
|
||||||
|
if not lock.locked():
|
||||||
|
return {"status": "idle"}
|
||||||
|
|
||||||
|
if is_cancel_requested(str(thread_id)):
|
||||||
|
cancel_state = get_cancel_state(str(thread_id))
|
||||||
|
attempt = cancel_state[0] if cancel_state else 1
|
||||||
|
retry_after_ms = _compute_turn_cancelling_retry_delay(attempt)
|
||||||
|
retry_after_at = int(datetime.now(UTC).timestamp() * 1000) + retry_after_ms
|
||||||
|
return {
|
||||||
|
"status": "cancelling",
|
||||||
|
"retry_after_ms": retry_after_ms,
|
||||||
|
"retry_after_at": retry_after_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
return {"status": "busy"}
|
||||||
|
|
||||||
|
|
||||||
|
def _set_retry_after_headers(response: Response, retry_after_ms: int) -> None:
|
||||||
|
response.headers["retry-after-ms"] = str(retry_after_ms)
|
||||||
|
response.headers["Retry-After"] = str(max(1, (retry_after_ms + 999) // 1000))
|
||||||
|
|
||||||
|
|
||||||
|
def _raise_if_thread_busy_for_start(thread_id: int) -> None:
|
||||||
|
status_payload = _build_turn_status_payload(thread_id)
|
||||||
|
status = status_payload["status"]
|
||||||
|
if status == "idle":
|
||||||
|
return
|
||||||
|
if status == "cancelling":
|
||||||
|
retry_after_ms = int(status_payload.get("retry_after_ms") or 0)
|
||||||
|
detail = {
|
||||||
|
"errorCode": "TURN_CANCELLING",
|
||||||
|
"message": "A previous response is still stopping. Please try again in a moment.",
|
||||||
|
"retry_after_ms": retry_after_ms if retry_after_ms > 0 else None,
|
||||||
|
"retry_after_at": status_payload.get("retry_after_at"),
|
||||||
|
}
|
||||||
|
headers = (
|
||||||
|
{
|
||||||
|
"retry-after-ms": str(retry_after_ms),
|
||||||
|
"Retry-After": str(max(1, (retry_after_ms + 999) // 1000)),
|
||||||
|
}
|
||||||
|
if retry_after_ms > 0
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
raise HTTPException(status_code=409, detail=detail, headers=headers)
|
||||||
|
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409,
|
||||||
|
detail={
|
||||||
|
"errorCode": "THREAD_BUSY",
|
||||||
|
"message": "Another response is still finishing for this thread. Please try again in a moment.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _find_pre_turn_checkpoint_id(
|
def _find_pre_turn_checkpoint_id(
|
||||||
checkpoint_tuples: list,
|
checkpoint_tuples: list,
|
||||||
*,
|
*,
|
||||||
|
|
@ -1476,6 +1553,7 @@ async def handle_new_chat(
|
||||||
|
|
||||||
# Check thread-level access based on visibility
|
# Check thread-level access based on visibility
|
||||||
await check_thread_access(session, thread, user)
|
await check_thread_access(session, thread, user)
|
||||||
|
_raise_if_thread_busy_for_start(request.chat_id)
|
||||||
filesystem_selection = _resolve_filesystem_selection(
|
filesystem_selection = _resolve_filesystem_selection(
|
||||||
mode=request.filesystem_mode,
|
mode=request.filesystem_mode,
|
||||||
client_platform=request.client_platform,
|
client_platform=request.client_platform,
|
||||||
|
|
@ -1550,6 +1628,93 @@ async def handle_new_chat(
|
||||||
) from None
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/threads/{thread_id}/cancel-active-turn",
|
||||||
|
response_model=CancelActiveTurnResponse,
|
||||||
|
)
|
||||||
|
async def cancel_active_turn(
|
||||||
|
thread_id: int,
|
||||||
|
response: Response,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""Signal cancellation for the currently running turn on ``thread_id``."""
|
||||||
|
result = await session.execute(
|
||||||
|
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||||
|
)
|
||||||
|
thread = result.scalars().first()
|
||||||
|
if not thread:
|
||||||
|
raise HTTPException(status_code=404, detail="Thread not found")
|
||||||
|
|
||||||
|
await check_permission(
|
||||||
|
session,
|
||||||
|
user,
|
||||||
|
thread.search_space_id,
|
||||||
|
Permission.CHATS_UPDATE.value,
|
||||||
|
"You don't have permission to update chats in this search space",
|
||||||
|
)
|
||||||
|
await check_thread_access(session, thread, user)
|
||||||
|
|
||||||
|
status_payload = _build_turn_status_payload(thread_id)
|
||||||
|
if status_payload["status"] == "idle":
|
||||||
|
return CancelActiveTurnResponse(
|
||||||
|
status="idle",
|
||||||
|
error_code="NO_ACTIVE_TURN",
|
||||||
|
)
|
||||||
|
|
||||||
|
request_cancel(str(thread_id))
|
||||||
|
response.status_code = 202
|
||||||
|
updated_payload = _build_turn_status_payload(thread_id)
|
||||||
|
retry_after_ms = int(updated_payload.get("retry_after_ms") or 0)
|
||||||
|
retry_after_at = (
|
||||||
|
int(updated_payload["retry_after_at"])
|
||||||
|
if "retry_after_at" in updated_payload
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if retry_after_ms > 0:
|
||||||
|
_set_retry_after_headers(response, retry_after_ms)
|
||||||
|
return CancelActiveTurnResponse(
|
||||||
|
status="cancelling",
|
||||||
|
error_code="TURN_CANCELLING",
|
||||||
|
retry_after_ms=retry_after_ms if retry_after_ms > 0 else None,
|
||||||
|
retry_after_at=retry_after_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/threads/{thread_id}/turn-status",
|
||||||
|
response_model=TurnStatusResponse,
|
||||||
|
)
|
||||||
|
async def get_turn_status(
|
||||||
|
thread_id: int,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
result = await session.execute(
|
||||||
|
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||||
|
)
|
||||||
|
thread = result.scalars().first()
|
||||||
|
if not thread:
|
||||||
|
raise HTTPException(status_code=404, detail="Thread not found")
|
||||||
|
|
||||||
|
await check_permission(
|
||||||
|
session,
|
||||||
|
user,
|
||||||
|
thread.search_space_id,
|
||||||
|
Permission.CHATS_READ.value,
|
||||||
|
"You don't have permission to view chats in this search space",
|
||||||
|
)
|
||||||
|
await check_thread_access(session, thread, user)
|
||||||
|
|
||||||
|
status_payload = _build_turn_status_payload(thread_id)
|
||||||
|
return TurnStatusResponse(
|
||||||
|
status=status_payload["status"], # type: ignore[arg-type]
|
||||||
|
active_turn_id=None,
|
||||||
|
retry_after_ms=status_payload.get("retry_after_ms"), # type: ignore[arg-type]
|
||||||
|
retry_after_at=status_payload.get("retry_after_at"), # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Chat Regeneration Endpoint (Edit/Reload)
|
# Chat Regeneration Endpoint (Edit/Reload)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
@ -1605,6 +1770,7 @@ async def regenerate_response(
|
||||||
|
|
||||||
# Check thread-level access based on visibility
|
# Check thread-level access based on visibility
|
||||||
await check_thread_access(session, thread, user)
|
await check_thread_access(session, thread, user)
|
||||||
|
_raise_if_thread_busy_for_start(thread_id)
|
||||||
filesystem_selection = _resolve_filesystem_selection(
|
filesystem_selection = _resolve_filesystem_selection(
|
||||||
mode=request.filesystem_mode,
|
mode=request.filesystem_mode,
|
||||||
client_platform=request.client_platform,
|
client_platform=request.client_platform,
|
||||||
|
|
@ -2012,6 +2178,7 @@ async def resume_chat(
|
||||||
)
|
)
|
||||||
|
|
||||||
await check_thread_access(session, thread, user)
|
await check_thread_access(session, thread, user)
|
||||||
|
_raise_if_thread_busy_for_start(thread_id)
|
||||||
filesystem_selection = _resolve_filesystem_selection(
|
filesystem_selection = _resolve_filesystem_selection(
|
||||||
mode=request.filesystem_mode,
|
mode=request.filesystem_mode,
|
||||||
client_platform=request.client_platform,
|
client_platform=request.client_platform,
|
||||||
|
|
|
||||||
|
|
@ -335,6 +335,24 @@ class ResumeRequest(BaseModel):
|
||||||
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
|
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CancelActiveTurnResponse(BaseModel):
|
||||||
|
"""Response for canceling an active turn on a chat thread."""
|
||||||
|
|
||||||
|
status: Literal["cancelling", "idle"]
|
||||||
|
error_code: Literal["TURN_CANCELLING", "NO_ACTIVE_TURN"]
|
||||||
|
retry_after_ms: int | None = None
|
||||||
|
retry_after_at: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TurnStatusResponse(BaseModel):
|
||||||
|
"""Current turn execution status for a thread."""
|
||||||
|
|
||||||
|
status: Literal["idle", "busy", "cancelling"]
|
||||||
|
active_turn_id: str | None = None
|
||||||
|
retry_after_ms: int | None = None
|
||||||
|
retry_after_at: int | None = None
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Public Chat Snapshot Schemas
|
# Public Chat Snapshot Schemas
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
|
||||||
|
|
@ -565,7 +565,12 @@ class VercelStreamingService:
|
||||||
# Error Part
|
# Error Part
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
|
|
||||||
def format_error(self, error_text: str, error_code: str | None = None) -> str:
|
def format_error(
|
||||||
|
self,
|
||||||
|
error_text: str,
|
||||||
|
error_code: str | None = None,
|
||||||
|
extra: dict[str, object] | None = None,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Format an error message.
|
Format an error message.
|
||||||
|
|
||||||
|
|
@ -579,9 +584,11 @@ class VercelStreamingService:
|
||||||
Example output:
|
Example output:
|
||||||
data: {"type":"error","errorText":"Something went wrong","errorCode":"SOME_CODE"}
|
data: {"type":"error","errorText":"Something went wrong","errorCode":"SOME_CODE"}
|
||||||
"""
|
"""
|
||||||
payload: dict[str, str] = {"type": "error", "errorText": error_text}
|
payload: dict[str, object] = {"type": "error", "errorText": error_text}
|
||||||
if error_code:
|
if error_code:
|
||||||
payload["errorCode"] = error_code
|
payload["errorCode"] = error_code
|
||||||
|
if extra:
|
||||||
|
payload.update(extra)
|
||||||
return self._format_sse(payload)
|
return self._format_sse(payload)
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,11 @@ from app.agents.new_chat.memory_extraction import (
|
||||||
extract_and_save_memory,
|
extract_and_save_memory,
|
||||||
extract_and_save_team_memory,
|
extract_and_save_team_memory,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.middleware.busy_mutex import (
|
||||||
|
end_turn,
|
||||||
|
get_cancel_state,
|
||||||
|
is_cancel_requested,
|
||||||
|
)
|
||||||
from app.agents.new_chat.middleware.kb_persistence import (
|
from app.agents.new_chat.middleware.kb_persistence import (
|
||||||
commit_staged_filesystem_state,
|
commit_staged_filesystem_state,
|
||||||
)
|
)
|
||||||
|
|
@ -72,6 +77,18 @@ from app.utils.user_message_multimodal import build_human_message_content
|
||||||
|
|
||||||
_background_tasks: set[asyncio.Task] = set()
|
_background_tasks: set[asyncio.Task] = set()
|
||||||
_perf_log = get_perf_logger()
|
_perf_log = get_perf_logger()
|
||||||
|
TURN_CANCELLING_INITIAL_DELAY_MS = 200
|
||||||
|
TURN_CANCELLING_BACKOFF_FACTOR = 2
|
||||||
|
TURN_CANCELLING_MAX_DELAY_MS = 1500
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
|
||||||
|
if attempt < 1:
|
||||||
|
attempt = 1
|
||||||
|
delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
|
||||||
|
TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
|
||||||
|
)
|
||||||
|
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
|
||||||
|
|
||||||
|
|
||||||
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
|
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
|
||||||
|
|
@ -401,15 +418,35 @@ def _classify_stream_exception(
|
||||||
exc: Exception,
|
exc: Exception,
|
||||||
*,
|
*,
|
||||||
flow_label: str,
|
flow_label: str,
|
||||||
) -> tuple[str, str, Literal["info", "warn", "error"], bool, str]:
|
) -> tuple[
|
||||||
|
str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None
|
||||||
|
]:
|
||||||
raw = str(exc)
|
raw = str(exc)
|
||||||
if isinstance(exc, BusyError) or "Thread is busy with another request" in raw:
|
if isinstance(exc, BusyError) or "Thread is busy with another request" in raw:
|
||||||
|
busy_thread_id = str(exc.request_id) if isinstance(exc, BusyError) else None
|
||||||
|
if busy_thread_id and is_cancel_requested(busy_thread_id):
|
||||||
|
cancel_state = get_cancel_state(busy_thread_id)
|
||||||
|
attempt = cancel_state[0] if cancel_state else 1
|
||||||
|
retry_after_ms = _compute_turn_cancelling_retry_delay(attempt)
|
||||||
|
retry_after_at = int(time.time() * 1000) + retry_after_ms
|
||||||
|
return (
|
||||||
|
"thread_busy",
|
||||||
|
"TURN_CANCELLING",
|
||||||
|
"info",
|
||||||
|
True,
|
||||||
|
"A previous response is still stopping. Please try again in a moment.",
|
||||||
|
{
|
||||||
|
"retry_after_ms": retry_after_ms,
|
||||||
|
"retry_after_at": retry_after_at,
|
||||||
|
},
|
||||||
|
)
|
||||||
return (
|
return (
|
||||||
"thread_busy",
|
"thread_busy",
|
||||||
"THREAD_BUSY",
|
"THREAD_BUSY",
|
||||||
"warn",
|
"warn",
|
||||||
True,
|
True,
|
||||||
"Another response is still finishing for this thread. Please try again in a moment.",
|
"Another response is still finishing for this thread. Please try again in a moment.",
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
parsed = _parse_error_payload(raw)
|
parsed = _parse_error_payload(raw)
|
||||||
|
|
@ -431,6 +468,7 @@ def _classify_stream_exception(
|
||||||
"warn",
|
"warn",
|
||||||
True,
|
True,
|
||||||
"This model is temporarily rate-limited. Please try again in a few seconds or switch models.",
|
"This model is temporarily rate-limited. Please try again in a few seconds or switch models.",
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|
@ -439,6 +477,7 @@ def _classify_stream_exception(
|
||||||
"error",
|
"error",
|
||||||
False,
|
False,
|
||||||
f"Error during {flow_label}: {raw}",
|
f"Error during {flow_label}: {raw}",
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -470,7 +509,7 @@ def _emit_stream_terminal_error(
|
||||||
message=message,
|
message=message,
|
||||||
extra=extra,
|
extra=extra,
|
||||||
)
|
)
|
||||||
return streaming_service.format_error(message, error_code=error_code)
|
return streaming_service.format_error(message, error_code=error_code, extra=extra)
|
||||||
|
|
||||||
|
|
||||||
def _legacy_match_lc_id(
|
def _legacy_match_lc_id(
|
||||||
|
|
@ -2497,6 +2536,7 @@ async def stream_new_chat(
|
||||||
"turn-info",
|
"turn-info",
|
||||||
{"chat_turn_id": stream_result.turn_id},
|
{"chat_turn_id": stream_result.turn_id},
|
||||||
)
|
)
|
||||||
|
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||||
|
|
||||||
# Initial thinking step - analyzing the request
|
# Initial thinking step - analyzing the request
|
||||||
if mentioned_surfsense_docs:
|
if mentioned_surfsense_docs:
|
||||||
|
|
@ -2805,6 +2845,7 @@ async def stream_new_chat(
|
||||||
task.add_done_callback(_background_tasks.discard)
|
task.add_done_callback(_background_tasks.discard)
|
||||||
|
|
||||||
# Finish the step and message
|
# Finish the step and message
|
||||||
|
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||||
yield streaming_service.format_finish_step()
|
yield streaming_service.format_finish_step()
|
||||||
yield streaming_service.format_finish()
|
yield streaming_service.format_finish()
|
||||||
yield streaming_service.format_done()
|
yield streaming_service.format_done()
|
||||||
|
|
@ -2819,11 +2860,19 @@ async def stream_new_chat(
|
||||||
severity,
|
severity,
|
||||||
is_expected,
|
is_expected,
|
||||||
user_message,
|
user_message,
|
||||||
|
error_extra,
|
||||||
) = _classify_stream_exception(e, flow_label="chat")
|
) = _classify_stream_exception(e, flow_label="chat")
|
||||||
error_message = f"Error during chat: {e!s}"
|
error_message = f"Error during chat: {e!s}"
|
||||||
print(f"[stream_new_chat] {error_message}")
|
print(f"[stream_new_chat] {error_message}")
|
||||||
print(f"[stream_new_chat] Exception type: {type(e).__name__}")
|
print(f"[stream_new_chat] Exception type: {type(e).__name__}")
|
||||||
print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}")
|
print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}")
|
||||||
|
if error_code == "TURN_CANCELLING":
|
||||||
|
status_payload: dict[str, Any] = {"status": "cancelling"}
|
||||||
|
if error_extra:
|
||||||
|
status_payload.update(error_extra)
|
||||||
|
yield streaming_service.format_data("turn-status", status_payload)
|
||||||
|
else:
|
||||||
|
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||||
|
|
||||||
yield _emit_stream_error(
|
yield _emit_stream_error(
|
||||||
message=user_message,
|
message=user_message,
|
||||||
|
|
@ -2831,7 +2880,9 @@ async def stream_new_chat(
|
||||||
error_code=error_code,
|
error_code=error_code,
|
||||||
severity=severity,
|
severity=severity,
|
||||||
is_expected=is_expected,
|
is_expected=is_expected,
|
||||||
|
extra=error_extra,
|
||||||
)
|
)
|
||||||
|
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||||
yield streaming_service.format_finish_step()
|
yield streaming_service.format_finish_step()
|
||||||
yield streaming_service.format_finish()
|
yield streaming_service.format_finish()
|
||||||
yield streaming_service.format_done()
|
yield streaming_service.format_done()
|
||||||
|
|
@ -2847,6 +2898,10 @@ async def stream_new_chat(
|
||||||
# (CancelledError is a BaseException), and the rest of the
|
# (CancelledError is a BaseException), and the rest of the
|
||||||
# finally block — including session.close() — would never run.
|
# finally block — including session.close() — would never run.
|
||||||
with anyio.CancelScope(shield=True):
|
with anyio.CancelScope(shield=True):
|
||||||
|
# Authoritative fallback cleanup for lock/cancel state. Middleware
|
||||||
|
# teardown can be skipped on some client-abort paths.
|
||||||
|
end_turn(str(chat_id))
|
||||||
|
|
||||||
# Release premium reservation if not finalized
|
# Release premium reservation if not finalized
|
||||||
if _premium_request_id and _premium_reserved > 0 and user_id:
|
if _premium_request_id and _premium_reserved > 0 and user_id:
|
||||||
try:
|
try:
|
||||||
|
|
@ -3206,6 +3261,7 @@ async def stream_resume_chat(
|
||||||
"turn-info",
|
"turn-info",
|
||||||
{"chat_turn_id": stream_result.turn_id},
|
{"chat_turn_id": stream_result.turn_id},
|
||||||
)
|
)
|
||||||
|
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||||
|
|
||||||
_t_stream_start = time.perf_counter()
|
_t_stream_start = time.perf_counter()
|
||||||
_first_event_logged = False
|
_first_event_logged = False
|
||||||
|
|
@ -3305,6 +3361,7 @@ async def stream_resume_chat(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||||
yield streaming_service.format_finish_step()
|
yield streaming_service.format_finish_step()
|
||||||
yield streaming_service.format_finish()
|
yield streaming_service.format_finish()
|
||||||
yield streaming_service.format_done()
|
yield streaming_service.format_done()
|
||||||
|
|
@ -3318,23 +3375,37 @@ async def stream_resume_chat(
|
||||||
severity,
|
severity,
|
||||||
is_expected,
|
is_expected,
|
||||||
user_message,
|
user_message,
|
||||||
|
error_extra,
|
||||||
) = _classify_stream_exception(e, flow_label="resume")
|
) = _classify_stream_exception(e, flow_label="resume")
|
||||||
error_message = f"Error during resume: {e!s}"
|
error_message = f"Error during resume: {e!s}"
|
||||||
print(f"[stream_resume_chat] {error_message}")
|
print(f"[stream_resume_chat] {error_message}")
|
||||||
print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}")
|
print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}")
|
||||||
|
if error_code == "TURN_CANCELLING":
|
||||||
|
status_payload: dict[str, Any] = {"status": "cancelling"}
|
||||||
|
if error_extra:
|
||||||
|
status_payload.update(error_extra)
|
||||||
|
yield streaming_service.format_data("turn-status", status_payload)
|
||||||
|
else:
|
||||||
|
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||||
yield _emit_stream_error(
|
yield _emit_stream_error(
|
||||||
message=user_message,
|
message=user_message,
|
||||||
error_kind=error_kind,
|
error_kind=error_kind,
|
||||||
error_code=error_code,
|
error_code=error_code,
|
||||||
severity=severity,
|
severity=severity,
|
||||||
is_expected=is_expected,
|
is_expected=is_expected,
|
||||||
|
extra=error_extra,
|
||||||
)
|
)
|
||||||
|
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||||
yield streaming_service.format_finish_step()
|
yield streaming_service.format_finish_step()
|
||||||
yield streaming_service.format_finish()
|
yield streaming_service.format_finish()
|
||||||
yield streaming_service.format_done()
|
yield streaming_service.format_done()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
with anyio.CancelScope(shield=True):
|
with anyio.CancelScope(shield=True):
|
||||||
|
# Authoritative fallback cleanup for lock/cancel state. Middleware
|
||||||
|
# teardown can be skipped on some client-abort paths.
|
||||||
|
end_turn(str(chat_id))
|
||||||
|
|
||||||
# Release premium reservation if not finalized
|
# Release premium reservation if not finalized
|
||||||
if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id:
|
if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id:
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,9 @@ import pytest
|
||||||
from app.agents.new_chat.errors import BusyError
|
from app.agents.new_chat.errors import BusyError
|
||||||
from app.agents.new_chat.middleware.busy_mutex import (
|
from app.agents.new_chat.middleware.busy_mutex import (
|
||||||
BusyMutexMiddleware,
|
BusyMutexMiddleware,
|
||||||
|
end_turn,
|
||||||
get_cancel_event,
|
get_cancel_event,
|
||||||
|
is_cancel_requested,
|
||||||
manager,
|
manager,
|
||||||
request_cancel,
|
request_cancel,
|
||||||
reset_cancel,
|
reset_cancel,
|
||||||
|
|
@ -88,3 +90,31 @@ async def test_no_thread_id_skipped_when_not_required() -> None:
|
||||||
def test_reset_cancel_idempotent() -> None:
|
def test_reset_cancel_idempotent() -> None:
|
||||||
# Should not raise even if event was never created
|
# Should not raise even if event was never created
|
||||||
reset_cancel("never-seen")
|
reset_cancel("never-seen")
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_cancel_creates_event_for_unseen_thread() -> None:
|
||||||
|
thread_id = "never-seen-cancel"
|
||||||
|
reset_cancel(thread_id)
|
||||||
|
|
||||||
|
assert request_cancel(thread_id) is True
|
||||||
|
assert get_cancel_event(thread_id).is_set()
|
||||||
|
assert is_cancel_requested(thread_id) is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_end_turn_force_clears_lock_and_cancel_state() -> None:
|
||||||
|
thread_id = "forced-end-turn"
|
||||||
|
mw = BusyMutexMiddleware()
|
||||||
|
runtime = _Runtime(thread_id)
|
||||||
|
|
||||||
|
await mw.abefore_agent({}, runtime)
|
||||||
|
assert manager.lock_for(thread_id).locked()
|
||||||
|
|
||||||
|
request_cancel(thread_id)
|
||||||
|
assert is_cancel_requested(thread_id) is True
|
||||||
|
|
||||||
|
end_turn(thread_id)
|
||||||
|
|
||||||
|
assert not manager.lock_for(thread_id).locked()
|
||||||
|
assert not get_cancel_event(thread_id).is_set()
|
||||||
|
assert is_cancel_requested(thread_id) is False
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ import pytest
|
||||||
|
|
||||||
import app.tasks.chat.stream_new_chat as stream_new_chat_module
|
import app.tasks.chat.stream_new_chat as stream_new_chat_module
|
||||||
from app.agents.new_chat.errors import BusyError
|
from app.agents.new_chat.errors import BusyError
|
||||||
|
from app.agents.new_chat.middleware.busy_mutex import request_cancel, reset_cancel
|
||||||
from app.tasks.chat.stream_new_chat import (
|
from app.tasks.chat.stream_new_chat import (
|
||||||
StreamResult,
|
StreamResult,
|
||||||
_classify_stream_exception,
|
_classify_stream_exception,
|
||||||
|
|
@ -147,7 +148,7 @@ def test_stream_exception_classifies_rate_limited():
|
||||||
exc = Exception(
|
exc = Exception(
|
||||||
'{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}'
|
'{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}'
|
||||||
)
|
)
|
||||||
kind, code, severity, is_expected, user_message = _classify_stream_exception(
|
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||||
exc, flow_label="chat"
|
exc, flow_label="chat"
|
||||||
)
|
)
|
||||||
assert kind == "rate_limited"
|
assert kind == "rate_limited"
|
||||||
|
|
@ -155,11 +156,12 @@ def test_stream_exception_classifies_rate_limited():
|
||||||
assert severity == "warn"
|
assert severity == "warn"
|
||||||
assert is_expected is True
|
assert is_expected is True
|
||||||
assert "temporarily rate-limited" in user_message
|
assert "temporarily rate-limited" in user_message
|
||||||
|
assert extra is None
|
||||||
|
|
||||||
|
|
||||||
def test_stream_exception_classifies_thread_busy():
|
def test_stream_exception_classifies_thread_busy():
|
||||||
exc = BusyError(request_id="thread-123")
|
exc = BusyError(request_id="thread-123")
|
||||||
kind, code, severity, is_expected, user_message = _classify_stream_exception(
|
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||||
exc, flow_label="chat"
|
exc, flow_label="chat"
|
||||||
)
|
)
|
||||||
assert kind == "thread_busy"
|
assert kind == "thread_busy"
|
||||||
|
|
@ -167,11 +169,12 @@ def test_stream_exception_classifies_thread_busy():
|
||||||
assert severity == "warn"
|
assert severity == "warn"
|
||||||
assert is_expected is True
|
assert is_expected is True
|
||||||
assert "still finishing for this thread" in user_message
|
assert "still finishing for this thread" in user_message
|
||||||
|
assert extra is None
|
||||||
|
|
||||||
|
|
||||||
def test_stream_exception_classifies_thread_busy_from_message():
|
def test_stream_exception_classifies_thread_busy_from_message():
|
||||||
exc = Exception("Thread is busy with another request")
|
exc = Exception("Thread is busy with another request")
|
||||||
kind, code, severity, is_expected, user_message = _classify_stream_exception(
|
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||||
exc, flow_label="chat"
|
exc, flow_label="chat"
|
||||||
)
|
)
|
||||||
assert kind == "thread_busy"
|
assert kind == "thread_busy"
|
||||||
|
|
@ -179,6 +182,24 @@ def test_stream_exception_classifies_thread_busy_from_message():
|
||||||
assert severity == "warn"
|
assert severity == "warn"
|
||||||
assert is_expected is True
|
assert is_expected is True
|
||||||
assert "still finishing for this thread" in user_message
|
assert "still finishing for this thread" in user_message
|
||||||
|
assert extra is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_exception_classifies_turn_cancelling_when_cancel_requested():
|
||||||
|
thread_id = "thread-cancelling-1"
|
||||||
|
reset_cancel(thread_id)
|
||||||
|
request_cancel(thread_id)
|
||||||
|
exc = BusyError(request_id=thread_id)
|
||||||
|
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||||
|
exc, flow_label="chat"
|
||||||
|
)
|
||||||
|
assert kind == "thread_busy"
|
||||||
|
assert code == "TURN_CANCELLING"
|
||||||
|
assert severity == "info"
|
||||||
|
assert is_expected is True
|
||||||
|
assert "stopping" in user_message
|
||||||
|
assert isinstance(extra, dict)
|
||||||
|
assert "retry_after_ms" in extra
|
||||||
|
|
||||||
|
|
||||||
def test_premium_classification_is_error_code_driven():
|
def test_premium_classification_is_error_code_driven():
|
||||||
|
|
@ -219,33 +240,33 @@ def test_toast_only_pre_accept_policy_has_no_inline_failed_marker():
|
||||||
def test_network_send_failures_use_unified_retry_toast_message():
|
def test_network_send_failures_use_unified_retry_toast_message():
|
||||||
classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts"
|
classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts"
|
||||||
classifier_source = classifier_path.read_text(encoding="utf-8")
|
classifier_source = classifier_path.read_text(encoding="utf-8")
|
||||||
page_path = (
|
request_errors_path = (
|
||||||
Path(__file__).resolve().parents[3]
|
Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-request-errors.ts"
|
||||||
/ "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
|
|
||||||
)
|
)
|
||||||
page_source = page_path.read_text(encoding="utf-8")
|
request_errors_source = request_errors_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
assert '"send_failed_pre_accept"' in classifier_source
|
assert '"send_failed_pre_accept"' in classifier_source
|
||||||
assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source
|
assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source
|
||||||
|
assert 'errorCode === "TURN_CANCELLING"' in classifier_source
|
||||||
assert "if (withCode.code) return withCode.code;" in classifier_source
|
assert "if (withCode.code) return withCode.code;" in classifier_source
|
||||||
assert 'userMessage: "Message not sent. Please retry."' in classifier_source
|
assert 'userMessage: "Message not sent. Please retry."' in classifier_source
|
||||||
assert 'userMessage: "Connection issue. Please try again."' in classifier_source
|
assert 'userMessage: "Connection issue. Please try again."' in classifier_source
|
||||||
assert "tagPreAcceptSendFailure(error)" in page_source
|
assert "const passthroughCodes = new Set([" in request_errors_source
|
||||||
assert "const passthroughCodes = new Set([" in page_source
|
assert '"PREMIUM_QUOTA_EXHAUSTED"' in request_errors_source
|
||||||
assert '"PREMIUM_QUOTA_EXHAUSTED"' in page_source
|
assert '"THREAD_BUSY"' in request_errors_source
|
||||||
assert '"THREAD_BUSY"' in page_source
|
assert '"TURN_CANCELLING"' in request_errors_source
|
||||||
assert '"AUTH_EXPIRED"' in page_source
|
assert '"AUTH_EXPIRED"' in request_errors_source
|
||||||
assert '"UNAUTHORIZED"' in page_source
|
assert '"UNAUTHORIZED"' in request_errors_source
|
||||||
assert '"RATE_LIMITED"' in page_source
|
assert '"RATE_LIMITED"' in request_errors_source
|
||||||
assert '"NETWORK_ERROR"' in page_source
|
assert '"NETWORK_ERROR"' in request_errors_source
|
||||||
assert '"STREAM_PARSE_ERROR"' in page_source
|
assert '"STREAM_PARSE_ERROR"' in request_errors_source
|
||||||
assert '"TOOL_EXECUTION_ERROR"' in page_source
|
assert '"TOOL_EXECUTION_ERROR"' in request_errors_source
|
||||||
assert '"PERSIST_MESSAGE_FAILED"' in page_source
|
assert '"PERSIST_MESSAGE_FAILED"' in request_errors_source
|
||||||
assert '"SERVER_ERROR"' in page_source
|
assert '"SERVER_ERROR"' in request_errors_source
|
||||||
assert "passthroughCodes.has(existingCode)" in page_source
|
assert "passthroughCodes.has(existingCode)" in request_errors_source
|
||||||
assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in page_source
|
assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in request_errors_source
|
||||||
assert 'errorCode: "NETWORK_ERROR"' not in page_source
|
assert 'errorCode: "NETWORK_ERROR"' not in request_errors_source
|
||||||
assert "Failed to start chat. Please try again." not in page_source
|
assert "Failed to start chat. Please try again." not in classifier_source
|
||||||
|
|
||||||
|
|
||||||
def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows():
|
def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows():
|
||||||
|
|
@ -269,3 +290,75 @@ def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows()
|
||||||
|
|
||||||
# New flow persists only when accepted and not already persisted.
|
# New flow persists only when accepted and not already persisted.
|
||||||
assert "if (newAccepted && !userPersisted) {" in source
|
assert "if (newAccepted && !userPersisted) {" in source
|
||||||
|
assert "const fetchWithTurnCancellingRetry = useCallback(" in source
|
||||||
|
assert "computeFallbackTurnCancellingRetryDelay" in source
|
||||||
|
assert 'withMeta.errorCode === "TURN_CANCELLING"' in source
|
||||||
|
assert 'withMeta.errorCode === "THREAD_BUSY"' in source
|
||||||
|
assert "await fetchWithTurnCancellingRetry(() =>" in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancel_active_turn_route_contract_exists():
|
||||||
|
routes_path = (
|
||||||
|
Path(__file__).resolve().parents[3]
|
||||||
|
/ "surfsense_backend/app/routes/new_chat_routes.py"
|
||||||
|
)
|
||||||
|
source = routes_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
assert '@router.post(\n "/threads/{thread_id}/cancel-active-turn",' in source
|
||||||
|
assert "response_model=CancelActiveTurnResponse" in source
|
||||||
|
assert 'status="cancelling",' in source
|
||||||
|
assert 'error_code="TURN_CANCELLING",' in source
|
||||||
|
assert "retry_after_ms=retry_after_ms if retry_after_ms > 0 else None," in source
|
||||||
|
assert "retry_after_at=" in source
|
||||||
|
assert 'status="idle",' in source
|
||||||
|
assert 'error_code="NO_ACTIVE_TURN",' in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_turn_status_route_contract_exists():
|
||||||
|
routes_path = (
|
||||||
|
Path(__file__).resolve().parents[3]
|
||||||
|
/ "surfsense_backend/app/routes/new_chat_routes.py"
|
||||||
|
)
|
||||||
|
source = routes_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
assert '@router.get(\n "/threads/{thread_id}/turn-status",' in source
|
||||||
|
assert "response_model=TurnStatusResponse" in source
|
||||||
|
assert "_build_turn_status_payload(thread_id)" in source
|
||||||
|
assert "Permission.CHATS_READ.value" in source
|
||||||
|
assert "_raise_if_thread_busy_for_start(" in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_turn_cancelling_retry_policy_contract_exists():
|
||||||
|
routes_path = (
|
||||||
|
Path(__file__).resolve().parents[3]
|
||||||
|
/ "surfsense_backend/app/routes/new_chat_routes.py"
|
||||||
|
)
|
||||||
|
source = routes_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
assert "TURN_CANCELLING_INITIAL_DELAY_MS = 200" in source
|
||||||
|
assert "TURN_CANCELLING_BACKOFF_FACTOR = 2" in source
|
||||||
|
assert "TURN_CANCELLING_MAX_DELAY_MS = 1500" in source
|
||||||
|
assert "def _compute_turn_cancelling_retry_delay(" in source
|
||||||
|
assert "retry-after-ms" in source
|
||||||
|
assert '"Retry-After"' in source
|
||||||
|
assert '"errorCode": "TURN_CANCELLING"' in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_turn_status_sse_contract_exists():
|
||||||
|
stream_source = (
|
||||||
|
Path(__file__).resolve().parents[3]
|
||||||
|
/ "surfsense_backend/app/tasks/chat/stream_new_chat.py"
|
||||||
|
).read_text(encoding="utf-8")
|
||||||
|
state_source = (
|
||||||
|
Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/streaming-state.ts"
|
||||||
|
).read_text(encoding="utf-8")
|
||||||
|
pipeline_source = (
|
||||||
|
Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/stream-pipeline.ts"
|
||||||
|
).read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
assert '"turn-status"' in stream_source
|
||||||
|
assert '"status": "busy"' in stream_source
|
||||||
|
assert '"status": "idle"' in stream_source
|
||||||
|
assert "type: \"data-turn-status\"" in state_source
|
||||||
|
assert "case \"data-turn-status\":" in pipeline_source
|
||||||
|
assert "end_turn(str(chat_id))" in stream_source
|
||||||
|
|
|
||||||
|
|
@ -182,6 +182,20 @@ function extractMentionedDocuments(content: unknown): MentionedDocumentInfo[] {
|
||||||
* ``stream_new_chat.py``) keep the JSON from ballooning.
|
* ``stream_new_chat.py``) keep the JSON from ballooning.
|
||||||
*/
|
*/
|
||||||
const TOOLS_WITH_UI_ALL: ToolUIGate = "all";
|
const TOOLS_WITH_UI_ALL: ToolUIGate = "all";
|
||||||
|
const TURN_CANCELLING_INITIAL_DELAY_MS = 200;
|
||||||
|
const TURN_CANCELLING_BACKOFF_FACTOR = 2;
|
||||||
|
const TURN_CANCELLING_MAX_DELAY_MS = 1500;
|
||||||
|
const RECENT_CANCEL_WINDOW_MS = 5_000;
|
||||||
|
|
||||||
|
function sleep(ms: number): Promise<void> {
|
||||||
|
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||||
|
}
|
||||||
|
|
||||||
|
function computeFallbackTurnCancellingRetryDelay(attempt: number): number {
|
||||||
|
const safeAttempt = Math.max(1, attempt);
|
||||||
|
const raw = TURN_CANCELLING_INITIAL_DELAY_MS * TURN_CANCELLING_BACKOFF_FACTOR ** (safeAttempt - 1);
|
||||||
|
return Math.min(raw, TURN_CANCELLING_MAX_DELAY_MS);
|
||||||
|
}
|
||||||
|
|
||||||
export default function NewChatPage() {
|
export default function NewChatPage() {
|
||||||
const params = useParams();
|
const params = useParams();
|
||||||
|
|
@ -193,6 +207,7 @@ export default function NewChatPage() {
|
||||||
const [isRunning, setIsRunning] = useState(false);
|
const [isRunning, setIsRunning] = useState(false);
|
||||||
const [tokenUsageStore] = useState(() => createTokenUsageStore());
|
const [tokenUsageStore] = useState(() => createTokenUsageStore());
|
||||||
const abortControllerRef = useRef<AbortController | null>(null);
|
const abortControllerRef = useRef<AbortController | null>(null);
|
||||||
|
const recentCancelRequestedAtRef = useRef(0);
|
||||||
const [pendingInterrupt, setPendingInterrupt] = useState<{
|
const [pendingInterrupt, setPendingInterrupt] = useState<{
|
||||||
threadId: number;
|
threadId: number;
|
||||||
assistantMsgId: string;
|
assistantMsgId: string;
|
||||||
|
|
@ -598,6 +613,36 @@ export default function NewChatPage() {
|
||||||
[handleChatFailure]
|
[handleChatFailure]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const fetchWithTurnCancellingRetry = useCallback(
|
||||||
|
async (runFetch: () => Promise<Response>) => {
|
||||||
|
const maxAttempts = 4;
|
||||||
|
for (let attempt = 1; attempt <= maxAttempts; attempt += 1) {
|
||||||
|
const response = await runFetch();
|
||||||
|
if (response.ok) {
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
const error = await toHttpResponseError(response);
|
||||||
|
const withMeta = error as Error & { errorCode?: string; retryAfterMs?: number };
|
||||||
|
const isTurnCancelling = withMeta.errorCode === "TURN_CANCELLING";
|
||||||
|
const isRecentThreadBusyAfterCancel =
|
||||||
|
withMeta.errorCode === "THREAD_BUSY" &&
|
||||||
|
Date.now() - recentCancelRequestedAtRef.current <= RECENT_CANCEL_WINDOW_MS;
|
||||||
|
if ((isTurnCancelling || isRecentThreadBusyAfterCancel) && attempt < maxAttempts) {
|
||||||
|
const waitMs =
|
||||||
|
withMeta.retryAfterMs ?? computeFallbackTurnCancellingRetryDelay(attempt);
|
||||||
|
await sleep(waitMs);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
|
throw Object.assign(new Error("Turn cancellation retry limit exceeded"), {
|
||||||
|
errorCode: "TURN_CANCELLING",
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[]
|
||||||
|
);
|
||||||
|
|
||||||
// Initialize thread and load messages
|
// Initialize thread and load messages
|
||||||
// For new chats (no urlChatId), we use lazy creation - thread is created on first message
|
// For new chats (no urlChatId), we use lazy creation - thread is created on first message
|
||||||
const initializeThread = useCallback(async () => {
|
const initializeThread = useCallback(async () => {
|
||||||
|
|
@ -767,12 +812,39 @@ export default function NewChatPage() {
|
||||||
|
|
||||||
// Cancel ongoing request
|
// Cancel ongoing request
|
||||||
const cancelRun = useCallback(async () => {
|
const cancelRun = useCallback(async () => {
|
||||||
|
if (threadId) {
|
||||||
|
const token = getBearerToken();
|
||||||
|
if (token) {
|
||||||
|
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
||||||
|
try {
|
||||||
|
const response = await fetch(
|
||||||
|
`${backendUrl}/api/v1/threads/${threadId}/cancel-active-turn`,
|
||||||
|
{
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
Authorization: `Bearer ${token}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
if (response.ok) {
|
||||||
|
const payload = (await response.json()) as {
|
||||||
|
error_code?: string;
|
||||||
|
};
|
||||||
|
if (payload.error_code === "TURN_CANCELLING") {
|
||||||
|
recentCancelRequestedAtRef.current = Date.now();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.warn("[NewChatPage] Failed to signal cancel-active-turn:", error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
if (abortControllerRef.current) {
|
if (abortControllerRef.current) {
|
||||||
abortControllerRef.current.abort();
|
abortControllerRef.current.abort();
|
||||||
abortControllerRef.current = null;
|
abortControllerRef.current = null;
|
||||||
}
|
}
|
||||||
setIsRunning(false);
|
setIsRunning(false);
|
||||||
}, []);
|
}, [threadId]);
|
||||||
|
|
||||||
// Handle new message from user
|
// Handle new message from user
|
||||||
const onNew = useCallback(
|
const onNew = useCallback(
|
||||||
|
|
@ -971,29 +1043,33 @@ export default function NewChatPage() {
|
||||||
setMentionedDocuments([]);
|
setMentionedDocuments([]);
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await fetch(`${backendUrl}/api/v1/new_chat`, {
|
const response = await fetchWithTurnCancellingRetry(() =>
|
||||||
method: "POST",
|
fetch(`${backendUrl}/api/v1/new_chat`, {
|
||||||
headers: {
|
method: "POST",
|
||||||
"Content-Type": "application/json",
|
headers: {
|
||||||
Authorization: `Bearer ${token}`,
|
"Content-Type": "application/json",
|
||||||
},
|
Authorization: `Bearer ${token}`,
|
||||||
body: JSON.stringify({
|
},
|
||||||
chat_id: currentThreadId,
|
body: JSON.stringify({
|
||||||
user_query: userQuery.trim(),
|
chat_id: currentThreadId,
|
||||||
search_space_id: searchSpaceId,
|
user_query: userQuery.trim(),
|
||||||
filesystem_mode: selection.filesystem_mode,
|
search_space_id: searchSpaceId,
|
||||||
client_platform: selection.client_platform,
|
filesystem_mode: selection.filesystem_mode,
|
||||||
local_filesystem_mounts: selection.local_filesystem_mounts,
|
client_platform: selection.client_platform,
|
||||||
messages: messageHistory,
|
local_filesystem_mounts: selection.local_filesystem_mounts,
|
||||||
mentioned_document_ids: hasDocumentIds ? mentionedDocumentIds.document_ids : undefined,
|
messages: messageHistory,
|
||||||
mentioned_surfsense_doc_ids: hasSurfsenseDocIds
|
mentioned_document_ids: hasDocumentIds
|
||||||
? mentionedDocumentIds.surfsense_doc_ids
|
? mentionedDocumentIds.document_ids
|
||||||
: undefined,
|
: undefined,
|
||||||
disabled_tools: disabledTools.length > 0 ? disabledTools : undefined,
|
mentioned_surfsense_doc_ids: hasSurfsenseDocIds
|
||||||
...(userImages.length > 0 ? { user_images: userImages } : {}),
|
? mentionedDocumentIds.surfsense_doc_ids
|
||||||
}),
|
: undefined,
|
||||||
signal: controller.signal,
|
disabled_tools: disabledTools.length > 0 ? disabledTools : undefined,
|
||||||
});
|
...(userImages.length > 0 ? { user_images: userImages } : {}),
|
||||||
|
}),
|
||||||
|
signal: controller.signal,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
throw await toHttpResponseError(response);
|
throw await toHttpResponseError(response);
|
||||||
|
|
@ -1033,6 +1109,11 @@ export default function NewChatPage() {
|
||||||
tokenUsageData = data;
|
tokenUsageData = data;
|
||||||
tokenUsageStore.set(assistantMsgId, data);
|
tokenUsageStore.set(assistantMsgId, data);
|
||||||
},
|
},
|
||||||
|
onTurnStatus: (data) => {
|
||||||
|
if (data.status === "cancelling") {
|
||||||
|
recentCancelRequestedAtRef.current = Date.now();
|
||||||
|
}
|
||||||
|
},
|
||||||
onToolOutputAvailable: (event, sharedCtx) => {
|
onToolOutputAvailable: (event, sharedCtx) => {
|
||||||
if (event.output?.status === "pending" && event.output?.podcast_id) {
|
if (event.output?.status === "pending" && event.output?.podcast_id) {
|
||||||
const idx = sharedCtx.toolCallIndices.get(event.toolCallId);
|
const idx = sharedCtx.toolCallIndices.get(event.toolCallId);
|
||||||
|
|
@ -1257,6 +1338,7 @@ export default function NewChatPage() {
|
||||||
tokenUsageStore,
|
tokenUsageStore,
|
||||||
pendingUserImageUrls,
|
pendingUserImageUrls,
|
||||||
setPendingUserImageUrls,
|
setPendingUserImageUrls,
|
||||||
|
fetchWithTurnCancellingRetry,
|
||||||
handleStreamTerminalError,
|
handleStreamTerminalError,
|
||||||
handleChatFailure,
|
handleChatFailure,
|
||||||
persistAssistantTurn,
|
persistAssistantTurn,
|
||||||
|
|
@ -1354,21 +1436,23 @@ export default function NewChatPage() {
|
||||||
try {
|
try {
|
||||||
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
||||||
const selection = await getAgentFilesystemSelection(searchSpaceId);
|
const selection = await getAgentFilesystemSelection(searchSpaceId);
|
||||||
const response = await fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, {
|
const response = await fetchWithTurnCancellingRetry(() =>
|
||||||
method: "POST",
|
fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, {
|
||||||
headers: {
|
method: "POST",
|
||||||
"Content-Type": "application/json",
|
headers: {
|
||||||
Authorization: `Bearer ${token}`,
|
"Content-Type": "application/json",
|
||||||
},
|
Authorization: `Bearer ${token}`,
|
||||||
body: JSON.stringify({
|
},
|
||||||
search_space_id: searchSpaceId,
|
body: JSON.stringify({
|
||||||
decisions,
|
search_space_id: searchSpaceId,
|
||||||
filesystem_mode: selection.filesystem_mode,
|
decisions,
|
||||||
client_platform: selection.client_platform,
|
filesystem_mode: selection.filesystem_mode,
|
||||||
local_filesystem_mounts: selection.local_filesystem_mounts,
|
client_platform: selection.client_platform,
|
||||||
}),
|
local_filesystem_mounts: selection.local_filesystem_mounts,
|
||||||
signal: controller.signal,
|
}),
|
||||||
});
|
signal: controller.signal,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
throw await toHttpResponseError(response);
|
throw await toHttpResponseError(response);
|
||||||
|
|
@ -1399,6 +1483,11 @@ export default function NewChatPage() {
|
||||||
tokenUsageData = data;
|
tokenUsageData = data;
|
||||||
tokenUsageStore.set(assistantMsgId, data);
|
tokenUsageStore.set(assistantMsgId, data);
|
||||||
},
|
},
|
||||||
|
onTurnStatus: (data) => {
|
||||||
|
if (data.status === "cancelling") {
|
||||||
|
recentCancelRequestedAtRef.current = Date.now();
|
||||||
|
}
|
||||||
|
},
|
||||||
})
|
})
|
||||||
) {
|
) {
|
||||||
return;
|
return;
|
||||||
|
|
@ -1496,6 +1585,7 @@ export default function NewChatPage() {
|
||||||
searchSpaceId,
|
searchSpaceId,
|
||||||
queryClient,
|
queryClient,
|
||||||
tokenUsageStore,
|
tokenUsageStore,
|
||||||
|
fetchWithTurnCancellingRetry,
|
||||||
handleStreamTerminalError,
|
handleStreamTerminalError,
|
||||||
persistAssistantTurn,
|
persistAssistantTurn,
|
||||||
]
|
]
|
||||||
|
|
@ -1700,15 +1790,17 @@ export default function NewChatPage() {
|
||||||
requestBody.revert_actions = true;
|
requestBody.revert_actions = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const response = await fetch(getRegenerateUrl(threadId), {
|
const response = await fetchWithTurnCancellingRetry(() =>
|
||||||
method: "POST",
|
fetch(getRegenerateUrl(threadId), {
|
||||||
headers: {
|
method: "POST",
|
||||||
"Content-Type": "application/json",
|
headers: {
|
||||||
Authorization: `Bearer ${token}`,
|
"Content-Type": "application/json",
|
||||||
},
|
Authorization: `Bearer ${token}`,
|
||||||
body: JSON.stringify(requestBody),
|
},
|
||||||
signal: controller.signal,
|
body: JSON.stringify(requestBody),
|
||||||
});
|
signal: controller.signal,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
throw await toHttpResponseError(response);
|
throw await toHttpResponseError(response);
|
||||||
|
|
@ -1774,6 +1866,11 @@ export default function NewChatPage() {
|
||||||
tokenUsageData = data;
|
tokenUsageData = data;
|
||||||
tokenUsageStore.set(assistantMsgId, data);
|
tokenUsageStore.set(assistantMsgId, data);
|
||||||
},
|
},
|
||||||
|
onTurnStatus: (data) => {
|
||||||
|
if (data.status === "cancelling") {
|
||||||
|
recentCancelRequestedAtRef.current = Date.now();
|
||||||
|
}
|
||||||
|
},
|
||||||
onToolOutputAvailable: (event, sharedCtx) => {
|
onToolOutputAvailable: (event, sharedCtx) => {
|
||||||
if (event.output?.status === "pending" && event.output?.podcast_id) {
|
if (event.output?.status === "pending" && event.output?.podcast_id) {
|
||||||
const idx = sharedCtx.toolCallIndices.get(event.toolCallId);
|
const idx = sharedCtx.toolCallIndices.get(event.toolCallId);
|
||||||
|
|
@ -1945,6 +2042,7 @@ export default function NewChatPage() {
|
||||||
setMessageDocumentsMap,
|
setMessageDocumentsMap,
|
||||||
queryClient,
|
queryClient,
|
||||||
tokenUsageStore,
|
tokenUsageStore,
|
||||||
|
fetchWithTurnCancellingRetry,
|
||||||
handleStreamTerminalError,
|
handleStreamTerminalError,
|
||||||
persistAssistantTurn,
|
persistAssistantTurn,
|
||||||
persistUserTurn,
|
persistUserTurn,
|
||||||
|
|
|
||||||
|
|
@ -147,6 +147,22 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
errorCode === "TURN_CANCELLING"
|
||||||
|
) {
|
||||||
|
return {
|
||||||
|
kind: "thread_busy",
|
||||||
|
channel: "toast",
|
||||||
|
severity: "info",
|
||||||
|
telemetryEvent: "chat_blocked",
|
||||||
|
isExpected: true,
|
||||||
|
userMessage: "A previous response is still stopping. Please try again in a moment.",
|
||||||
|
rawMessage,
|
||||||
|
errorCode: errorCode ?? "TURN_CANCELLING",
|
||||||
|
details: { flow: input.flow },
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
errorCode === "THREAD_BUSY"
|
errorCode === "THREAD_BUSY"
|
||||||
) {
|
) {
|
||||||
|
|
@ -156,7 +172,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError
|
||||||
severity: "warn",
|
severity: "warn",
|
||||||
telemetryEvent: "chat_blocked",
|
telemetryEvent: "chat_blocked",
|
||||||
isExpected: true,
|
isExpected: true,
|
||||||
userMessage: "A previous response is still stopping. Please try again in a moment.",
|
userMessage: "Another response is still finishing for this thread. Please try again in a moment.",
|
||||||
rawMessage,
|
rawMessage,
|
||||||
errorCode: errorCode ?? "THREAD_BUSY",
|
errorCode: errorCode ?? "THREAD_BUSY",
|
||||||
details: { flow: input.flow },
|
details: { flow: input.flow },
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
export async function toHttpResponseError(
|
export async function toHttpResponseError(
|
||||||
response: Response
|
response: Response
|
||||||
): Promise<Error & { errorCode?: string }> {
|
): Promise<Error & { errorCode?: string; retryAfterMs?: number }> {
|
||||||
const statusDefaultCode =
|
const statusDefaultCode =
|
||||||
response.status === 409
|
response.status === 409
|
||||||
? "THREAD_BUSY"
|
? "THREAD_BUSY"
|
||||||
|
|
@ -52,13 +52,37 @@ export async function toHttpResponseError(
|
||||||
: undefined;
|
: undefined;
|
||||||
|
|
||||||
const errorCode = detailCode ?? topLevelCode ?? statusDefaultCode;
|
const errorCode = detailCode ?? topLevelCode ?? statusDefaultCode;
|
||||||
|
|
||||||
|
const detailRetryAfterMs =
|
||||||
|
typeof detailObject?.retry_after_ms === "number"
|
||||||
|
? detailObject.retry_after_ms
|
||||||
|
: typeof detailObject?.retryAfterMs === "number"
|
||||||
|
? detailObject.retryAfterMs
|
||||||
|
: undefined;
|
||||||
|
const topRetryAfterMs =
|
||||||
|
typeof parsedBody?.retry_after_ms === "number"
|
||||||
|
? parsedBody.retry_after_ms
|
||||||
|
: typeof parsedBody?.retryAfterMs === "number"
|
||||||
|
? parsedBody.retryAfterMs
|
||||||
|
: undefined;
|
||||||
|
const headerRetryAfterMsRaw = response.headers.get("retry-after-ms");
|
||||||
|
const headerRetryAfterMs = headerRetryAfterMsRaw ? Number.parseFloat(headerRetryAfterMsRaw) : NaN;
|
||||||
|
const retryAfterHeader = response.headers.get("retry-after");
|
||||||
|
const retryAfterSeconds = retryAfterHeader ? Number.parseFloat(retryAfterHeader) : NaN;
|
||||||
|
const retryAfterMsFromHeader = Number.isFinite(headerRetryAfterMs)
|
||||||
|
? Math.max(0, Math.round(headerRetryAfterMs))
|
||||||
|
: Number.isFinite(retryAfterSeconds)
|
||||||
|
? Math.max(0, Math.round(retryAfterSeconds * 1000))
|
||||||
|
: undefined;
|
||||||
|
const retryAfterMs =
|
||||||
|
detailRetryAfterMs ?? topRetryAfterMs ?? retryAfterMsFromHeader ?? undefined;
|
||||||
const message =
|
const message =
|
||||||
detailNestedMessage ??
|
detailNestedMessage ??
|
||||||
detailMessage ??
|
detailMessage ??
|
||||||
topLevelMessage ??
|
topLevelMessage ??
|
||||||
`Backend error: ${response.status}`;
|
`Backend error: ${response.status}`;
|
||||||
|
|
||||||
return Object.assign(new Error(message), { errorCode });
|
return Object.assign(new Error(message), { errorCode, retryAfterMs });
|
||||||
}
|
}
|
||||||
|
|
||||||
export function tagPreAcceptSendFailure(error: unknown): unknown {
|
export function tagPreAcceptSendFailure(error: unknown): unknown {
|
||||||
|
|
@ -68,6 +92,7 @@ export function tagPreAcceptSendFailure(error: unknown): unknown {
|
||||||
const passthroughCodes = new Set([
|
const passthroughCodes = new Set([
|
||||||
"PREMIUM_QUOTA_EXHAUSTED",
|
"PREMIUM_QUOTA_EXHAUSTED",
|
||||||
"THREAD_BUSY",
|
"THREAD_BUSY",
|
||||||
|
"TURN_CANCELLING",
|
||||||
"AUTH_EXPIRED",
|
"AUTH_EXPIRED",
|
||||||
"UNAUTHORIZED",
|
"UNAUTHORIZED",
|
||||||
"RATE_LIMITED",
|
"RATE_LIMITED",
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ export type SharedStreamEventContext = {
|
||||||
scheduleFlush: () => void;
|
scheduleFlush: () => void;
|
||||||
forceFlush: () => void;
|
forceFlush: () => void;
|
||||||
onTokenUsage?: (data: Extract<SSEEvent, { type: "data-token-usage" }>["data"]) => void;
|
onTokenUsage?: (data: Extract<SSEEvent, { type: "data-token-usage" }>["data"]) => void;
|
||||||
|
onTurnStatus?: (data: Extract<SSEEvent, { type: "data-turn-status" }>["data"]) => void;
|
||||||
onToolOutputAvailable?: (
|
onToolOutputAvailable?: (
|
||||||
event: Extract<SSEEvent, { type: "tool-output-available" }>,
|
event: Extract<SSEEvent, { type: "tool-output-available" }>,
|
||||||
context: {
|
context: {
|
||||||
|
|
@ -173,6 +174,10 @@ export function processSharedStreamEvent(parsed: SSEEvent, context: SharedStream
|
||||||
context.onTokenUsage?.(parsed.data);
|
context.onTokenUsage?.(parsed.data);
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
|
case "data-turn-status":
|
||||||
|
context.onTurnStatus?.(parsed.data);
|
||||||
|
return true;
|
||||||
|
|
||||||
case "error":
|
case "error":
|
||||||
throw toStreamTerminalError(parsed);
|
throw toStreamTerminalError(parsed);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -528,6 +528,14 @@ export type SSEEvent =
|
||||||
}>;
|
}>;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
| {
|
||||||
|
type: "data-turn-status";
|
||||||
|
data: {
|
||||||
|
status: "idle" | "busy" | "cancelling";
|
||||||
|
retry_after_ms?: number;
|
||||||
|
retry_after_at?: number;
|
||||||
|
};
|
||||||
|
}
|
||||||
| {
|
| {
|
||||||
type: "data-token-usage";
|
type: "data-token-usage";
|
||||||
data: {
|
data: {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue