mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 18:35:19 +02:00
Merge pull request #1327 from AnishSarkar22/feat/chat-state-unification
refactor(chat): unify streaming state flow and improve chat viewport + mention UX
This commit is contained in:
commit
d335e96ec2
27 changed files with 1953 additions and 1647 deletions
|
|
@ -33,6 +33,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
import weakref
|
import weakref
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -58,6 +59,8 @@ class _ThreadLockManager:
|
||||||
weakref.WeakValueDictionary()
|
weakref.WeakValueDictionary()
|
||||||
)
|
)
|
||||||
self._cancel_events: dict[str, asyncio.Event] = {}
|
self._cancel_events: dict[str, asyncio.Event] = {}
|
||||||
|
self._cancel_requested_at_ms: dict[str, int] = {}
|
||||||
|
self._cancel_attempt_count: dict[str, int] = {}
|
||||||
|
|
||||||
def lock_for(self, thread_id: str) -> asyncio.Lock:
|
def lock_for(self, thread_id: str) -> asyncio.Lock:
|
||||||
lock = self._locks.get(thread_id)
|
lock = self._locks.get(thread_id)
|
||||||
|
|
@ -76,14 +79,45 @@ class _ThreadLockManager:
|
||||||
def request_cancel(self, thread_id: str) -> bool:
|
def request_cancel(self, thread_id: str) -> bool:
|
||||||
event = self._cancel_events.get(thread_id)
|
event = self._cancel_events.get(thread_id)
|
||||||
if event is None:
|
if event is None:
|
||||||
return False
|
event = asyncio.Event()
|
||||||
|
self._cancel_events[thread_id] = event
|
||||||
event.set()
|
event.set()
|
||||||
|
now_ms = int(time.time() * 1000)
|
||||||
|
self._cancel_requested_at_ms[thread_id] = now_ms
|
||||||
|
self._cancel_attempt_count[thread_id] = (
|
||||||
|
self._cancel_attempt_count.get(thread_id, 0) + 1
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def is_cancel_requested(self, thread_id: str) -> bool:
|
||||||
|
event = self._cancel_events.get(thread_id)
|
||||||
|
return bool(event and event.is_set())
|
||||||
|
|
||||||
|
def cancel_state(self, thread_id: str) -> tuple[int, int] | None:
|
||||||
|
if not self.is_cancel_requested(thread_id):
|
||||||
|
return None
|
||||||
|
attempts = self._cancel_attempt_count.get(thread_id, 1)
|
||||||
|
requested_at_ms = self._cancel_requested_at_ms.get(thread_id, 0)
|
||||||
|
return attempts, requested_at_ms
|
||||||
|
|
||||||
def reset(self, thread_id: str) -> None:
|
def reset(self, thread_id: str) -> None:
|
||||||
event = self._cancel_events.get(thread_id)
|
event = self._cancel_events.get(thread_id)
|
||||||
if event is not None:
|
if event is not None:
|
||||||
event.clear()
|
event.clear()
|
||||||
|
self._cancel_requested_at_ms.pop(thread_id, None)
|
||||||
|
self._cancel_attempt_count.pop(thread_id, None)
|
||||||
|
|
||||||
|
def end_turn(self, thread_id: str) -> None:
|
||||||
|
"""Best-effort terminal cleanup for a thread turn.
|
||||||
|
|
||||||
|
This is intentionally idempotent and safe to call from outer stream
|
||||||
|
finally-blocks where middleware teardown might be skipped due to abort
|
||||||
|
or disconnect edge-cases.
|
||||||
|
"""
|
||||||
|
lock = self._locks.get(thread_id)
|
||||||
|
if lock is not None and lock.locked():
|
||||||
|
lock.release()
|
||||||
|
self.reset(thread_id)
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton — process-local but reused across all agent
|
# Module-level singleton — process-local but reused across all agent
|
||||||
|
|
@ -98,15 +132,30 @@ def get_cancel_event(thread_id: str) -> asyncio.Event:
|
||||||
|
|
||||||
|
|
||||||
def request_cancel(thread_id: str) -> bool:
|
def request_cancel(thread_id: str) -> bool:
|
||||||
"""Trip the cancel event for ``thread_id``. Returns True if found."""
|
"""Trip the cancel event for ``thread_id``. Always returns True."""
|
||||||
return manager.request_cancel(thread_id)
|
return manager.request_cancel(thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
def is_cancel_requested(thread_id: str) -> bool:
|
||||||
|
"""Return whether ``thread_id`` currently has a pending cancel signal."""
|
||||||
|
return manager.is_cancel_requested(thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cancel_state(thread_id: str) -> tuple[int, int] | None:
|
||||||
|
"""Return ``(attempt_count, requested_at_ms)`` for pending cancel state."""
|
||||||
|
return manager.cancel_state(thread_id)
|
||||||
|
|
||||||
|
|
||||||
def reset_cancel(thread_id: str) -> None:
|
def reset_cancel(thread_id: str) -> None:
|
||||||
"""Reset the cancel event for ``thread_id`` (called between turns)."""
|
"""Reset the cancel event for ``thread_id`` (called between turns)."""
|
||||||
manager.reset(thread_id)
|
manager.reset(thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
def end_turn(thread_id: str) -> None:
|
||||||
|
"""Force end-of-turn cleanup for lock + cancel state."""
|
||||||
|
manager.end_turn(thread_id)
|
||||||
|
|
||||||
|
|
||||||
class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||||
"""Block concurrent prompts on the same thread.
|
"""Block concurrent prompts on the same thread.
|
||||||
|
|
||||||
|
|
@ -229,7 +278,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BusyMutexMiddleware",
|
"BusyMutexMiddleware",
|
||||||
|
"end_turn",
|
||||||
"get_cancel_event",
|
"get_cancel_event",
|
||||||
|
"get_cancel_state",
|
||||||
|
"is_cancel_requested",
|
||||||
"manager",
|
"manager",
|
||||||
"request_cancel",
|
"request_cancel",
|
||||||
"reset_cancel",
|
"reset_cancel",
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy import func, or_
|
from sqlalchemy import func, or_
|
||||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||||
|
|
@ -29,6 +29,12 @@ from app.agents.new_chat.filesystem_selection import (
|
||||||
FilesystemSelection,
|
FilesystemSelection,
|
||||||
LocalFilesystemMount,
|
LocalFilesystemMount,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.middleware.busy_mutex import (
|
||||||
|
get_cancel_state,
|
||||||
|
is_cancel_requested,
|
||||||
|
manager,
|
||||||
|
request_cancel,
|
||||||
|
)
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.db import (
|
from app.db import (
|
||||||
ChatComment,
|
ChatComment,
|
||||||
|
|
@ -44,6 +50,7 @@ from app.db import (
|
||||||
)
|
)
|
||||||
from app.schemas.new_chat import (
|
from app.schemas.new_chat import (
|
||||||
AgentToolInfo,
|
AgentToolInfo,
|
||||||
|
CancelActiveTurnResponse,
|
||||||
LocalFilesystemMountPayload,
|
LocalFilesystemMountPayload,
|
||||||
NewChatMessageRead,
|
NewChatMessageRead,
|
||||||
NewChatRequest,
|
NewChatRequest,
|
||||||
|
|
@ -60,6 +67,7 @@ from app.schemas.new_chat import (
|
||||||
ThreadListItem,
|
ThreadListItem,
|
||||||
ThreadListResponse,
|
ThreadListResponse,
|
||||||
TokenUsageSummary,
|
TokenUsageSummary,
|
||||||
|
TurnStatusResponse,
|
||||||
)
|
)
|
||||||
from app.services.token_tracking_service import record_token_usage
|
from app.services.token_tracking_service import record_token_usage
|
||||||
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
||||||
|
|
@ -72,6 +80,9 @@ from app.utils.user_message_multimodal import (
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
_background_tasks: set[asyncio.Task] = set()
|
_background_tasks: set[asyncio.Task] = set()
|
||||||
|
TURN_CANCELLING_INITIAL_DELAY_MS = 200
|
||||||
|
TURN_CANCELLING_BACKOFF_FACTOR = 2
|
||||||
|
TURN_CANCELLING_MAX_DELAY_MS = 1500
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
@ -137,6 +148,72 @@ def _resolve_filesystem_selection(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
|
||||||
|
"""Bounded exponential delay for TURN_CANCELLING retry hints."""
|
||||||
|
if attempt < 1:
|
||||||
|
attempt = 1
|
||||||
|
delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
|
||||||
|
TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
|
||||||
|
)
|
||||||
|
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_turn_status_payload(thread_id: int) -> dict[str, object]:
|
||||||
|
lock = manager.lock_for(str(thread_id))
|
||||||
|
if not lock.locked():
|
||||||
|
return {"status": "idle"}
|
||||||
|
|
||||||
|
if is_cancel_requested(str(thread_id)):
|
||||||
|
cancel_state = get_cancel_state(str(thread_id))
|
||||||
|
attempt = cancel_state[0] if cancel_state else 1
|
||||||
|
retry_after_ms = _compute_turn_cancelling_retry_delay(attempt)
|
||||||
|
retry_after_at = int(datetime.now(UTC).timestamp() * 1000) + retry_after_ms
|
||||||
|
return {
|
||||||
|
"status": "cancelling",
|
||||||
|
"retry_after_ms": retry_after_ms,
|
||||||
|
"retry_after_at": retry_after_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
return {"status": "busy"}
|
||||||
|
|
||||||
|
|
||||||
|
def _set_retry_after_headers(response: Response, retry_after_ms: int) -> None:
|
||||||
|
response.headers["retry-after-ms"] = str(retry_after_ms)
|
||||||
|
response.headers["Retry-After"] = str(max(1, (retry_after_ms + 999) // 1000))
|
||||||
|
|
||||||
|
|
||||||
|
def _raise_if_thread_busy_for_start(thread_id: int) -> None:
|
||||||
|
status_payload = _build_turn_status_payload(thread_id)
|
||||||
|
status = status_payload["status"]
|
||||||
|
if status == "idle":
|
||||||
|
return
|
||||||
|
if status == "cancelling":
|
||||||
|
retry_after_ms = int(status_payload.get("retry_after_ms") or 0)
|
||||||
|
detail = {
|
||||||
|
"errorCode": "TURN_CANCELLING",
|
||||||
|
"message": "A previous response is still stopping. Please try again in a moment.",
|
||||||
|
"retry_after_ms": retry_after_ms if retry_after_ms > 0 else None,
|
||||||
|
"retry_after_at": status_payload.get("retry_after_at"),
|
||||||
|
}
|
||||||
|
headers = (
|
||||||
|
{
|
||||||
|
"retry-after-ms": str(retry_after_ms),
|
||||||
|
"Retry-After": str(max(1, (retry_after_ms + 999) // 1000)),
|
||||||
|
}
|
||||||
|
if retry_after_ms > 0
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
raise HTTPException(status_code=409, detail=detail, headers=headers)
|
||||||
|
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409,
|
||||||
|
detail={
|
||||||
|
"errorCode": "THREAD_BUSY",
|
||||||
|
"message": "Another response is still finishing for this thread. Please try again in a moment.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _find_pre_turn_checkpoint_id(
|
def _find_pre_turn_checkpoint_id(
|
||||||
checkpoint_tuples: list,
|
checkpoint_tuples: list,
|
||||||
*,
|
*,
|
||||||
|
|
@ -1476,6 +1553,7 @@ async def handle_new_chat(
|
||||||
|
|
||||||
# Check thread-level access based on visibility
|
# Check thread-level access based on visibility
|
||||||
await check_thread_access(session, thread, user)
|
await check_thread_access(session, thread, user)
|
||||||
|
_raise_if_thread_busy_for_start(request.chat_id)
|
||||||
filesystem_selection = _resolve_filesystem_selection(
|
filesystem_selection = _resolve_filesystem_selection(
|
||||||
mode=request.filesystem_mode,
|
mode=request.filesystem_mode,
|
||||||
client_platform=request.client_platform,
|
client_platform=request.client_platform,
|
||||||
|
|
@ -1550,6 +1628,93 @@ async def handle_new_chat(
|
||||||
) from None
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/threads/{thread_id}/cancel-active-turn",
|
||||||
|
response_model=CancelActiveTurnResponse,
|
||||||
|
)
|
||||||
|
async def cancel_active_turn(
|
||||||
|
thread_id: int,
|
||||||
|
response: Response,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""Signal cancellation for the currently running turn on ``thread_id``."""
|
||||||
|
result = await session.execute(
|
||||||
|
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||||
|
)
|
||||||
|
thread = result.scalars().first()
|
||||||
|
if not thread:
|
||||||
|
raise HTTPException(status_code=404, detail="Thread not found")
|
||||||
|
|
||||||
|
await check_permission(
|
||||||
|
session,
|
||||||
|
user,
|
||||||
|
thread.search_space_id,
|
||||||
|
Permission.CHATS_UPDATE.value,
|
||||||
|
"You don't have permission to update chats in this search space",
|
||||||
|
)
|
||||||
|
await check_thread_access(session, thread, user)
|
||||||
|
|
||||||
|
status_payload = _build_turn_status_payload(thread_id)
|
||||||
|
if status_payload["status"] == "idle":
|
||||||
|
return CancelActiveTurnResponse(
|
||||||
|
status="idle",
|
||||||
|
error_code="NO_ACTIVE_TURN",
|
||||||
|
)
|
||||||
|
|
||||||
|
request_cancel(str(thread_id))
|
||||||
|
response.status_code = 202
|
||||||
|
updated_payload = _build_turn_status_payload(thread_id)
|
||||||
|
retry_after_ms = int(updated_payload.get("retry_after_ms") or 0)
|
||||||
|
retry_after_at = (
|
||||||
|
int(updated_payload["retry_after_at"])
|
||||||
|
if "retry_after_at" in updated_payload
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if retry_after_ms > 0:
|
||||||
|
_set_retry_after_headers(response, retry_after_ms)
|
||||||
|
return CancelActiveTurnResponse(
|
||||||
|
status="cancelling",
|
||||||
|
error_code="TURN_CANCELLING",
|
||||||
|
retry_after_ms=retry_after_ms if retry_after_ms > 0 else None,
|
||||||
|
retry_after_at=retry_after_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/threads/{thread_id}/turn-status",
|
||||||
|
response_model=TurnStatusResponse,
|
||||||
|
)
|
||||||
|
async def get_turn_status(
|
||||||
|
thread_id: int,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
result = await session.execute(
|
||||||
|
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||||
|
)
|
||||||
|
thread = result.scalars().first()
|
||||||
|
if not thread:
|
||||||
|
raise HTTPException(status_code=404, detail="Thread not found")
|
||||||
|
|
||||||
|
await check_permission(
|
||||||
|
session,
|
||||||
|
user,
|
||||||
|
thread.search_space_id,
|
||||||
|
Permission.CHATS_READ.value,
|
||||||
|
"You don't have permission to view chats in this search space",
|
||||||
|
)
|
||||||
|
await check_thread_access(session, thread, user)
|
||||||
|
|
||||||
|
status_payload = _build_turn_status_payload(thread_id)
|
||||||
|
return TurnStatusResponse(
|
||||||
|
status=status_payload["status"], # type: ignore[arg-type]
|
||||||
|
active_turn_id=None,
|
||||||
|
retry_after_ms=status_payload.get("retry_after_ms"), # type: ignore[arg-type]
|
||||||
|
retry_after_at=status_payload.get("retry_after_at"), # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Chat Regeneration Endpoint (Edit/Reload)
|
# Chat Regeneration Endpoint (Edit/Reload)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
@ -1605,6 +1770,7 @@ async def regenerate_response(
|
||||||
|
|
||||||
# Check thread-level access based on visibility
|
# Check thread-level access based on visibility
|
||||||
await check_thread_access(session, thread, user)
|
await check_thread_access(session, thread, user)
|
||||||
|
_raise_if_thread_busy_for_start(thread_id)
|
||||||
filesystem_selection = _resolve_filesystem_selection(
|
filesystem_selection = _resolve_filesystem_selection(
|
||||||
mode=request.filesystem_mode,
|
mode=request.filesystem_mode,
|
||||||
client_platform=request.client_platform,
|
client_platform=request.client_platform,
|
||||||
|
|
@ -2012,6 +2178,7 @@ async def resume_chat(
|
||||||
)
|
)
|
||||||
|
|
||||||
await check_thread_access(session, thread, user)
|
await check_thread_access(session, thread, user)
|
||||||
|
_raise_if_thread_busy_for_start(thread_id)
|
||||||
filesystem_selection = _resolve_filesystem_selection(
|
filesystem_selection = _resolve_filesystem_selection(
|
||||||
mode=request.filesystem_mode,
|
mode=request.filesystem_mode,
|
||||||
client_platform=request.client_platform,
|
client_platform=request.client_platform,
|
||||||
|
|
|
||||||
|
|
@ -335,6 +335,24 @@ class ResumeRequest(BaseModel):
|
||||||
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
|
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CancelActiveTurnResponse(BaseModel):
|
||||||
|
"""Response for canceling an active turn on a chat thread."""
|
||||||
|
|
||||||
|
status: Literal["cancelling", "idle"]
|
||||||
|
error_code: Literal["TURN_CANCELLING", "NO_ACTIVE_TURN"]
|
||||||
|
retry_after_ms: int | None = None
|
||||||
|
retry_after_at: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TurnStatusResponse(BaseModel):
|
||||||
|
"""Current turn execution status for a thread."""
|
||||||
|
|
||||||
|
status: Literal["idle", "busy", "cancelling"]
|
||||||
|
active_turn_id: str | None = None
|
||||||
|
retry_after_ms: int | None = None
|
||||||
|
retry_after_at: int | None = None
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Public Chat Snapshot Schemas
|
# Public Chat Snapshot Schemas
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
|
||||||
|
|
@ -565,7 +565,12 @@ class VercelStreamingService:
|
||||||
# Error Part
|
# Error Part
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
|
|
||||||
def format_error(self, error_text: str, error_code: str | None = None) -> str:
|
def format_error(
|
||||||
|
self,
|
||||||
|
error_text: str,
|
||||||
|
error_code: str | None = None,
|
||||||
|
extra: dict[str, object] | None = None,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Format an error message.
|
Format an error message.
|
||||||
|
|
||||||
|
|
@ -579,9 +584,11 @@ class VercelStreamingService:
|
||||||
Example output:
|
Example output:
|
||||||
data: {"type":"error","errorText":"Something went wrong","errorCode":"SOME_CODE"}
|
data: {"type":"error","errorText":"Something went wrong","errorCode":"SOME_CODE"}
|
||||||
"""
|
"""
|
||||||
payload: dict[str, str] = {"type": "error", "errorText": error_text}
|
payload: dict[str, object] = {"type": "error", "errorText": error_text}
|
||||||
if error_code:
|
if error_code:
|
||||||
payload["errorCode"] = error_code
|
payload["errorCode"] = error_code
|
||||||
|
if extra:
|
||||||
|
payload.update(extra)
|
||||||
return self._format_sse(payload)
|
return self._format_sse(payload)
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,11 @@ from app.agents.new_chat.memory_extraction import (
|
||||||
extract_and_save_memory,
|
extract_and_save_memory,
|
||||||
extract_and_save_team_memory,
|
extract_and_save_team_memory,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.middleware.busy_mutex import (
|
||||||
|
end_turn,
|
||||||
|
get_cancel_state,
|
||||||
|
is_cancel_requested,
|
||||||
|
)
|
||||||
from app.agents.new_chat.middleware.kb_persistence import (
|
from app.agents.new_chat.middleware.kb_persistence import (
|
||||||
commit_staged_filesystem_state,
|
commit_staged_filesystem_state,
|
||||||
)
|
)
|
||||||
|
|
@ -72,6 +77,18 @@ from app.utils.user_message_multimodal import build_human_message_content
|
||||||
|
|
||||||
_background_tasks: set[asyncio.Task] = set()
|
_background_tasks: set[asyncio.Task] = set()
|
||||||
_perf_log = get_perf_logger()
|
_perf_log = get_perf_logger()
|
||||||
|
TURN_CANCELLING_INITIAL_DELAY_MS = 200
|
||||||
|
TURN_CANCELLING_BACKOFF_FACTOR = 2
|
||||||
|
TURN_CANCELLING_MAX_DELAY_MS = 1500
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
|
||||||
|
if attempt < 1:
|
||||||
|
attempt = 1
|
||||||
|
delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
|
||||||
|
TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
|
||||||
|
)
|
||||||
|
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
|
||||||
|
|
||||||
|
|
||||||
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
|
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
|
||||||
|
|
@ -401,15 +418,35 @@ def _classify_stream_exception(
|
||||||
exc: Exception,
|
exc: Exception,
|
||||||
*,
|
*,
|
||||||
flow_label: str,
|
flow_label: str,
|
||||||
) -> tuple[str, str, Literal["info", "warn", "error"], bool, str]:
|
) -> tuple[
|
||||||
|
str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None
|
||||||
|
]:
|
||||||
raw = str(exc)
|
raw = str(exc)
|
||||||
if isinstance(exc, BusyError) or "Thread is busy with another request" in raw:
|
if isinstance(exc, BusyError) or "Thread is busy with another request" in raw:
|
||||||
|
busy_thread_id = str(exc.request_id) if isinstance(exc, BusyError) else None
|
||||||
|
if busy_thread_id and is_cancel_requested(busy_thread_id):
|
||||||
|
cancel_state = get_cancel_state(busy_thread_id)
|
||||||
|
attempt = cancel_state[0] if cancel_state else 1
|
||||||
|
retry_after_ms = _compute_turn_cancelling_retry_delay(attempt)
|
||||||
|
retry_after_at = int(time.time() * 1000) + retry_after_ms
|
||||||
|
return (
|
||||||
|
"thread_busy",
|
||||||
|
"TURN_CANCELLING",
|
||||||
|
"info",
|
||||||
|
True,
|
||||||
|
"A previous response is still stopping. Please try again in a moment.",
|
||||||
|
{
|
||||||
|
"retry_after_ms": retry_after_ms,
|
||||||
|
"retry_after_at": retry_after_at,
|
||||||
|
},
|
||||||
|
)
|
||||||
return (
|
return (
|
||||||
"thread_busy",
|
"thread_busy",
|
||||||
"THREAD_BUSY",
|
"THREAD_BUSY",
|
||||||
"warn",
|
"warn",
|
||||||
True,
|
True,
|
||||||
"Another response is still finishing for this thread. Please try again in a moment.",
|
"Another response is still finishing for this thread. Please try again in a moment.",
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
parsed = _parse_error_payload(raw)
|
parsed = _parse_error_payload(raw)
|
||||||
|
|
@ -431,6 +468,7 @@ def _classify_stream_exception(
|
||||||
"warn",
|
"warn",
|
||||||
True,
|
True,
|
||||||
"This model is temporarily rate-limited. Please try again in a few seconds or switch models.",
|
"This model is temporarily rate-limited. Please try again in a few seconds or switch models.",
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|
@ -439,6 +477,7 @@ def _classify_stream_exception(
|
||||||
"error",
|
"error",
|
||||||
False,
|
False,
|
||||||
f"Error during {flow_label}: {raw}",
|
f"Error during {flow_label}: {raw}",
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -470,7 +509,7 @@ def _emit_stream_terminal_error(
|
||||||
message=message,
|
message=message,
|
||||||
extra=extra,
|
extra=extra,
|
||||||
)
|
)
|
||||||
return streaming_service.format_error(message, error_code=error_code)
|
return streaming_service.format_error(message, error_code=error_code, extra=extra)
|
||||||
|
|
||||||
|
|
||||||
def _legacy_match_lc_id(
|
def _legacy_match_lc_id(
|
||||||
|
|
@ -2497,6 +2536,7 @@ async def stream_new_chat(
|
||||||
"turn-info",
|
"turn-info",
|
||||||
{"chat_turn_id": stream_result.turn_id},
|
{"chat_turn_id": stream_result.turn_id},
|
||||||
)
|
)
|
||||||
|
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||||
|
|
||||||
# Initial thinking step - analyzing the request
|
# Initial thinking step - analyzing the request
|
||||||
if mentioned_surfsense_docs:
|
if mentioned_surfsense_docs:
|
||||||
|
|
@ -2805,6 +2845,7 @@ async def stream_new_chat(
|
||||||
task.add_done_callback(_background_tasks.discard)
|
task.add_done_callback(_background_tasks.discard)
|
||||||
|
|
||||||
# Finish the step and message
|
# Finish the step and message
|
||||||
|
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||||
yield streaming_service.format_finish_step()
|
yield streaming_service.format_finish_step()
|
||||||
yield streaming_service.format_finish()
|
yield streaming_service.format_finish()
|
||||||
yield streaming_service.format_done()
|
yield streaming_service.format_done()
|
||||||
|
|
@ -2819,11 +2860,19 @@ async def stream_new_chat(
|
||||||
severity,
|
severity,
|
||||||
is_expected,
|
is_expected,
|
||||||
user_message,
|
user_message,
|
||||||
|
error_extra,
|
||||||
) = _classify_stream_exception(e, flow_label="chat")
|
) = _classify_stream_exception(e, flow_label="chat")
|
||||||
error_message = f"Error during chat: {e!s}"
|
error_message = f"Error during chat: {e!s}"
|
||||||
print(f"[stream_new_chat] {error_message}")
|
print(f"[stream_new_chat] {error_message}")
|
||||||
print(f"[stream_new_chat] Exception type: {type(e).__name__}")
|
print(f"[stream_new_chat] Exception type: {type(e).__name__}")
|
||||||
print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}")
|
print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}")
|
||||||
|
if error_code == "TURN_CANCELLING":
|
||||||
|
status_payload: dict[str, Any] = {"status": "cancelling"}
|
||||||
|
if error_extra:
|
||||||
|
status_payload.update(error_extra)
|
||||||
|
yield streaming_service.format_data("turn-status", status_payload)
|
||||||
|
else:
|
||||||
|
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||||
|
|
||||||
yield _emit_stream_error(
|
yield _emit_stream_error(
|
||||||
message=user_message,
|
message=user_message,
|
||||||
|
|
@ -2831,7 +2880,9 @@ async def stream_new_chat(
|
||||||
error_code=error_code,
|
error_code=error_code,
|
||||||
severity=severity,
|
severity=severity,
|
||||||
is_expected=is_expected,
|
is_expected=is_expected,
|
||||||
|
extra=error_extra,
|
||||||
)
|
)
|
||||||
|
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||||
yield streaming_service.format_finish_step()
|
yield streaming_service.format_finish_step()
|
||||||
yield streaming_service.format_finish()
|
yield streaming_service.format_finish()
|
||||||
yield streaming_service.format_done()
|
yield streaming_service.format_done()
|
||||||
|
|
@ -2847,6 +2898,10 @@ async def stream_new_chat(
|
||||||
# (CancelledError is a BaseException), and the rest of the
|
# (CancelledError is a BaseException), and the rest of the
|
||||||
# finally block — including session.close() — would never run.
|
# finally block — including session.close() — would never run.
|
||||||
with anyio.CancelScope(shield=True):
|
with anyio.CancelScope(shield=True):
|
||||||
|
# Authoritative fallback cleanup for lock/cancel state. Middleware
|
||||||
|
# teardown can be skipped on some client-abort paths.
|
||||||
|
end_turn(str(chat_id))
|
||||||
|
|
||||||
# Release premium reservation if not finalized
|
# Release premium reservation if not finalized
|
||||||
if _premium_request_id and _premium_reserved > 0 and user_id:
|
if _premium_request_id and _premium_reserved > 0 and user_id:
|
||||||
try:
|
try:
|
||||||
|
|
@ -3206,6 +3261,7 @@ async def stream_resume_chat(
|
||||||
"turn-info",
|
"turn-info",
|
||||||
{"chat_turn_id": stream_result.turn_id},
|
{"chat_turn_id": stream_result.turn_id},
|
||||||
)
|
)
|
||||||
|
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||||
|
|
||||||
_t_stream_start = time.perf_counter()
|
_t_stream_start = time.perf_counter()
|
||||||
_first_event_logged = False
|
_first_event_logged = False
|
||||||
|
|
@ -3305,6 +3361,7 @@ async def stream_resume_chat(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||||
yield streaming_service.format_finish_step()
|
yield streaming_service.format_finish_step()
|
||||||
yield streaming_service.format_finish()
|
yield streaming_service.format_finish()
|
||||||
yield streaming_service.format_done()
|
yield streaming_service.format_done()
|
||||||
|
|
@ -3318,23 +3375,37 @@ async def stream_resume_chat(
|
||||||
severity,
|
severity,
|
||||||
is_expected,
|
is_expected,
|
||||||
user_message,
|
user_message,
|
||||||
|
error_extra,
|
||||||
) = _classify_stream_exception(e, flow_label="resume")
|
) = _classify_stream_exception(e, flow_label="resume")
|
||||||
error_message = f"Error during resume: {e!s}"
|
error_message = f"Error during resume: {e!s}"
|
||||||
print(f"[stream_resume_chat] {error_message}")
|
print(f"[stream_resume_chat] {error_message}")
|
||||||
print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}")
|
print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}")
|
||||||
|
if error_code == "TURN_CANCELLING":
|
||||||
|
status_payload: dict[str, Any] = {"status": "cancelling"}
|
||||||
|
if error_extra:
|
||||||
|
status_payload.update(error_extra)
|
||||||
|
yield streaming_service.format_data("turn-status", status_payload)
|
||||||
|
else:
|
||||||
|
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||||
yield _emit_stream_error(
|
yield _emit_stream_error(
|
||||||
message=user_message,
|
message=user_message,
|
||||||
error_kind=error_kind,
|
error_kind=error_kind,
|
||||||
error_code=error_code,
|
error_code=error_code,
|
||||||
severity=severity,
|
severity=severity,
|
||||||
is_expected=is_expected,
|
is_expected=is_expected,
|
||||||
|
extra=error_extra,
|
||||||
)
|
)
|
||||||
|
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||||
yield streaming_service.format_finish_step()
|
yield streaming_service.format_finish_step()
|
||||||
yield streaming_service.format_finish()
|
yield streaming_service.format_finish()
|
||||||
yield streaming_service.format_done()
|
yield streaming_service.format_done()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
with anyio.CancelScope(shield=True):
|
with anyio.CancelScope(shield=True):
|
||||||
|
# Authoritative fallback cleanup for lock/cancel state. Middleware
|
||||||
|
# teardown can be skipped on some client-abort paths.
|
||||||
|
end_turn(str(chat_id))
|
||||||
|
|
||||||
# Release premium reservation if not finalized
|
# Release premium reservation if not finalized
|
||||||
if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id:
|
if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id:
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,9 @@ import pytest
|
||||||
from app.agents.new_chat.errors import BusyError
|
from app.agents.new_chat.errors import BusyError
|
||||||
from app.agents.new_chat.middleware.busy_mutex import (
|
from app.agents.new_chat.middleware.busy_mutex import (
|
||||||
BusyMutexMiddleware,
|
BusyMutexMiddleware,
|
||||||
|
end_turn,
|
||||||
get_cancel_event,
|
get_cancel_event,
|
||||||
|
is_cancel_requested,
|
||||||
manager,
|
manager,
|
||||||
request_cancel,
|
request_cancel,
|
||||||
reset_cancel,
|
reset_cancel,
|
||||||
|
|
@ -88,3 +90,31 @@ async def test_no_thread_id_skipped_when_not_required() -> None:
|
||||||
def test_reset_cancel_idempotent() -> None:
|
def test_reset_cancel_idempotent() -> None:
|
||||||
# Should not raise even if event was never created
|
# Should not raise even if event was never created
|
||||||
reset_cancel("never-seen")
|
reset_cancel("never-seen")
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_cancel_creates_event_for_unseen_thread() -> None:
|
||||||
|
thread_id = "never-seen-cancel"
|
||||||
|
reset_cancel(thread_id)
|
||||||
|
|
||||||
|
assert request_cancel(thread_id) is True
|
||||||
|
assert get_cancel_event(thread_id).is_set()
|
||||||
|
assert is_cancel_requested(thread_id) is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_end_turn_force_clears_lock_and_cancel_state() -> None:
|
||||||
|
thread_id = "forced-end-turn"
|
||||||
|
mw = BusyMutexMiddleware()
|
||||||
|
runtime = _Runtime(thread_id)
|
||||||
|
|
||||||
|
await mw.abefore_agent({}, runtime)
|
||||||
|
assert manager.lock_for(thread_id).locked()
|
||||||
|
|
||||||
|
request_cancel(thread_id)
|
||||||
|
assert is_cancel_requested(thread_id) is True
|
||||||
|
|
||||||
|
end_turn(thread_id)
|
||||||
|
|
||||||
|
assert not manager.lock_for(thread_id).locked()
|
||||||
|
assert not get_cancel_event(thread_id).is_set()
|
||||||
|
assert is_cancel_requested(thread_id) is False
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ import pytest
|
||||||
|
|
||||||
import app.tasks.chat.stream_new_chat as stream_new_chat_module
|
import app.tasks.chat.stream_new_chat as stream_new_chat_module
|
||||||
from app.agents.new_chat.errors import BusyError
|
from app.agents.new_chat.errors import BusyError
|
||||||
|
from app.agents.new_chat.middleware.busy_mutex import request_cancel, reset_cancel
|
||||||
from app.tasks.chat.stream_new_chat import (
|
from app.tasks.chat.stream_new_chat import (
|
||||||
StreamResult,
|
StreamResult,
|
||||||
_classify_stream_exception,
|
_classify_stream_exception,
|
||||||
|
|
@ -147,7 +148,7 @@ def test_stream_exception_classifies_rate_limited():
|
||||||
exc = Exception(
|
exc = Exception(
|
||||||
'{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}'
|
'{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}'
|
||||||
)
|
)
|
||||||
kind, code, severity, is_expected, user_message = _classify_stream_exception(
|
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||||
exc, flow_label="chat"
|
exc, flow_label="chat"
|
||||||
)
|
)
|
||||||
assert kind == "rate_limited"
|
assert kind == "rate_limited"
|
||||||
|
|
@ -155,11 +156,12 @@ def test_stream_exception_classifies_rate_limited():
|
||||||
assert severity == "warn"
|
assert severity == "warn"
|
||||||
assert is_expected is True
|
assert is_expected is True
|
||||||
assert "temporarily rate-limited" in user_message
|
assert "temporarily rate-limited" in user_message
|
||||||
|
assert extra is None
|
||||||
|
|
||||||
|
|
||||||
def test_stream_exception_classifies_thread_busy():
|
def test_stream_exception_classifies_thread_busy():
|
||||||
exc = BusyError(request_id="thread-123")
|
exc = BusyError(request_id="thread-123")
|
||||||
kind, code, severity, is_expected, user_message = _classify_stream_exception(
|
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||||
exc, flow_label="chat"
|
exc, flow_label="chat"
|
||||||
)
|
)
|
||||||
assert kind == "thread_busy"
|
assert kind == "thread_busy"
|
||||||
|
|
@ -167,11 +169,12 @@ def test_stream_exception_classifies_thread_busy():
|
||||||
assert severity == "warn"
|
assert severity == "warn"
|
||||||
assert is_expected is True
|
assert is_expected is True
|
||||||
assert "still finishing for this thread" in user_message
|
assert "still finishing for this thread" in user_message
|
||||||
|
assert extra is None
|
||||||
|
|
||||||
|
|
||||||
def test_stream_exception_classifies_thread_busy_from_message():
|
def test_stream_exception_classifies_thread_busy_from_message():
|
||||||
exc = Exception("Thread is busy with another request")
|
exc = Exception("Thread is busy with another request")
|
||||||
kind, code, severity, is_expected, user_message = _classify_stream_exception(
|
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||||
exc, flow_label="chat"
|
exc, flow_label="chat"
|
||||||
)
|
)
|
||||||
assert kind == "thread_busy"
|
assert kind == "thread_busy"
|
||||||
|
|
@ -179,6 +182,24 @@ def test_stream_exception_classifies_thread_busy_from_message():
|
||||||
assert severity == "warn"
|
assert severity == "warn"
|
||||||
assert is_expected is True
|
assert is_expected is True
|
||||||
assert "still finishing for this thread" in user_message
|
assert "still finishing for this thread" in user_message
|
||||||
|
assert extra is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_exception_classifies_turn_cancelling_when_cancel_requested():
|
||||||
|
thread_id = "thread-cancelling-1"
|
||||||
|
reset_cancel(thread_id)
|
||||||
|
request_cancel(thread_id)
|
||||||
|
exc = BusyError(request_id=thread_id)
|
||||||
|
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||||
|
exc, flow_label="chat"
|
||||||
|
)
|
||||||
|
assert kind == "thread_busy"
|
||||||
|
assert code == "TURN_CANCELLING"
|
||||||
|
assert severity == "info"
|
||||||
|
assert is_expected is True
|
||||||
|
assert "stopping" in user_message
|
||||||
|
assert isinstance(extra, dict)
|
||||||
|
assert "retry_after_ms" in extra
|
||||||
|
|
||||||
|
|
||||||
def test_premium_classification_is_error_code_driven():
|
def test_premium_classification_is_error_code_driven():
|
||||||
|
|
@ -219,33 +240,33 @@ def test_toast_only_pre_accept_policy_has_no_inline_failed_marker():
|
||||||
def test_network_send_failures_use_unified_retry_toast_message():
|
def test_network_send_failures_use_unified_retry_toast_message():
|
||||||
classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts"
|
classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts"
|
||||||
classifier_source = classifier_path.read_text(encoding="utf-8")
|
classifier_source = classifier_path.read_text(encoding="utf-8")
|
||||||
page_path = (
|
request_errors_path = (
|
||||||
Path(__file__).resolve().parents[3]
|
Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-request-errors.ts"
|
||||||
/ "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
|
|
||||||
)
|
)
|
||||||
page_source = page_path.read_text(encoding="utf-8")
|
request_errors_source = request_errors_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
assert '"send_failed_pre_accept"' in classifier_source
|
assert '"send_failed_pre_accept"' in classifier_source
|
||||||
assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source
|
assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source
|
||||||
|
assert 'errorCode === "TURN_CANCELLING"' in classifier_source
|
||||||
assert "if (withCode.code) return withCode.code;" in classifier_source
|
assert "if (withCode.code) return withCode.code;" in classifier_source
|
||||||
assert 'userMessage: "Message not sent. Please retry."' in classifier_source
|
assert 'userMessage: "Message not sent. Please retry."' in classifier_source
|
||||||
assert 'userMessage: "Connection issue. Please try again."' in classifier_source
|
assert 'userMessage: "Connection issue. Please try again."' in classifier_source
|
||||||
assert "tagPreAcceptSendFailure(error)" in page_source
|
assert "const passthroughCodes = new Set([" in request_errors_source
|
||||||
assert "const passthroughCodes = new Set([" in page_source
|
assert '"PREMIUM_QUOTA_EXHAUSTED"' in request_errors_source
|
||||||
assert '"PREMIUM_QUOTA_EXHAUSTED"' in page_source
|
assert '"THREAD_BUSY"' in request_errors_source
|
||||||
assert '"THREAD_BUSY"' in page_source
|
assert '"TURN_CANCELLING"' in request_errors_source
|
||||||
assert '"AUTH_EXPIRED"' in page_source
|
assert '"AUTH_EXPIRED"' in request_errors_source
|
||||||
assert '"UNAUTHORIZED"' in page_source
|
assert '"UNAUTHORIZED"' in request_errors_source
|
||||||
assert '"RATE_LIMITED"' in page_source
|
assert '"RATE_LIMITED"' in request_errors_source
|
||||||
assert '"NETWORK_ERROR"' in page_source
|
assert '"NETWORK_ERROR"' in request_errors_source
|
||||||
assert '"STREAM_PARSE_ERROR"' in page_source
|
assert '"STREAM_PARSE_ERROR"' in request_errors_source
|
||||||
assert '"TOOL_EXECUTION_ERROR"' in page_source
|
assert '"TOOL_EXECUTION_ERROR"' in request_errors_source
|
||||||
assert '"PERSIST_MESSAGE_FAILED"' in page_source
|
assert '"PERSIST_MESSAGE_FAILED"' in request_errors_source
|
||||||
assert '"SERVER_ERROR"' in page_source
|
assert '"SERVER_ERROR"' in request_errors_source
|
||||||
assert "passthroughCodes.has(existingCode)" in page_source
|
assert "passthroughCodes.has(existingCode)" in request_errors_source
|
||||||
assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in page_source
|
assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in request_errors_source
|
||||||
assert 'errorCode: "NETWORK_ERROR"' not in page_source
|
assert 'errorCode: "NETWORK_ERROR"' not in request_errors_source
|
||||||
assert "Failed to start chat. Please try again." not in page_source
|
assert "Failed to start chat. Please try again." not in classifier_source
|
||||||
|
|
||||||
|
|
||||||
def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows():
|
def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows():
|
||||||
|
|
@ -269,3 +290,75 @@ def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows()
|
||||||
|
|
||||||
# New flow persists only when accepted and not already persisted.
|
# New flow persists only when accepted and not already persisted.
|
||||||
assert "if (newAccepted && !userPersisted) {" in source
|
assert "if (newAccepted && !userPersisted) {" in source
|
||||||
|
assert "const fetchWithTurnCancellingRetry = useCallback(" in source
|
||||||
|
assert "computeFallbackTurnCancellingRetryDelay" in source
|
||||||
|
assert 'withMeta.errorCode === "TURN_CANCELLING"' in source
|
||||||
|
assert 'withMeta.errorCode === "THREAD_BUSY"' in source
|
||||||
|
assert "await fetchWithTurnCancellingRetry(() =>" in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancel_active_turn_route_contract_exists():
|
||||||
|
routes_path = (
|
||||||
|
Path(__file__).resolve().parents[3]
|
||||||
|
/ "surfsense_backend/app/routes/new_chat_routes.py"
|
||||||
|
)
|
||||||
|
source = routes_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
assert '@router.post(\n "/threads/{thread_id}/cancel-active-turn",' in source
|
||||||
|
assert "response_model=CancelActiveTurnResponse" in source
|
||||||
|
assert 'status="cancelling",' in source
|
||||||
|
assert 'error_code="TURN_CANCELLING",' in source
|
||||||
|
assert "retry_after_ms=retry_after_ms if retry_after_ms > 0 else None," in source
|
||||||
|
assert "retry_after_at=" in source
|
||||||
|
assert 'status="idle",' in source
|
||||||
|
assert 'error_code="NO_ACTIVE_TURN",' in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_turn_status_route_contract_exists():
|
||||||
|
routes_path = (
|
||||||
|
Path(__file__).resolve().parents[3]
|
||||||
|
/ "surfsense_backend/app/routes/new_chat_routes.py"
|
||||||
|
)
|
||||||
|
source = routes_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
assert '@router.get(\n "/threads/{thread_id}/turn-status",' in source
|
||||||
|
assert "response_model=TurnStatusResponse" in source
|
||||||
|
assert "_build_turn_status_payload(thread_id)" in source
|
||||||
|
assert "Permission.CHATS_READ.value" in source
|
||||||
|
assert "_raise_if_thread_busy_for_start(" in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_turn_cancelling_retry_policy_contract_exists():
|
||||||
|
routes_path = (
|
||||||
|
Path(__file__).resolve().parents[3]
|
||||||
|
/ "surfsense_backend/app/routes/new_chat_routes.py"
|
||||||
|
)
|
||||||
|
source = routes_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
assert "TURN_CANCELLING_INITIAL_DELAY_MS = 200" in source
|
||||||
|
assert "TURN_CANCELLING_BACKOFF_FACTOR = 2" in source
|
||||||
|
assert "TURN_CANCELLING_MAX_DELAY_MS = 1500" in source
|
||||||
|
assert "def _compute_turn_cancelling_retry_delay(" in source
|
||||||
|
assert "retry-after-ms" in source
|
||||||
|
assert '"Retry-After"' in source
|
||||||
|
assert '"errorCode": "TURN_CANCELLING"' in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_turn_status_sse_contract_exists():
|
||||||
|
stream_source = (
|
||||||
|
Path(__file__).resolve().parents[3]
|
||||||
|
/ "surfsense_backend/app/tasks/chat/stream_new_chat.py"
|
||||||
|
).read_text(encoding="utf-8")
|
||||||
|
state_source = (
|
||||||
|
Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/streaming-state.ts"
|
||||||
|
).read_text(encoding="utf-8")
|
||||||
|
pipeline_source = (
|
||||||
|
Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/stream-pipeline.ts"
|
||||||
|
).read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
assert '"turn-status"' in stream_source
|
||||||
|
assert '"status": "busy"' in stream_source
|
||||||
|
assert '"status": "idle"' in stream_source
|
||||||
|
assert "type: \"data-turn-status\"" in state_source
|
||||||
|
assert "case \"data-turn-status\":" in pipeline_source
|
||||||
|
assert "end_turn(str(chat_id))" in stream_source
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -26,7 +26,14 @@ export const setThreadVisibilityAtom = atom(null, (get, set, newVisibility: Chat
|
||||||
|
|
||||||
export const resetCurrentThreadAtom = atom(null, (_, set) => {
|
export const resetCurrentThreadAtom = atom(null, (_, set) => {
|
||||||
set(currentThreadAtom, initialState);
|
set(currentThreadAtom, initialState);
|
||||||
set(reportPanelAtom, { isOpen: false, reportId: null, title: null, wordCount: null });
|
set(reportPanelAtom, {
|
||||||
|
isOpen: false,
|
||||||
|
reportId: null,
|
||||||
|
title: null,
|
||||||
|
wordCount: null,
|
||||||
|
shareToken: null,
|
||||||
|
contentType: "markdown",
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
/** Target comment ID to scroll to (from URL navigation or inbox click) */
|
/** Target comment ID to scroll to (from URL navigation or inbox click) */
|
||||||
|
|
|
||||||
|
|
@ -548,8 +548,10 @@ const AssistantMessageInner: FC = () => {
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
<div className="aui-assistant-message-footer mt-3 mb-5 ml-2 flex items-center gap-2">
|
<div className="aui-assistant-message-footer mt-3 mb-5 ml-2 h-6">
|
||||||
<AssistantActionBar />
|
<div className="h-full opacity-100 transition-opacity">
|
||||||
|
<AssistantActionBar />
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</CitationMetadataProvider>
|
</CitationMetadataProvider>
|
||||||
);
|
);
|
||||||
|
|
@ -642,35 +644,41 @@ export const AssistantMessage: FC = () => {
|
||||||
className="aui-assistant-message-root group fade-in slide-in-from-bottom-1 relative mx-auto w-full max-w-(--thread-max-width) animate-in py-3 duration-150"
|
className="aui-assistant-message-root group fade-in slide-in-from-bottom-1 relative mx-auto w-full max-w-(--thread-max-width) animate-in py-3 duration-150"
|
||||||
data-role="assistant"
|
data-role="assistant"
|
||||||
>
|
>
|
||||||
{/* Comment trigger — right-aligned, just below user query on all screen sizes */}
|
{/* Fixed trigger slot prevents any vertical reflow when visibility changes */}
|
||||||
{showCommentTrigger && (
|
<div className="mr-2 mb-1 flex h-7 justify-end">
|
||||||
<div className="mr-2 mb-1 flex justify-end">
|
<button
|
||||||
<button
|
ref={isDesktop ? commentTriggerRef : undefined}
|
||||||
ref={isDesktop ? commentTriggerRef : undefined}
|
type="button"
|
||||||
type="button"
|
onClick={
|
||||||
onClick={
|
showCommentTrigger
|
||||||
isDesktop ? () => setIsInlineOpen((prev) => !prev) : () => setIsSheetOpen(true)
|
? isDesktop
|
||||||
}
|
? () => setIsInlineOpen((prev) => !prev)
|
||||||
className={cn(
|
: () => setIsSheetOpen(true)
|
||||||
"flex items-center gap-1.5 rounded-full px-3 py-1 text-sm transition-colors",
|
: undefined
|
||||||
isDesktop && isInlineOpen
|
}
|
||||||
? "bg-primary/10 text-primary"
|
aria-hidden={!showCommentTrigger}
|
||||||
: hasComments
|
tabIndex={showCommentTrigger ? 0 : -1}
|
||||||
? "text-primary hover:bg-primary/10"
|
className={cn(
|
||||||
: "text-muted-foreground hover:text-foreground hover:bg-muted"
|
"flex items-center gap-1.5 rounded-full px-3 py-1 text-sm transition-colors",
|
||||||
)}
|
"opacity-0 pointer-events-none",
|
||||||
>
|
showCommentTrigger && "opacity-100 pointer-events-auto",
|
||||||
<MessageCircleReply className={cn("size-3.5", hasComments && "fill-current")} />
|
isDesktop && isInlineOpen
|
||||||
{hasComments ? (
|
? "bg-primary/10 text-primary"
|
||||||
<span>
|
: hasComments
|
||||||
{commentCount} {commentCount === 1 ? "comment" : "comments"}
|
? "text-primary hover:bg-primary/10"
|
||||||
</span>
|
: "text-muted-foreground hover:text-foreground hover:bg-muted"
|
||||||
) : (
|
)}
|
||||||
<span>Add comment</span>
|
>
|
||||||
)}
|
<MessageCircleReply className={cn("size-3.5", hasComments && "fill-current")} />
|
||||||
</button>
|
{hasComments ? (
|
||||||
</div>
|
<span>
|
||||||
)}
|
{commentCount} {commentCount === 1 ? "comment" : "comments"}
|
||||||
|
</span>
|
||||||
|
) : (
|
||||||
|
<span>Add comment</span>
|
||||||
|
)}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
{/* Desktop floating comment panel — overlays on top of chat content */}
|
{/* Desktop floating comment panel — overlays on top of chat content */}
|
||||||
{showCommentTrigger && isDesktop && isInlineOpen && dbMessageId && (
|
{showCommentTrigger && isDesktop && isInlineOpen && dbMessageId && (
|
||||||
|
|
|
||||||
52
surfsense_web/components/assistant-ui/chat-viewport.tsx
Normal file
52
surfsense_web/components/assistant-ui/chat-viewport.tsx
Normal file
|
|
@ -0,0 +1,52 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { ThreadPrimitive } from "@assistant-ui/react";
|
||||||
|
import { ArrowDownIcon } from "lucide-react";
|
||||||
|
import type { FC, ReactNode } from "react";
|
||||||
|
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
||||||
|
|
||||||
|
const ChatScrollToBottom: FC = () => (
|
||||||
|
<ThreadPrimitive.ScrollToBottom asChild>
|
||||||
|
<TooltipIconButton
|
||||||
|
tooltip="Scroll to bottom"
|
||||||
|
variant="outline"
|
||||||
|
className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent"
|
||||||
|
>
|
||||||
|
<ArrowDownIcon />
|
||||||
|
</TooltipIconButton>
|
||||||
|
</ThreadPrimitive.ScrollToBottom>
|
||||||
|
);
|
||||||
|
|
||||||
|
export interface ChatViewportProps {
|
||||||
|
children: ReactNode;
|
||||||
|
footer?: ReactNode;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const ChatViewport: FC<ChatViewportProps> = ({ children, footer }) => (
|
||||||
|
<ThreadPrimitive.Viewport
|
||||||
|
turnAnchor="top"
|
||||||
|
autoScroll
|
||||||
|
scrollToBottomOnRunStart
|
||||||
|
scrollToBottomOnInitialize
|
||||||
|
scrollToBottomOnThreadSwitch
|
||||||
|
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 scroll-smooth"
|
||||||
|
style={{ scrollbarGutter: "stable" }}
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
aria-hidden
|
||||||
|
className="aui-chat-viewport-top-fade pointer-events-none sticky top-0 z-10 -mx-4 h-2 shrink-0 bg-gradient-to-b from-main-panel from-20% to-transparent"
|
||||||
|
/>
|
||||||
|
{children}
|
||||||
|
{footer ? (
|
||||||
|
<ThreadPrimitive.ViewportFooter
|
||||||
|
className="aui-chat-composer-footer sticky bottom-0 z-20 -mx-4 mt-auto flex flex-col items-stretch bg-gradient-to-t from-main-panel from-60% to-transparent px-4 pt-6"
|
||||||
|
style={{ paddingBottom: "max(0.5rem, env(safe-area-inset-bottom))" }}
|
||||||
|
>
|
||||||
|
<div className="aui-chat-composer-area relative mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-3 overflow-visible">
|
||||||
|
<ChatScrollToBottom />
|
||||||
|
{footer}
|
||||||
|
</div>
|
||||||
|
</ThreadPrimitive.ViewportFooter>
|
||||||
|
) : null}
|
||||||
|
</ThreadPrimitive.Viewport>
|
||||||
|
);
|
||||||
File diff suppressed because it is too large
Load diff
24
surfsense_web/components/assistant-ui/nested-scroll.tsx
Normal file
24
surfsense_web/components/assistant-ui/nested-scroll.tsx
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { forwardRef, type ComponentPropsWithoutRef, type WheelEvent } from "react";
|
||||||
|
|
||||||
|
export type NestedScrollProps = ComponentPropsWithoutRef<"div">;
|
||||||
|
|
||||||
|
export const NestedScroll = forwardRef<HTMLDivElement, NestedScrollProps>(
|
||||||
|
({ onWheel, ...props }, ref) => {
|
||||||
|
const handleWheel = (event: WheelEvent<HTMLDivElement>) => {
|
||||||
|
const el = event.currentTarget;
|
||||||
|
const canScrollUp = el.scrollTop > 0;
|
||||||
|
const canScrollDown = el.scrollTop < el.scrollHeight - el.clientHeight - 1;
|
||||||
|
const goingUp = event.deltaY < 0;
|
||||||
|
const goingDown = event.deltaY > 0;
|
||||||
|
if ((goingUp && canScrollUp) || (goingDown && canScrollDown)) {
|
||||||
|
event.stopPropagation();
|
||||||
|
}
|
||||||
|
onWheel?.(event);
|
||||||
|
};
|
||||||
|
return <div ref={ref} onWheel={handleWheel} {...props} />;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
NestedScroll.displayName = "NestedScroll";
|
||||||
|
|
@ -1,18 +0,0 @@
|
||||||
import { ThreadPrimitive } from "@assistant-ui/react";
|
|
||||||
import { ArrowDownIcon } from "lucide-react";
|
|
||||||
import type { FC } from "react";
|
|
||||||
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
|
||||||
|
|
||||||
export const ThreadScrollToBottom: FC = () => {
|
|
||||||
return (
|
|
||||||
<ThreadPrimitive.ScrollToBottom asChild>
|
|
||||||
<TooltipIconButton
|
|
||||||
tooltip="Scroll to bottom"
|
|
||||||
variant="outline"
|
|
||||||
className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent"
|
|
||||||
>
|
|
||||||
<ArrowDownIcon />
|
|
||||||
</TooltipIconButton>
|
|
||||||
</ThreadPrimitive.ScrollToBottom>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
@ -5,12 +5,10 @@ import {
|
||||||
ThreadPrimitive,
|
ThreadPrimitive,
|
||||||
useAui,
|
useAui,
|
||||||
useAuiState,
|
useAuiState,
|
||||||
useThreadViewportStore,
|
|
||||||
} from "@assistant-ui/react";
|
} from "@assistant-ui/react";
|
||||||
import { useAtom, useAtomValue, useSetAtom } from "jotai";
|
import { useAtom, useAtomValue, useSetAtom } from "jotai";
|
||||||
import {
|
import {
|
||||||
AlertCircle,
|
AlertCircle,
|
||||||
ArrowDownIcon,
|
|
||||||
ArrowUpIcon,
|
ArrowUpIcon,
|
||||||
Camera,
|
Camera,
|
||||||
ChevronDown,
|
ChevronDown,
|
||||||
|
|
@ -55,6 +53,7 @@ import {
|
||||||
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
|
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
|
||||||
import { AssistantMessage } from "@/components/assistant-ui/assistant-message";
|
import { AssistantMessage } from "@/components/assistant-ui/assistant-message";
|
||||||
import { ChatSessionStatus } from "@/components/assistant-ui/chat-session-status";
|
import { ChatSessionStatus } from "@/components/assistant-ui/chat-session-status";
|
||||||
|
import { ChatViewport } from "@/components/assistant-ui/chat-viewport";
|
||||||
import { ConnectorIndicator } from "@/components/assistant-ui/connector-popup";
|
import { ConnectorIndicator } from "@/components/assistant-ui/connector-popup";
|
||||||
import { useDocumentUploadDialog } from "@/components/assistant-ui/document-upload-popup";
|
import { useDocumentUploadDialog } from "@/components/assistant-ui/document-upload-popup";
|
||||||
import {
|
import {
|
||||||
|
|
@ -112,10 +111,13 @@ const ThreadContent: FC = () => {
|
||||||
["--thread-max-width" as string]: "44rem",
|
["--thread-max-width" as string]: "44rem",
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<ThreadPrimitive.Viewport
|
<ChatViewport
|
||||||
turnAnchor="top"
|
footer={
|
||||||
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4"
|
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
||||||
style={{ scrollbarGutter: "stable" }}
|
<PremiumQuotaPinnedAlert />
|
||||||
|
<Composer />
|
||||||
|
</AuiIf>
|
||||||
|
}
|
||||||
>
|
>
|
||||||
<AuiIf condition={({ thread }) => thread.isEmpty}>
|
<AuiIf condition={({ thread }) => thread.isEmpty}>
|
||||||
<ThreadWelcome />
|
<ThreadWelcome />
|
||||||
|
|
@ -128,24 +130,7 @@ const ThreadContent: FC = () => {
|
||||||
AssistantMessage,
|
AssistantMessage,
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
</ChatViewport>
|
||||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
|
||||||
<div className="grow" />
|
|
||||||
</AuiIf>
|
|
||||||
|
|
||||||
<ThreadPrimitive.ViewportFooter
|
|
||||||
className="aui-thread-viewport-footer sticky bottom-0 z-10 mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible rounded-t-3xl bg-main-panel pb-4 md:pb-6"
|
|
||||||
style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }}
|
|
||||||
>
|
|
||||||
<ThreadScrollToBottom />
|
|
||||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
|
||||||
<PremiumQuotaPinnedAlert />
|
|
||||||
</AuiIf>
|
|
||||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
|
||||||
<Composer />
|
|
||||||
</AuiIf>
|
|
||||||
</ThreadPrimitive.ViewportFooter>
|
|
||||||
</ThreadPrimitive.Viewport>
|
|
||||||
</ThreadPrimitive.Root>
|
</ThreadPrimitive.Root>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
@ -181,20 +166,6 @@ const PremiumQuotaPinnedAlert: FC = () => {
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
const ThreadScrollToBottom: FC = () => {
|
|
||||||
return (
|
|
||||||
<ThreadPrimitive.ScrollToBottom asChild>
|
|
||||||
<TooltipIconButton
|
|
||||||
tooltip="Scroll to bottom"
|
|
||||||
variant="outline"
|
|
||||||
className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent"
|
|
||||||
>
|
|
||||||
<ArrowDownIcon />
|
|
||||||
</TooltipIconButton>
|
|
||||||
</ThreadPrimitive.ScrollToBottom>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const getTimeBasedGreeting = (user?: { display_name?: string | null; email?: string }): string => {
|
const getTimeBasedGreeting = (user?: { display_name?: string | null; email?: string }): string => {
|
||||||
const hour = new Date().getHours();
|
const hour = new Date().getHours();
|
||||||
|
|
||||||
|
|
@ -411,23 +382,9 @@ const Composer: FC = () => {
|
||||||
>(new Map());
|
>(new Map());
|
||||||
const documentPickerRef = useRef<DocumentMentionPickerRef>(null);
|
const documentPickerRef = useRef<DocumentMentionPickerRef>(null);
|
||||||
const promptPickerRef = useRef<PromptPickerRef>(null);
|
const promptPickerRef = useRef<PromptPickerRef>(null);
|
||||||
const viewportRef = useRef<Element | null>(null);
|
|
||||||
const { search_space_id, chat_id } = useParams();
|
const { search_space_id, chat_id } = useParams();
|
||||||
const aui = useAui();
|
const aui = useAui();
|
||||||
const threadViewportStore = useThreadViewportStore();
|
|
||||||
const hasAutoFocusedRef = useRef(false);
|
const hasAutoFocusedRef = useRef(false);
|
||||||
const submitCleanupRef = useRef<(() => void) | null>(null);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
return () => {
|
|
||||||
submitCleanupRef.current?.();
|
|
||||||
};
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
// Store viewport element reference on mount
|
|
||||||
useEffect(() => {
|
|
||||||
viewportRef.current = document.querySelector(".aui-thread-viewport");
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const electronAPI = useElectronAPI();
|
const electronAPI = useElectronAPI();
|
||||||
const [clipboardInitialText, setClipboardInitialText] = useState<string | undefined>();
|
const [clipboardInitialText, setClipboardInitialText] = useState<string | undefined>();
|
||||||
|
|
@ -626,7 +583,6 @@ const Composer: FC = () => {
|
||||||
[showDocumentPopover, showPromptPicker]
|
[showDocumentPopover, showPromptPicker]
|
||||||
);
|
);
|
||||||
|
|
||||||
// Submit message (blocked during streaming, document picker open, or AI responding to another user)
|
|
||||||
const handleSubmit = useCallback(() => {
|
const handleSubmit = useCallback(() => {
|
||||||
if (isThreadRunning || isBlockedByOtherUser) return;
|
if (isThreadRunning || isBlockedByOtherUser) return;
|
||||||
if (showDocumentPopover || showPromptPicker) return;
|
if (showDocumentPopover || showPromptPicker) return;
|
||||||
|
|
@ -638,50 +594,9 @@ const Composer: FC = () => {
|
||||||
setClipboardInitialText(undefined);
|
setClipboardInitialText(undefined);
|
||||||
}
|
}
|
||||||
|
|
||||||
const viewportEl = viewportRef.current;
|
|
||||||
const heightBefore = viewportEl?.scrollHeight ?? 0;
|
|
||||||
|
|
||||||
aui.composer().send();
|
aui.composer().send();
|
||||||
editorRef.current?.clear();
|
editorRef.current?.clear();
|
||||||
setMentionedDocuments([]);
|
setMentionedDocuments([]);
|
||||||
|
|
||||||
// With turnAnchor="top", ViewportSlack adds min-height to the last
|
|
||||||
// assistant message so that scrolling-to-bottom actually positions the
|
|
||||||
// user message at the TOP of the viewport. That slack height is
|
|
||||||
// calculated asynchronously (ResizeObserver → style → layout).
|
|
||||||
// Poll via rAF for ~500ms, re-scrolling whenever scrollHeight changes.
|
|
||||||
const scrollToBottom = () =>
|
|
||||||
threadViewportStore.getState().scrollToBottom({ behavior: "instant" });
|
|
||||||
|
|
||||||
let lastHeight = heightBefore;
|
|
||||||
let frames = 0;
|
|
||||||
let cancelled = false;
|
|
||||||
const POLL_FRAMES = 30;
|
|
||||||
|
|
||||||
const pollAndScroll = () => {
|
|
||||||
if (cancelled) return;
|
|
||||||
const el = viewportRef.current;
|
|
||||||
if (el) {
|
|
||||||
const h = el.scrollHeight;
|
|
||||||
if (h !== lastHeight) {
|
|
||||||
lastHeight = h;
|
|
||||||
scrollToBottom();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (++frames < POLL_FRAMES) {
|
|
||||||
requestAnimationFrame(pollAndScroll);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
requestAnimationFrame(pollAndScroll);
|
|
||||||
|
|
||||||
const t1 = setTimeout(scrollToBottom, 100);
|
|
||||||
const t2 = setTimeout(scrollToBottom, 300);
|
|
||||||
|
|
||||||
submitCleanupRef.current = () => {
|
|
||||||
cancelled = true;
|
|
||||||
clearTimeout(t1);
|
|
||||||
clearTimeout(t2);
|
|
||||||
};
|
|
||||||
}, [
|
}, [
|
||||||
showDocumentPopover,
|
showDocumentPopover,
|
||||||
showPromptPicker,
|
showPromptPicker,
|
||||||
|
|
@ -690,7 +605,6 @@ const Composer: FC = () => {
|
||||||
clipboardInitialText,
|
clipboardInitialText,
|
||||||
aui,
|
aui,
|
||||||
setMentionedDocuments,
|
setMentionedDocuments,
|
||||||
threadViewportStore,
|
|
||||||
]);
|
]);
|
||||||
|
|
||||||
const handleDocumentRemove = useCallback(
|
const handleDocumentRemove = useCallback(
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ import {
|
||||||
isDoomLoopInterrupt,
|
isDoomLoopInterrupt,
|
||||||
} from "@/components/tool-ui/doom-loop-approval";
|
} from "@/components/tool-ui/doom-loop-approval";
|
||||||
import { GenericHitlApprovalToolUI } from "@/components/tool-ui/generic-hitl-approval";
|
import { GenericHitlApprovalToolUI } from "@/components/tool-ui/generic-hitl-approval";
|
||||||
|
import { NestedScroll } from "@/components/assistant-ui/nested-scroll";
|
||||||
import {
|
import {
|
||||||
AlertDialog,
|
AlertDialog,
|
||||||
AlertDialogAction,
|
AlertDialogAction,
|
||||||
|
|
@ -475,7 +476,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => {
|
||||||
{(argsText || isRunning) && (
|
{(argsText || isRunning) && (
|
||||||
<div className="flex flex-col gap-1 min-w-0">
|
<div className="flex flex-col gap-1 min-w-0">
|
||||||
<p className="text-xs font-medium text-muted-foreground">Inputs</p>
|
<p className="text-xs font-medium text-muted-foreground">Inputs</p>
|
||||||
<div className="max-h-48 overflow-auto rounded-md bg-muted/40">
|
<NestedScroll className="max-h-48 overflow-auto rounded-md bg-muted/40">
|
||||||
{argsText ? (
|
{argsText ? (
|
||||||
<pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono">
|
<pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono">
|
||||||
{argsText}
|
{argsText}
|
||||||
|
|
@ -489,7 +490,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => {
|
||||||
Waiting for input…
|
Waiting for input…
|
||||||
</p>
|
</p>
|
||||||
)}
|
)}
|
||||||
</div>
|
</NestedScroll>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
{!isCancelled && result !== undefined && (
|
{!isCancelled && result !== undefined && (
|
||||||
|
|
@ -497,11 +498,11 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => {
|
||||||
<Separator />
|
<Separator />
|
||||||
<div className="flex flex-col gap-1 min-w-0">
|
<div className="flex flex-col gap-1 min-w-0">
|
||||||
<p className="text-xs font-medium text-muted-foreground">Result</p>
|
<p className="text-xs font-medium text-muted-foreground">Result</p>
|
||||||
<div className="max-h-64 overflow-auto rounded-md bg-muted/40">
|
<NestedScroll className="max-h-64 overflow-auto rounded-md bg-muted/40">
|
||||||
<pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono">
|
<pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono">
|
||||||
{typeof result === "string" ? result : serializedResult}
|
{typeof result === "string" ? result : serializedResult}
|
||||||
</pre>
|
</pre>
|
||||||
</div>
|
</NestedScroll>
|
||||||
</div>
|
</div>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,10 @@
|
||||||
import { ActionBarPrimitive, AuiIf, MessagePrimitive, useAuiState } from "@assistant-ui/react";
|
import {
|
||||||
|
ActionBarPrimitive,
|
||||||
|
AuiIf,
|
||||||
|
MessagePrimitive,
|
||||||
|
useAuiState,
|
||||||
|
useMessagePartText,
|
||||||
|
} from "@assistant-ui/react";
|
||||||
import { useAtomValue } from "jotai";
|
import { useAtomValue } from "jotai";
|
||||||
import { CheckIcon, CopyIcon, Pencil } from "lucide-react";
|
import { CheckIcon, CopyIcon, Pencil } from "lucide-react";
|
||||||
import Image from "next/image";
|
import Image from "next/image";
|
||||||
|
|
@ -7,6 +13,8 @@ import { currentThreadAtom } from "@/atoms/chat/current-thread.atom";
|
||||||
import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom";
|
import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom";
|
||||||
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
||||||
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
||||||
|
import { getMentionDocKey } from "@/lib/chat/mention-doc-key";
|
||||||
|
import { parseMentionSegments } from "@/lib/chat/parse-mention-segments";
|
||||||
|
|
||||||
interface AuthorMetadata {
|
interface AuthorMetadata {
|
||||||
displayName: string | null;
|
displayName: string | null;
|
||||||
|
|
@ -47,23 +55,40 @@ const UserAvatar: FC<AuthorMetadata> = ({ displayName, avatarUrl }) => {
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export const UserMessage: FC = () => {
|
const UserTextPart: FC = () => {
|
||||||
const messageId = useAuiState(({ message }) => message?.id);
|
const messageId = useAuiState(({ message }) => message?.id);
|
||||||
const messageText = useAuiState(({ message }) =>
|
const part = useMessagePartText();
|
||||||
(message?.content ?? [])
|
const text = (part as { text?: string }).text ?? "";
|
||||||
.map((part) =>
|
|
||||||
typeof part === "object" &&
|
|
||||||
part !== null &&
|
|
||||||
"type" in part &&
|
|
||||||
(part as { type?: string }).type === "text" &&
|
|
||||||
"text" in part
|
|
||||||
? String((part as { text?: string }).text ?? "")
|
|
||||||
: ""
|
|
||||||
)
|
|
||||||
.join("")
|
|
||||||
);
|
|
||||||
const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom);
|
const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom);
|
||||||
const mentionedDocs = messageId ? messageDocumentsMap[messageId] : undefined;
|
const mentionedDocs = (messageId ? messageDocumentsMap[messageId] : undefined) ?? [];
|
||||||
|
|
||||||
|
const segments = parseMentionSegments(text, mentionedDocs);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<p style={{ whiteSpace: "pre-line" }} className="break-words">
|
||||||
|
{segments.map((segment) =>
|
||||||
|
segment.type === "text" ? (
|
||||||
|
<span key={`txt-${segment.start}`}>{segment.value}</span>
|
||||||
|
) : (
|
||||||
|
<span
|
||||||
|
key={`mention-${getMentionDocKey(segment.doc)}-${segment.start}`}
|
||||||
|
className="inline-flex items-center gap-1 mx-0.5 px-1 py-0.5 rounded bg-primary/10 text-xs font-bold text-primary/60 select-none align-middle leading-none"
|
||||||
|
title={segment.doc.title}
|
||||||
|
>
|
||||||
|
<span className="flex items-center text-muted-foreground">
|
||||||
|
{getConnectorIcon(segment.doc.document_type ?? "UNKNOWN", "h-3 w-3")}
|
||||||
|
</span>
|
||||||
|
<span className="max-w-[120px] truncate">{segment.doc.title}</span>
|
||||||
|
</span>
|
||||||
|
)
|
||||||
|
)}
|
||||||
|
</p>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
const userMessageParts = { Text: UserTextPart };
|
||||||
|
|
||||||
|
export const UserMessage: FC = () => {
|
||||||
const metadata = useAuiState(({ message }) => message?.metadata);
|
const metadata = useAuiState(({ message }) => message?.metadata);
|
||||||
const author = metadata?.custom?.author as AuthorMetadata | undefined;
|
const author = metadata?.custom?.author as AuthorMetadata | undefined;
|
||||||
const isSharedChat = useAtomValue(currentThreadAtom).visibility === "SEARCH_SPACE";
|
const isSharedChat = useAtomValue(currentThreadAtom).visibility === "SEARCH_SPACE";
|
||||||
|
|
@ -78,11 +103,7 @@ export const UserMessage: FC = () => {
|
||||||
<div className="aui-user-message-content-wrapper flex items-end gap-2">
|
<div className="aui-user-message-content-wrapper flex items-end gap-2">
|
||||||
<div className="relative flex-1 min-w-0">
|
<div className="relative flex-1 min-w-0">
|
||||||
<div className="aui-user-message-content wrap-break-word rounded-2xl bg-muted px-4 py-2.5 text-foreground">
|
<div className="aui-user-message-content wrap-break-word rounded-2xl bg-muted px-4 py-2.5 text-foreground">
|
||||||
{mentionedDocs && mentionedDocs.length > 0 ? (
|
<MessagePrimitive.Parts components={userMessageParts} />
|
||||||
<UserMessageWithMentionChips text={messageText} mentionedDocs={mentionedDocs} />
|
|
||||||
) : (
|
|
||||||
<MessagePrimitive.Parts />
|
|
||||||
)}
|
|
||||||
</div>
|
</div>
|
||||||
<div className="absolute right-0 top-full mt-1 z-10 opacity-100 pointer-events-auto md:opacity-0 md:pointer-events-none md:transition-opacity md:duration-200 md:delay-300 md:group-hover/user-msg:opacity-100 md:group-hover/user-msg:delay-0 md:group-hover/user-msg:pointer-events-auto">
|
<div className="absolute right-0 top-full mt-1 z-10 opacity-100 pointer-events-auto md:opacity-0 md:pointer-events-none md:transition-opacity md:duration-200 md:delay-300 md:group-hover/user-msg:opacity-100 md:group-hover/user-msg:delay-0 md:group-hover/user-msg:pointer-events-auto">
|
||||||
<UserActionBar />
|
<UserActionBar />
|
||||||
|
|
@ -99,64 +120,6 @@ export const UserMessage: FC = () => {
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
const UserMessageWithMentionChips: FC<{
|
|
||||||
text: string;
|
|
||||||
mentionedDocs: { id: number; title: string; document_type: string }[];
|
|
||||||
}> = ({ text, mentionedDocs }) => {
|
|
||||||
type Segment =
|
|
||||||
| { type: "text"; value: string; start: number }
|
|
||||||
| { type: "mention"; doc: { id: number; title: string; document_type: string }; start: number };
|
|
||||||
|
|
||||||
const tokens = mentionedDocs
|
|
||||||
.map((doc) => ({ doc, token: `@${doc.title}` }))
|
|
||||||
.sort((a, b) => b.token.length - a.token.length);
|
|
||||||
|
|
||||||
const segments: Segment[] = [];
|
|
||||||
let i = 0;
|
|
||||||
let buffer = "";
|
|
||||||
let bufferStart = 0;
|
|
||||||
while (i < text.length) {
|
|
||||||
const tokenMatch = tokens.find(({ token }) => text.startsWith(token, i));
|
|
||||||
if (tokenMatch) {
|
|
||||||
if (buffer) {
|
|
||||||
segments.push({ type: "text", value: buffer, start: bufferStart });
|
|
||||||
buffer = "";
|
|
||||||
}
|
|
||||||
segments.push({ type: "mention", doc: tokenMatch.doc, start: i });
|
|
||||||
i += tokenMatch.token.length;
|
|
||||||
bufferStart = i;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (!buffer) bufferStart = i;
|
|
||||||
buffer += text[i];
|
|
||||||
i += 1;
|
|
||||||
}
|
|
||||||
if (buffer) {
|
|
||||||
segments.push({ type: "text", value: buffer, start: bufferStart });
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<span className="whitespace-pre-wrap break-words">
|
|
||||||
{segments.map((segment) =>
|
|
||||||
segment.type === "text" ? (
|
|
||||||
<span key={`txt-${segment.start}`}>{segment.value}</span>
|
|
||||||
) : (
|
|
||||||
<span
|
|
||||||
key={`mention-${segment.doc.document_type}:${segment.doc.id}-${segment.start}`}
|
|
||||||
className="inline-flex items-center gap-1 mx-0.5 px-1 py-0.5 rounded bg-primary/10 text-xs font-bold text-primary/60 select-none align-baseline"
|
|
||||||
title={segment.doc.title}
|
|
||||||
>
|
|
||||||
<span className="flex items-center text-muted-foreground">
|
|
||||||
{getConnectorIcon(segment.doc.document_type ?? "UNKNOWN", "h-3 w-3")}
|
|
||||||
</span>
|
|
||||||
<span className="max-w-[120px] truncate">{segment.doc.title}</span>
|
|
||||||
</span>
|
|
||||||
)
|
|
||||||
)}
|
|
||||||
</span>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const UserActionBar: FC = () => {
|
const UserActionBar: FC = () => {
|
||||||
const isThreadRunning = useAuiState(({ thread }) => thread.isRunning);
|
const isThreadRunning = useAuiState(({ thread }) => thread.isRunning);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,10 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { AuiIf, ThreadPrimitive } from "@assistant-ui/react";
|
import { AuiIf, ThreadPrimitive } from "@assistant-ui/react";
|
||||||
import { ArrowDownIcon } from "lucide-react";
|
|
||||||
import type { FC } from "react";
|
import type { FC } from "react";
|
||||||
import { AssistantMessage } from "@/components/assistant-ui/assistant-message";
|
import { AssistantMessage } from "@/components/assistant-ui/assistant-message";
|
||||||
|
import { ChatViewport } from "@/components/assistant-ui/chat-viewport";
|
||||||
import { EditComposer } from "@/components/assistant-ui/edit-composer";
|
import { EditComposer } from "@/components/assistant-ui/edit-composer";
|
||||||
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
|
||||||
import { UserMessage } from "@/components/assistant-ui/user-message";
|
import { UserMessage } from "@/components/assistant-ui/user-message";
|
||||||
import { FreeComposer } from "./free-composer";
|
import { FreeComposer } from "./free-composer";
|
||||||
|
|
||||||
|
|
@ -24,20 +23,6 @@ const FreeThreadWelcome: FC = () => {
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
const ThreadScrollToBottom: FC = () => {
|
|
||||||
return (
|
|
||||||
<ThreadPrimitive.ScrollToBottom asChild>
|
|
||||||
<TooltipIconButton
|
|
||||||
tooltip="Scroll to bottom"
|
|
||||||
variant="outline"
|
|
||||||
className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent"
|
|
||||||
>
|
|
||||||
<ArrowDownIcon />
|
|
||||||
</TooltipIconButton>
|
|
||||||
</ThreadPrimitive.ScrollToBottom>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export const FreeThread: FC = () => {
|
export const FreeThread: FC = () => {
|
||||||
return (
|
return (
|
||||||
<ThreadPrimitive.Root
|
<ThreadPrimitive.Root
|
||||||
|
|
@ -46,10 +31,12 @@ export const FreeThread: FC = () => {
|
||||||
["--thread-max-width" as string]: "44rem",
|
["--thread-max-width" as string]: "44rem",
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<ThreadPrimitive.Viewport
|
<ChatViewport
|
||||||
turnAnchor="top"
|
footer={
|
||||||
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4"
|
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
||||||
style={{ scrollbarGutter: "stable" }}
|
<FreeComposer />
|
||||||
|
</AuiIf>
|
||||||
|
}
|
||||||
>
|
>
|
||||||
<AuiIf condition={({ thread }) => thread.isEmpty}>
|
<AuiIf condition={({ thread }) => thread.isEmpty}>
|
||||||
<FreeThreadWelcome />
|
<FreeThreadWelcome />
|
||||||
|
|
@ -62,21 +49,7 @@ export const FreeThread: FC = () => {
|
||||||
AssistantMessage,
|
AssistantMessage,
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
</ChatViewport>
|
||||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
|
||||||
<div className="grow" />
|
|
||||||
</AuiIf>
|
|
||||||
|
|
||||||
<ThreadPrimitive.ViewportFooter
|
|
||||||
className="aui-thread-viewport-footer sticky bottom-0 z-10 mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible rounded-t-3xl bg-main-panel pb-4 md:pb-6"
|
|
||||||
style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }}
|
|
||||||
>
|
|
||||||
<ThreadScrollToBottom />
|
|
||||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
|
||||||
<FreeComposer />
|
|
||||||
</AuiIf>
|
|
||||||
</ThreadPrimitive.ViewportFooter>
|
|
||||||
</ThreadPrimitive.Viewport>
|
|
||||||
</ThreadPrimitive.Root>
|
</ThreadPrimitive.Root>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -236,6 +236,93 @@ interface DisplayItem {
|
||||||
isAutoMode: boolean;
|
isAutoMode: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const TruncatedNameWithTooltip: React.FC<{
|
||||||
|
text: string;
|
||||||
|
className?: string;
|
||||||
|
enableTooltip: boolean;
|
||||||
|
}> = ({ text, className, enableTooltip }) => {
|
||||||
|
const textRef = useRef<HTMLSpanElement>(null);
|
||||||
|
const openTimerRef = useRef<number | undefined>(undefined);
|
||||||
|
const [isTruncated, setIsTruncated] = useState(false);
|
||||||
|
const [open, setOpen] = useState(false);
|
||||||
|
|
||||||
|
const recalcTruncation = useCallback(() => {
|
||||||
|
const el = textRef.current;
|
||||||
|
if (!el) return;
|
||||||
|
setIsTruncated(el.scrollWidth > el.clientWidth + 1);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!enableTooltip) return;
|
||||||
|
const el = textRef.current;
|
||||||
|
if (!el) return;
|
||||||
|
|
||||||
|
const raf = requestAnimationFrame(recalcTruncation);
|
||||||
|
recalcTruncation();
|
||||||
|
|
||||||
|
const observer = new ResizeObserver(recalcTruncation);
|
||||||
|
observer.observe(el);
|
||||||
|
if (el.parentElement) observer.observe(el.parentElement);
|
||||||
|
window.addEventListener("resize", recalcTruncation);
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
cancelAnimationFrame(raf);
|
||||||
|
observer.disconnect();
|
||||||
|
window.removeEventListener("resize", recalcTruncation);
|
||||||
|
};
|
||||||
|
}, [enableTooltip, recalcTruncation]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
// Recompute when row text changes.
|
||||||
|
void text;
|
||||||
|
requestAnimationFrame(recalcTruncation);
|
||||||
|
}, [text, recalcTruncation]);
|
||||||
|
|
||||||
|
useEffect(
|
||||||
|
() => () => {
|
||||||
|
if (openTimerRef.current) window.clearTimeout(openTimerRef.current);
|
||||||
|
},
|
||||||
|
[]
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!enableTooltip) {
|
||||||
|
return (
|
||||||
|
<span ref={textRef} className={cn("block max-w-full", className)}>
|
||||||
|
{text}
|
||||||
|
</span>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleOpenChange = (nextOpen: boolean) => {
|
||||||
|
if (openTimerRef.current) {
|
||||||
|
window.clearTimeout(openTimerRef.current);
|
||||||
|
openTimerRef.current = undefined;
|
||||||
|
}
|
||||||
|
if (!nextOpen) {
|
||||||
|
setOpen(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!isTruncated) return;
|
||||||
|
openTimerRef.current = window.setTimeout(() => {
|
||||||
|
setOpen(true);
|
||||||
|
openTimerRef.current = undefined;
|
||||||
|
}, 220);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Tooltip open={open} onOpenChange={handleOpenChange}>
|
||||||
|
<TooltipTrigger asChild>
|
||||||
|
<span ref={textRef} className={cn("block max-w-full", className)}>
|
||||||
|
{text}
|
||||||
|
</span>
|
||||||
|
</TooltipTrigger>
|
||||||
|
<TooltipContent side="top" align="start">
|
||||||
|
{text}
|
||||||
|
</TooltipContent>
|
||||||
|
</Tooltip>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
// ─── Component ──────────────────────────────────────────────────────
|
// ─── Component ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
interface ModelSelectorProps {
|
interface ModelSelectorProps {
|
||||||
|
|
@ -936,7 +1023,11 @@ export function ModelSelector({
|
||||||
{/* Model info */}
|
{/* Model info */}
|
||||||
<div className="flex-1 min-w-0">
|
<div className="flex-1 min-w-0">
|
||||||
<div className="flex items-center gap-1.5">
|
<div className="flex items-center gap-1.5">
|
||||||
<span className="font-medium text-sm truncate">{config.name}</span>
|
<TruncatedNameWithTooltip
|
||||||
|
text={config.name}
|
||||||
|
enableTooltip={!isMobile}
|
||||||
|
className="font-medium text-sm truncate"
|
||||||
|
/>
|
||||||
{isAutoMode && (
|
{isAutoMode && (
|
||||||
<Badge
|
<Badge
|
||||||
variant="secondary"
|
variant="secondary"
|
||||||
|
|
|
||||||
|
|
@ -45,20 +45,21 @@ export const PublicThread: FC<PublicThreadProps> = ({ footer }) => {
|
||||||
["--thread-max-width" as string]: "44rem",
|
["--thread-max-width" as string]: "44rem",
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<ThreadPrimitive.Viewport className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4">
|
<ThreadPrimitive.Viewport
|
||||||
|
scrollToBottomOnInitialize
|
||||||
|
scrollToBottomOnThreadSwitch
|
||||||
|
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4 pb-6"
|
||||||
|
>
|
||||||
<ThreadPrimitive.Messages
|
<ThreadPrimitive.Messages
|
||||||
components={{
|
components={{
|
||||||
UserMessage: PublicUserMessage,
|
UserMessage: PublicUserMessage,
|
||||||
AssistantMessage: PublicAssistantMessage,
|
AssistantMessage: PublicAssistantMessage,
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
{/* Spacer to ensure footer doesn't overlap last message */}
|
|
||||||
<div className="h-24" />
|
|
||||||
</ThreadPrimitive.Viewport>
|
</ThreadPrimitive.Viewport>
|
||||||
|
|
||||||
{footer && (
|
{footer && (
|
||||||
<div className="sticky bottom-0 z-20 border-t bg-main-panel/95 backdrop-blur supports-backdrop-filter:bg-main-panel/60">
|
<div className="border-t bg-main-panel/95 backdrop-blur supports-backdrop-filter:bg-main-panel/60">
|
||||||
{footer}
|
{footer}
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
|
||||||
|
|
@ -147,6 +147,22 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
errorCode === "TURN_CANCELLING"
|
||||||
|
) {
|
||||||
|
return {
|
||||||
|
kind: "thread_busy",
|
||||||
|
channel: "toast",
|
||||||
|
severity: "info",
|
||||||
|
telemetryEvent: "chat_blocked",
|
||||||
|
isExpected: true,
|
||||||
|
userMessage: "A previous response is still stopping. Please try again in a moment.",
|
||||||
|
rawMessage,
|
||||||
|
errorCode: errorCode ?? "TURN_CANCELLING",
|
||||||
|
details: { flow: input.flow },
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
errorCode === "THREAD_BUSY"
|
errorCode === "THREAD_BUSY"
|
||||||
) {
|
) {
|
||||||
|
|
@ -156,7 +172,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError
|
||||||
severity: "warn",
|
severity: "warn",
|
||||||
telemetryEvent: "chat_blocked",
|
telemetryEvent: "chat_blocked",
|
||||||
isExpected: true,
|
isExpected: true,
|
||||||
userMessage: "A previous response is still stopping. Please try again in a moment.",
|
userMessage: "Another response is still finishing for this thread. Please try again in a moment.",
|
||||||
rawMessage,
|
rawMessage,
|
||||||
errorCode: errorCode ?? "THREAD_BUSY",
|
errorCode: errorCode ?? "THREAD_BUSY",
|
||||||
details: { flow: input.flow },
|
details: { flow: input.flow },
|
||||||
|
|
|
||||||
114
surfsense_web/lib/chat/chat-request-errors.ts
Normal file
114
surfsense_web/lib/chat/chat-request-errors.ts
Normal file
|
|
@ -0,0 +1,114 @@
|
||||||
|
export async function toHttpResponseError(
|
||||||
|
response: Response
|
||||||
|
): Promise<Error & { errorCode?: string; retryAfterMs?: number }> {
|
||||||
|
const statusDefaultCode =
|
||||||
|
response.status === 409
|
||||||
|
? "THREAD_BUSY"
|
||||||
|
: response.status === 429
|
||||||
|
? "RATE_LIMITED"
|
||||||
|
: response.status === 401 || response.status === 403
|
||||||
|
? "AUTH_EXPIRED"
|
||||||
|
: "SERVER_ERROR";
|
||||||
|
|
||||||
|
let rawBody = "";
|
||||||
|
try {
|
||||||
|
rawBody = await response.text();
|
||||||
|
} catch {
|
||||||
|
// noop
|
||||||
|
}
|
||||||
|
|
||||||
|
let parsedBody: Record<string, unknown> | null = null;
|
||||||
|
if (rawBody) {
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(rawBody);
|
||||||
|
if (typeof parsed === "object" && parsed !== null) {
|
||||||
|
parsedBody = parsed as Record<string, unknown>;
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// noop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const detail = parsedBody?.detail;
|
||||||
|
const detailObject =
|
||||||
|
typeof detail === "object" && detail !== null ? (detail as Record<string, unknown>) : null;
|
||||||
|
const detailMessage = typeof detail === "string" ? detail : undefined;
|
||||||
|
const topLevelMessage =
|
||||||
|
typeof parsedBody?.message === "string" ? (parsedBody.message as string) : undefined;
|
||||||
|
const detailNestedMessage =
|
||||||
|
typeof detailObject?.message === "string" ? (detailObject.message as string) : undefined;
|
||||||
|
|
||||||
|
const topLevelCode =
|
||||||
|
typeof parsedBody?.errorCode === "string"
|
||||||
|
? parsedBody.errorCode
|
||||||
|
: typeof parsedBody?.error_code === "string"
|
||||||
|
? parsedBody.error_code
|
||||||
|
: undefined;
|
||||||
|
const detailCode =
|
||||||
|
typeof detailObject?.errorCode === "string"
|
||||||
|
? detailObject.errorCode
|
||||||
|
: typeof detailObject?.error_code === "string"
|
||||||
|
? detailObject.error_code
|
||||||
|
: undefined;
|
||||||
|
|
||||||
|
const errorCode = detailCode ?? topLevelCode ?? statusDefaultCode;
|
||||||
|
|
||||||
|
const detailRetryAfterMs =
|
||||||
|
typeof detailObject?.retry_after_ms === "number"
|
||||||
|
? detailObject.retry_after_ms
|
||||||
|
: typeof detailObject?.retryAfterMs === "number"
|
||||||
|
? detailObject.retryAfterMs
|
||||||
|
: undefined;
|
||||||
|
const topRetryAfterMs =
|
||||||
|
typeof parsedBody?.retry_after_ms === "number"
|
||||||
|
? parsedBody.retry_after_ms
|
||||||
|
: typeof parsedBody?.retryAfterMs === "number"
|
||||||
|
? parsedBody.retryAfterMs
|
||||||
|
: undefined;
|
||||||
|
const headerRetryAfterMsRaw = response.headers.get("retry-after-ms");
|
||||||
|
const headerRetryAfterMs = headerRetryAfterMsRaw ? Number.parseFloat(headerRetryAfterMsRaw) : NaN;
|
||||||
|
const retryAfterHeader = response.headers.get("retry-after");
|
||||||
|
const retryAfterSeconds = retryAfterHeader ? Number.parseFloat(retryAfterHeader) : NaN;
|
||||||
|
const retryAfterMsFromHeader = Number.isFinite(headerRetryAfterMs)
|
||||||
|
? Math.max(0, Math.round(headerRetryAfterMs))
|
||||||
|
: Number.isFinite(retryAfterSeconds)
|
||||||
|
? Math.max(0, Math.round(retryAfterSeconds * 1000))
|
||||||
|
: undefined;
|
||||||
|
const retryAfterMs =
|
||||||
|
detailRetryAfterMs ?? topRetryAfterMs ?? retryAfterMsFromHeader ?? undefined;
|
||||||
|
const message =
|
||||||
|
detailNestedMessage ??
|
||||||
|
detailMessage ??
|
||||||
|
topLevelMessage ??
|
||||||
|
`Backend error: ${response.status}`;
|
||||||
|
|
||||||
|
return Object.assign(new Error(message), { errorCode, retryAfterMs });
|
||||||
|
}
|
||||||
|
|
||||||
|
export function tagPreAcceptSendFailure(error: unknown): unknown {
|
||||||
|
if (error instanceof Error) {
|
||||||
|
const withCode = error as Error & { errorCode?: string; code?: string };
|
||||||
|
const existingCode = withCode.errorCode ?? withCode.code;
|
||||||
|
const passthroughCodes = new Set([
|
||||||
|
"PREMIUM_QUOTA_EXHAUSTED",
|
||||||
|
"THREAD_BUSY",
|
||||||
|
"TURN_CANCELLING",
|
||||||
|
"AUTH_EXPIRED",
|
||||||
|
"UNAUTHORIZED",
|
||||||
|
"RATE_LIMITED",
|
||||||
|
"NETWORK_ERROR",
|
||||||
|
"STREAM_PARSE_ERROR",
|
||||||
|
"TOOL_EXECUTION_ERROR",
|
||||||
|
"PERSIST_MESSAGE_FAILED",
|
||||||
|
"SERVER_ERROR",
|
||||||
|
]);
|
||||||
|
if (existingCode && passthroughCodes.has(existingCode)) {
|
||||||
|
return Object.assign(error, { errorCode: existingCode });
|
||||||
|
}
|
||||||
|
return Object.assign(error, { errorCode: "SEND_FAILED_PRE_ACCEPT" });
|
||||||
|
}
|
||||||
|
|
||||||
|
return Object.assign(new Error("Failed to send message before stream acceptance"), {
|
||||||
|
errorCode: "SEND_FAILED_PRE_ACCEPT",
|
||||||
|
});
|
||||||
|
}
|
||||||
54
surfsense_web/lib/chat/parse-mention-segments.ts
Normal file
54
surfsense_web/lib/chat/parse-mention-segments.ts
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
import type { MentionedDocumentInfo } from "@/atoms/chat/mentioned-documents.atom";
|
||||||
|
|
||||||
|
export type MentionSegment =
|
||||||
|
| { type: "text"; value: string; start: number }
|
||||||
|
| { type: "mention"; doc: MentionedDocumentInfo; start: number };
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tokenizes a user message into text and `@mention` segments.
|
||||||
|
*
|
||||||
|
* Pure: no React, no DOM, no side effects. Safe to unit-test and reuse.
|
||||||
|
*
|
||||||
|
* Mentions are matched greedily by longest title first so that a longer title
|
||||||
|
* (e.g. `@Project Roadmap`) is never shadowed by a shorter prefix
|
||||||
|
* (e.g. `@Project`).
|
||||||
|
*/
|
||||||
|
export function parseMentionSegments(
|
||||||
|
text: string,
|
||||||
|
docs: ReadonlyArray<MentionedDocumentInfo>
|
||||||
|
): MentionSegment[] {
|
||||||
|
if (text.length === 0) return [];
|
||||||
|
if (docs.length === 0) return [{ type: "text", value: text, start: 0 }];
|
||||||
|
|
||||||
|
const tokens = docs
|
||||||
|
.map((doc) => ({ doc, token: `@${doc.title}` }))
|
||||||
|
.sort((a, b) => b.token.length - a.token.length);
|
||||||
|
|
||||||
|
const segments: MentionSegment[] = [];
|
||||||
|
let i = 0;
|
||||||
|
let buffer = "";
|
||||||
|
let bufferStart = 0;
|
||||||
|
|
||||||
|
while (i < text.length) {
|
||||||
|
const tokenMatch = tokens.find(({ token }) => text.startsWith(token, i));
|
||||||
|
if (tokenMatch) {
|
||||||
|
if (buffer) {
|
||||||
|
segments.push({ type: "text", value: buffer, start: bufferStart });
|
||||||
|
buffer = "";
|
||||||
|
}
|
||||||
|
segments.push({ type: "mention", doc: tokenMatch.doc, start: i });
|
||||||
|
i += tokenMatch.token.length;
|
||||||
|
bufferStart = i;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!buffer) bufferStart = i;
|
||||||
|
buffer += text[i];
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (buffer) {
|
||||||
|
segments.push({ type: "text", value: buffer, start: bufferStart });
|
||||||
|
}
|
||||||
|
|
||||||
|
return segments;
|
||||||
|
}
|
||||||
19
surfsense_web/lib/chat/stream-flush.ts
Normal file
19
surfsense_web/lib/chat/stream-flush.ts
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
import { FrameBatchedUpdater } from "@/lib/chat/streaming-state";
|
||||||
|
|
||||||
|
export function createStreamFlushHelpers(flushMessages: () => void): {
|
||||||
|
batcher: FrameBatchedUpdater;
|
||||||
|
scheduleFlush: () => void;
|
||||||
|
forceFlush: () => void;
|
||||||
|
} {
|
||||||
|
const batcher = new FrameBatchedUpdater();
|
||||||
|
const scheduleFlush = () => batcher.schedule(flushMessages);
|
||||||
|
// Force-flush helper: ``batcher.flush()`` is a no-op when
|
||||||
|
// ``dirty=false`` (e.g. a tool starts before any text streamed).
|
||||||
|
// ``scheduleFlush(); batcher.flush()`` sets the dirty bit first so
|
||||||
|
// terminal events render promptly without the throttle delay.
|
||||||
|
const forceFlush = () => {
|
||||||
|
scheduleFlush();
|
||||||
|
batcher.flush();
|
||||||
|
};
|
||||||
|
return { batcher, scheduleFlush, forceFlush };
|
||||||
|
}
|
||||||
196
surfsense_web/lib/chat/stream-pipeline.ts
Normal file
196
surfsense_web/lib/chat/stream-pipeline.ts
Normal file
|
|
@ -0,0 +1,196 @@
|
||||||
|
import {
|
||||||
|
addStepSeparator,
|
||||||
|
addToolCall,
|
||||||
|
appendReasoning,
|
||||||
|
appendText,
|
||||||
|
appendToolInputDelta,
|
||||||
|
type ContentPartsState,
|
||||||
|
endReasoning,
|
||||||
|
readSSEStream,
|
||||||
|
type SSEEvent,
|
||||||
|
type ThinkingStepData,
|
||||||
|
type ToolUIGate,
|
||||||
|
updateThinkingSteps,
|
||||||
|
updateToolCall,
|
||||||
|
} from "@/lib/chat/streaming-state";
|
||||||
|
|
||||||
|
export type SharedStreamEventContext = {
|
||||||
|
contentPartsState: ContentPartsState;
|
||||||
|
toolsWithUI: ToolUIGate;
|
||||||
|
currentThinkingSteps: Map<string, ThinkingStepData>;
|
||||||
|
scheduleFlush: () => void;
|
||||||
|
forceFlush: () => void;
|
||||||
|
onTokenUsage?: (data: Extract<SSEEvent, { type: "data-token-usage" }>["data"]) => void;
|
||||||
|
onTurnStatus?: (data: Extract<SSEEvent, { type: "data-turn-status" }>["data"]) => void;
|
||||||
|
onToolOutputAvailable?: (
|
||||||
|
event: Extract<SSEEvent, { type: "tool-output-available" }>,
|
||||||
|
context: {
|
||||||
|
contentPartsState: ContentPartsState;
|
||||||
|
toolCallIndices: Map<string, number>;
|
||||||
|
}
|
||||||
|
) => void;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* After a tool produces output, mark any previously-decided interrupt tool
|
||||||
|
* calls as completed so the ApprovalCard can transition from shimmer to done.
|
||||||
|
*/
|
||||||
|
export function markInterruptsCompleted(
|
||||||
|
contentParts: Array<{ type: string; result?: unknown }>
|
||||||
|
): void {
|
||||||
|
for (const part of contentParts) {
|
||||||
|
if (
|
||||||
|
part.type === "tool-call" &&
|
||||||
|
typeof part.result === "object" &&
|
||||||
|
part.result !== null &&
|
||||||
|
(part.result as Record<string, unknown>).__interrupt__ === true &&
|
||||||
|
(part.result as Record<string, unknown>).__decided__ &&
|
||||||
|
!(part.result as Record<string, unknown>).__completed__
|
||||||
|
) {
|
||||||
|
part.result = { ...(part.result as Record<string, unknown>), __completed__: true };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function hasPersistableContent(
|
||||||
|
contentParts: ContentPartsState["contentParts"],
|
||||||
|
toolsWithUI: ToolUIGate
|
||||||
|
) {
|
||||||
|
return contentParts.some(
|
||||||
|
(part) =>
|
||||||
|
(part.type === "text" && part.text.length > 0) ||
|
||||||
|
(part.type === "reasoning" && part.text.length > 0) ||
|
||||||
|
(part.type === "tool-call" && (toolsWithUI === "all" || toolsWithUI.has(part.toolName)))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function toStreamTerminalError(
|
||||||
|
event: Extract<SSEEvent, { type: "error" }>
|
||||||
|
): Error & { errorCode?: string } {
|
||||||
|
return Object.assign(new Error(event.errorText || "Server error"), {
|
||||||
|
errorCode: event.errorCode,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
export function processSharedStreamEvent(parsed: SSEEvent, context: SharedStreamEventContext): boolean {
|
||||||
|
const { contentPartsState, toolsWithUI, currentThinkingSteps, scheduleFlush, forceFlush } = context;
|
||||||
|
const { contentParts, toolCallIndices } = contentPartsState;
|
||||||
|
|
||||||
|
switch (parsed.type) {
|
||||||
|
case "text-delta":
|
||||||
|
appendText(contentPartsState, parsed.delta);
|
||||||
|
scheduleFlush();
|
||||||
|
return true;
|
||||||
|
|
||||||
|
case "reasoning-delta":
|
||||||
|
appendReasoning(contentPartsState, parsed.delta);
|
||||||
|
scheduleFlush();
|
||||||
|
return true;
|
||||||
|
|
||||||
|
case "reasoning-end":
|
||||||
|
endReasoning(contentPartsState);
|
||||||
|
scheduleFlush();
|
||||||
|
return true;
|
||||||
|
|
||||||
|
case "start-step":
|
||||||
|
addStepSeparator(contentPartsState);
|
||||||
|
scheduleFlush();
|
||||||
|
return true;
|
||||||
|
|
||||||
|
case "finish-step":
|
||||||
|
return true;
|
||||||
|
|
||||||
|
case "tool-input-start":
|
||||||
|
addToolCall(
|
||||||
|
contentPartsState,
|
||||||
|
toolsWithUI,
|
||||||
|
parsed.toolCallId,
|
||||||
|
parsed.toolName,
|
||||||
|
{},
|
||||||
|
false,
|
||||||
|
parsed.langchainToolCallId
|
||||||
|
);
|
||||||
|
forceFlush();
|
||||||
|
return true;
|
||||||
|
|
||||||
|
case "tool-input-delta":
|
||||||
|
// High-frequency event: deltas can fire dozens of times per call,
|
||||||
|
// so use throttled scheduleFlush (NOT forceFlush) to coalesce.
|
||||||
|
appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta);
|
||||||
|
scheduleFlush();
|
||||||
|
return true;
|
||||||
|
|
||||||
|
case "tool-input-available": {
|
||||||
|
const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2);
|
||||||
|
if (toolCallIndices.has(parsed.toolCallId)) {
|
||||||
|
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||||
|
args: parsed.input || {},
|
||||||
|
argsText: finalArgsText,
|
||||||
|
langchainToolCallId: parsed.langchainToolCallId,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
addToolCall(
|
||||||
|
contentPartsState,
|
||||||
|
toolsWithUI,
|
||||||
|
parsed.toolCallId,
|
||||||
|
parsed.toolName,
|
||||||
|
parsed.input || {},
|
||||||
|
false,
|
||||||
|
parsed.langchainToolCallId
|
||||||
|
);
|
||||||
|
// addToolCall doesn't accept argsText today; backfill via
|
||||||
|
// updateToolCall so the new card renders pretty-printed JSON.
|
||||||
|
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||||
|
argsText: finalArgsText,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
forceFlush();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
case "tool-output-available":
|
||||||
|
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||||
|
result: parsed.output,
|
||||||
|
langchainToolCallId: parsed.langchainToolCallId,
|
||||||
|
});
|
||||||
|
markInterruptsCompleted(contentParts);
|
||||||
|
context.onToolOutputAvailable?.(parsed, { contentPartsState, toolCallIndices });
|
||||||
|
forceFlush();
|
||||||
|
return true;
|
||||||
|
|
||||||
|
case "data-thinking-step": {
|
||||||
|
const stepData = parsed.data as ThinkingStepData;
|
||||||
|
if (stepData?.id) {
|
||||||
|
currentThinkingSteps.set(stepData.id, stepData);
|
||||||
|
const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps);
|
||||||
|
if (didUpdate) {
|
||||||
|
scheduleFlush();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
case "data-token-usage":
|
||||||
|
context.onTokenUsage?.(parsed.data);
|
||||||
|
return true;
|
||||||
|
|
||||||
|
case "data-turn-status":
|
||||||
|
context.onTurnStatus?.(parsed.data);
|
||||||
|
return true;
|
||||||
|
|
||||||
|
case "error":
|
||||||
|
throw toStreamTerminalError(parsed);
|
||||||
|
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function consumeSseEvents(
|
||||||
|
response: Response,
|
||||||
|
onEvent: (event: SSEEvent) => void | Promise<void>
|
||||||
|
): Promise<void> {
|
||||||
|
for await (const parsed of readSSEStream(response)) {
|
||||||
|
await onEvent(parsed);
|
||||||
|
}
|
||||||
|
}
|
||||||
127
surfsense_web/lib/chat/stream-side-effects.ts
Normal file
127
surfsense_web/lib/chat/stream-side-effects.ts
Normal file
|
|
@ -0,0 +1,127 @@
|
||||||
|
import type { ThreadMessageLike } from "@assistant-ui/react";
|
||||||
|
import {
|
||||||
|
addToolCall,
|
||||||
|
type ContentPartsState,
|
||||||
|
type ToolUIGate,
|
||||||
|
updateToolCall,
|
||||||
|
} from "@/lib/chat/streaming-state";
|
||||||
|
|
||||||
|
type InterruptActionRequest = {
|
||||||
|
name: string;
|
||||||
|
args: Record<string, unknown>;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type EditedInterruptAction = {
|
||||||
|
name: string;
|
||||||
|
args: Record<string, unknown>;
|
||||||
|
};
|
||||||
|
|
||||||
|
function readInterruptActions(
|
||||||
|
interruptData: Record<string, unknown>
|
||||||
|
): InterruptActionRequest[] {
|
||||||
|
return (interruptData.action_requests ?? []) as InterruptActionRequest[];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Applies an interrupt request payload to tool-call parts. Existing tool cards
|
||||||
|
* are updated in-place; missing ones are upserted so approval UI always shows.
|
||||||
|
*/
|
||||||
|
export function applyInterruptRequestToContentParts(
|
||||||
|
contentPartsState: ContentPartsState,
|
||||||
|
toolsWithUI: ToolUIGate,
|
||||||
|
interruptData: Record<string, unknown>
|
||||||
|
): void {
|
||||||
|
const { contentParts, toolCallIndices } = contentPartsState;
|
||||||
|
const actionRequests = readInterruptActions(interruptData);
|
||||||
|
for (const action of actionRequests) {
|
||||||
|
const existingEntry = Array.from(toolCallIndices.entries()).find(([, idx]) => {
|
||||||
|
const part = contentParts[idx];
|
||||||
|
return part?.type === "tool-call" && part.toolName === action.name;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (existingEntry) {
|
||||||
|
updateToolCall(contentPartsState, existingEntry[0], {
|
||||||
|
result: { __interrupt__: true, ...interruptData },
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
const toolCallId = `interrupt-${action.name}`;
|
||||||
|
addToolCall(contentPartsState, toolsWithUI, toolCallId, action.name, action.args, true);
|
||||||
|
updateToolCall(contentPartsState, toolCallId, {
|
||||||
|
result: { __interrupt__: true, ...interruptData },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function mergeEditedInterruptAction(
|
||||||
|
contentParts: ContentPartsState["contentParts"],
|
||||||
|
editedAction: EditedInterruptAction | undefined
|
||||||
|
): void {
|
||||||
|
if (!editedAction) return;
|
||||||
|
for (const part of contentParts) {
|
||||||
|
if (part.type === "tool-call" && part.toolName === editedAction.name) {
|
||||||
|
const mergedArgs = { ...part.args, ...editedAction.args };
|
||||||
|
part.args = mergedArgs;
|
||||||
|
// assistant-ui prefers argsText over JSON.stringify(args)
|
||||||
|
part.argsText = JSON.stringify(mergedArgs, null, 2);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function markInterruptDecisionOnContentParts(
|
||||||
|
contentParts: ContentPartsState["contentParts"],
|
||||||
|
decisionType: "approve" | "reject" | undefined
|
||||||
|
): void {
|
||||||
|
if (!decisionType) return;
|
||||||
|
for (const part of contentParts) {
|
||||||
|
if (
|
||||||
|
part.type === "tool-call" &&
|
||||||
|
typeof part.result === "object" &&
|
||||||
|
part.result !== null &&
|
||||||
|
"__interrupt__" in (part.result as Record<string, unknown>)
|
||||||
|
) {
|
||||||
|
part.result = {
|
||||||
|
...(part.result as Record<string, unknown>),
|
||||||
|
__decided__: decisionType,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* When a streamed message is persisted, the backend returns the durable
|
||||||
|
* turn_id; merge it into assistant-ui metadata for turn-scoped actions.
|
||||||
|
*/
|
||||||
|
export function mergeChatTurnIdIntoMessage(
|
||||||
|
msg: ThreadMessageLike,
|
||||||
|
turnId: string | null | undefined
|
||||||
|
): ThreadMessageLike {
|
||||||
|
if (!turnId) return msg;
|
||||||
|
const existingMeta = (msg.metadata ?? {}) as { custom?: Record<string, unknown> };
|
||||||
|
const existingCustom = existingMeta.custom ?? {};
|
||||||
|
if ((existingCustom as { chatTurnId?: string }).chatTurnId === turnId) return msg;
|
||||||
|
return {
|
||||||
|
...msg,
|
||||||
|
metadata: {
|
||||||
|
...existingMeta,
|
||||||
|
custom: { ...existingCustom, chatTurnId: turnId },
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function readStreamedChatTurnId(data: unknown): string | null {
|
||||||
|
if (typeof data !== "object" || data === null) return null;
|
||||||
|
const value = (data as { chat_turn_id?: unknown }).chat_turn_id;
|
||||||
|
return typeof value === "string" && value.length > 0 ? value : null;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function applyTurnIdToAssistantMessageList(
|
||||||
|
messages: ThreadMessageLike[],
|
||||||
|
assistantMsgId: string,
|
||||||
|
turnId: string
|
||||||
|
): ThreadMessageLike[] {
|
||||||
|
return messages.map((m) =>
|
||||||
|
m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, turnId) : m
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
@ -528,6 +528,14 @@ export type SSEEvent =
|
||||||
}>;
|
}>;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
| {
|
||||||
|
type: "data-turn-status";
|
||||||
|
data: {
|
||||||
|
status: "idle" | "busy" | "cancelling";
|
||||||
|
retry_after_ms?: number;
|
||||||
|
retry_after_at?: number;
|
||||||
|
};
|
||||||
|
}
|
||||||
| {
|
| {
|
||||||
type: "data-token-usage";
|
type: "data-token-usage";
|
||||||
data: {
|
data: {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue