mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-05 13:52:40 +02:00
Merge pull request #1327 from AnishSarkar22/feat/chat-state-unification
refactor(chat): unify streaming state flow and improve chat viewport + mention UX
This commit is contained in:
commit
d335e96ec2
27 changed files with 1953 additions and 1647 deletions
|
|
@ -33,6 +33,7 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -58,6 +59,8 @@ class _ThreadLockManager:
|
|||
weakref.WeakValueDictionary()
|
||||
)
|
||||
self._cancel_events: dict[str, asyncio.Event] = {}
|
||||
self._cancel_requested_at_ms: dict[str, int] = {}
|
||||
self._cancel_attempt_count: dict[str, int] = {}
|
||||
|
||||
def lock_for(self, thread_id: str) -> asyncio.Lock:
|
||||
lock = self._locks.get(thread_id)
|
||||
|
|
@ -76,14 +79,45 @@ class _ThreadLockManager:
|
|||
def request_cancel(self, thread_id: str) -> bool:
|
||||
event = self._cancel_events.get(thread_id)
|
||||
if event is None:
|
||||
return False
|
||||
event = asyncio.Event()
|
||||
self._cancel_events[thread_id] = event
|
||||
event.set()
|
||||
now_ms = int(time.time() * 1000)
|
||||
self._cancel_requested_at_ms[thread_id] = now_ms
|
||||
self._cancel_attempt_count[thread_id] = (
|
||||
self._cancel_attempt_count.get(thread_id, 0) + 1
|
||||
)
|
||||
return True
|
||||
|
||||
def is_cancel_requested(self, thread_id: str) -> bool:
|
||||
event = self._cancel_events.get(thread_id)
|
||||
return bool(event and event.is_set())
|
||||
|
||||
def cancel_state(self, thread_id: str) -> tuple[int, int] | None:
|
||||
if not self.is_cancel_requested(thread_id):
|
||||
return None
|
||||
attempts = self._cancel_attempt_count.get(thread_id, 1)
|
||||
requested_at_ms = self._cancel_requested_at_ms.get(thread_id, 0)
|
||||
return attempts, requested_at_ms
|
||||
|
||||
def reset(self, thread_id: str) -> None:
|
||||
event = self._cancel_events.get(thread_id)
|
||||
if event is not None:
|
||||
event.clear()
|
||||
self._cancel_requested_at_ms.pop(thread_id, None)
|
||||
self._cancel_attempt_count.pop(thread_id, None)
|
||||
|
||||
def end_turn(self, thread_id: str) -> None:
|
||||
"""Best-effort terminal cleanup for a thread turn.
|
||||
|
||||
This is intentionally idempotent and safe to call from outer stream
|
||||
finally-blocks where middleware teardown might be skipped due to abort
|
||||
or disconnect edge-cases.
|
||||
"""
|
||||
lock = self._locks.get(thread_id)
|
||||
if lock is not None and lock.locked():
|
||||
lock.release()
|
||||
self.reset(thread_id)
|
||||
|
||||
|
||||
# Module-level singleton — process-local but reused across all agent
|
||||
|
|
@ -98,15 +132,30 @@ def get_cancel_event(thread_id: str) -> asyncio.Event:
|
|||
|
||||
|
||||
def request_cancel(thread_id: str) -> bool:
|
||||
"""Trip the cancel event for ``thread_id``. Returns True if found."""
|
||||
"""Trip the cancel event for ``thread_id``. Always returns True."""
|
||||
return manager.request_cancel(thread_id)
|
||||
|
||||
|
||||
def is_cancel_requested(thread_id: str) -> bool:
|
||||
"""Return whether ``thread_id`` currently has a pending cancel signal."""
|
||||
return manager.is_cancel_requested(thread_id)
|
||||
|
||||
|
||||
def get_cancel_state(thread_id: str) -> tuple[int, int] | None:
|
||||
"""Return ``(attempt_count, requested_at_ms)`` for pending cancel state."""
|
||||
return manager.cancel_state(thread_id)
|
||||
|
||||
|
||||
def reset_cancel(thread_id: str) -> None:
|
||||
"""Reset the cancel event for ``thread_id`` (called between turns)."""
|
||||
manager.reset(thread_id)
|
||||
|
||||
|
||||
def end_turn(thread_id: str) -> None:
|
||||
"""Force end-of-turn cleanup for lock + cancel state."""
|
||||
manager.end_turn(thread_id)
|
||||
|
||||
|
||||
class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""Block concurrent prompts on the same thread.
|
||||
|
||||
|
|
@ -229,7 +278,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
|||
|
||||
__all__ = [
|
||||
"BusyMutexMiddleware",
|
||||
"end_turn",
|
||||
"get_cancel_event",
|
||||
"get_cancel_state",
|
||||
"is_cancel_requested",
|
||||
"manager",
|
||||
"request_cancel",
|
||||
"reset_cancel",
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import json
|
|||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
|
|
@ -29,6 +29,12 @@ from app.agents.new_chat.filesystem_selection import (
|
|||
FilesystemSelection,
|
||||
LocalFilesystemMount,
|
||||
)
|
||||
from app.agents.new_chat.middleware.busy_mutex import (
|
||||
get_cancel_state,
|
||||
is_cancel_requested,
|
||||
manager,
|
||||
request_cancel,
|
||||
)
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ChatComment,
|
||||
|
|
@ -44,6 +50,7 @@ from app.db import (
|
|||
)
|
||||
from app.schemas.new_chat import (
|
||||
AgentToolInfo,
|
||||
CancelActiveTurnResponse,
|
||||
LocalFilesystemMountPayload,
|
||||
NewChatMessageRead,
|
||||
NewChatRequest,
|
||||
|
|
@ -60,6 +67,7 @@ from app.schemas.new_chat import (
|
|||
ThreadListItem,
|
||||
ThreadListResponse,
|
||||
TokenUsageSummary,
|
||||
TurnStatusResponse,
|
||||
)
|
||||
from app.services.token_tracking_service import record_token_usage
|
||||
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
||||
|
|
@ -72,6 +80,9 @@ from app.utils.user_message_multimodal import (
|
|||
|
||||
_logger = logging.getLogger(__name__)
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
TURN_CANCELLING_INITIAL_DELAY_MS = 200
|
||||
TURN_CANCELLING_BACKOFF_FACTOR = 2
|
||||
TURN_CANCELLING_MAX_DELAY_MS = 1500
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -137,6 +148,72 @@ def _resolve_filesystem_selection(
|
|||
)
|
||||
|
||||
|
||||
def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
|
||||
"""Bounded exponential delay for TURN_CANCELLING retry hints."""
|
||||
if attempt < 1:
|
||||
attempt = 1
|
||||
delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
|
||||
TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
|
||||
)
|
||||
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
|
||||
|
||||
|
||||
def _build_turn_status_payload(thread_id: int) -> dict[str, object]:
|
||||
lock = manager.lock_for(str(thread_id))
|
||||
if not lock.locked():
|
||||
return {"status": "idle"}
|
||||
|
||||
if is_cancel_requested(str(thread_id)):
|
||||
cancel_state = get_cancel_state(str(thread_id))
|
||||
attempt = cancel_state[0] if cancel_state else 1
|
||||
retry_after_ms = _compute_turn_cancelling_retry_delay(attempt)
|
||||
retry_after_at = int(datetime.now(UTC).timestamp() * 1000) + retry_after_ms
|
||||
return {
|
||||
"status": "cancelling",
|
||||
"retry_after_ms": retry_after_ms,
|
||||
"retry_after_at": retry_after_at,
|
||||
}
|
||||
|
||||
return {"status": "busy"}
|
||||
|
||||
|
||||
def _set_retry_after_headers(response: Response, retry_after_ms: int) -> None:
|
||||
response.headers["retry-after-ms"] = str(retry_after_ms)
|
||||
response.headers["Retry-After"] = str(max(1, (retry_after_ms + 999) // 1000))
|
||||
|
||||
|
||||
def _raise_if_thread_busy_for_start(thread_id: int) -> None:
|
||||
status_payload = _build_turn_status_payload(thread_id)
|
||||
status = status_payload["status"]
|
||||
if status == "idle":
|
||||
return
|
||||
if status == "cancelling":
|
||||
retry_after_ms = int(status_payload.get("retry_after_ms") or 0)
|
||||
detail = {
|
||||
"errorCode": "TURN_CANCELLING",
|
||||
"message": "A previous response is still stopping. Please try again in a moment.",
|
||||
"retry_after_ms": retry_after_ms if retry_after_ms > 0 else None,
|
||||
"retry_after_at": status_payload.get("retry_after_at"),
|
||||
}
|
||||
headers = (
|
||||
{
|
||||
"retry-after-ms": str(retry_after_ms),
|
||||
"Retry-After": str(max(1, (retry_after_ms + 999) // 1000)),
|
||||
}
|
||||
if retry_after_ms > 0
|
||||
else None
|
||||
)
|
||||
raise HTTPException(status_code=409, detail=detail, headers=headers)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail={
|
||||
"errorCode": "THREAD_BUSY",
|
||||
"message": "Another response is still finishing for this thread. Please try again in a moment.",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _find_pre_turn_checkpoint_id(
|
||||
checkpoint_tuples: list,
|
||||
*,
|
||||
|
|
@ -1476,6 +1553,7 @@ async def handle_new_chat(
|
|||
|
||||
# Check thread-level access based on visibility
|
||||
await check_thread_access(session, thread, user)
|
||||
_raise_if_thread_busy_for_start(request.chat_id)
|
||||
filesystem_selection = _resolve_filesystem_selection(
|
||||
mode=request.filesystem_mode,
|
||||
client_platform=request.client_platform,
|
||||
|
|
@ -1550,6 +1628,93 @@ async def handle_new_chat(
|
|||
) from None
|
||||
|
||||
|
||||
@router.post(
|
||||
"/threads/{thread_id}/cancel-active-turn",
|
||||
response_model=CancelActiveTurnResponse,
|
||||
)
|
||||
async def cancel_active_turn(
|
||||
thread_id: int,
|
||||
response: Response,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Signal cancellation for the currently running turn on ``thread_id``."""
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
thread = result.scalars().first()
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_UPDATE.value,
|
||||
"You don't have permission to update chats in this search space",
|
||||
)
|
||||
await check_thread_access(session, thread, user)
|
||||
|
||||
status_payload = _build_turn_status_payload(thread_id)
|
||||
if status_payload["status"] == "idle":
|
||||
return CancelActiveTurnResponse(
|
||||
status="idle",
|
||||
error_code="NO_ACTIVE_TURN",
|
||||
)
|
||||
|
||||
request_cancel(str(thread_id))
|
||||
response.status_code = 202
|
||||
updated_payload = _build_turn_status_payload(thread_id)
|
||||
retry_after_ms = int(updated_payload.get("retry_after_ms") or 0)
|
||||
retry_after_at = (
|
||||
int(updated_payload["retry_after_at"])
|
||||
if "retry_after_at" in updated_payload
|
||||
else None
|
||||
)
|
||||
if retry_after_ms > 0:
|
||||
_set_retry_after_headers(response, retry_after_ms)
|
||||
return CancelActiveTurnResponse(
|
||||
status="cancelling",
|
||||
error_code="TURN_CANCELLING",
|
||||
retry_after_ms=retry_after_ms if retry_after_ms > 0 else None,
|
||||
retry_after_at=retry_after_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/threads/{thread_id}/turn-status",
|
||||
response_model=TurnStatusResponse,
|
||||
)
|
||||
async def get_turn_status(
|
||||
thread_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
thread = result.scalars().first()
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to view chats in this search space",
|
||||
)
|
||||
await check_thread_access(session, thread, user)
|
||||
|
||||
status_payload = _build_turn_status_payload(thread_id)
|
||||
return TurnStatusResponse(
|
||||
status=status_payload["status"], # type: ignore[arg-type]
|
||||
active_turn_id=None,
|
||||
retry_after_ms=status_payload.get("retry_after_ms"), # type: ignore[arg-type]
|
||||
retry_after_at=status_payload.get("retry_after_at"), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Chat Regeneration Endpoint (Edit/Reload)
|
||||
# =============================================================================
|
||||
|
|
@ -1605,6 +1770,7 @@ async def regenerate_response(
|
|||
|
||||
# Check thread-level access based on visibility
|
||||
await check_thread_access(session, thread, user)
|
||||
_raise_if_thread_busy_for_start(thread_id)
|
||||
filesystem_selection = _resolve_filesystem_selection(
|
||||
mode=request.filesystem_mode,
|
||||
client_platform=request.client_platform,
|
||||
|
|
@ -2012,6 +2178,7 @@ async def resume_chat(
|
|||
)
|
||||
|
||||
await check_thread_access(session, thread, user)
|
||||
_raise_if_thread_busy_for_start(thread_id)
|
||||
filesystem_selection = _resolve_filesystem_selection(
|
||||
mode=request.filesystem_mode,
|
||||
client_platform=request.client_platform,
|
||||
|
|
|
|||
|
|
@ -335,6 +335,24 @@ class ResumeRequest(BaseModel):
|
|||
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
|
||||
|
||||
|
||||
class CancelActiveTurnResponse(BaseModel):
|
||||
"""Response for canceling an active turn on a chat thread."""
|
||||
|
||||
status: Literal["cancelling", "idle"]
|
||||
error_code: Literal["TURN_CANCELLING", "NO_ACTIVE_TURN"]
|
||||
retry_after_ms: int | None = None
|
||||
retry_after_at: int | None = None
|
||||
|
||||
|
||||
class TurnStatusResponse(BaseModel):
|
||||
"""Current turn execution status for a thread."""
|
||||
|
||||
status: Literal["idle", "busy", "cancelling"]
|
||||
active_turn_id: str | None = None
|
||||
retry_after_ms: int | None = None
|
||||
retry_after_at: int | None = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Public Chat Snapshot Schemas
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -565,7 +565,12 @@ class VercelStreamingService:
|
|||
# Error Part
|
||||
# =========================================================================
|
||||
|
||||
def format_error(self, error_text: str, error_code: str | None = None) -> str:
|
||||
def format_error(
|
||||
self,
|
||||
error_text: str,
|
||||
error_code: str | None = None,
|
||||
extra: dict[str, object] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format an error message.
|
||||
|
||||
|
|
@ -579,9 +584,11 @@ class VercelStreamingService:
|
|||
Example output:
|
||||
data: {"type":"error","errorText":"Something went wrong","errorCode":"SOME_CODE"}
|
||||
"""
|
||||
payload: dict[str, str] = {"type": "error", "errorText": error_text}
|
||||
payload: dict[str, object] = {"type": "error", "errorText": error_text}
|
||||
if error_code:
|
||||
payload["errorCode"] = error_code
|
||||
if extra:
|
||||
payload.update(extra)
|
||||
return self._format_sse(payload)
|
||||
|
||||
# =========================================================================
|
||||
|
|
|
|||
|
|
@ -45,6 +45,11 @@ from app.agents.new_chat.memory_extraction import (
|
|||
extract_and_save_memory,
|
||||
extract_and_save_team_memory,
|
||||
)
|
||||
from app.agents.new_chat.middleware.busy_mutex import (
|
||||
end_turn,
|
||||
get_cancel_state,
|
||||
is_cancel_requested,
|
||||
)
|
||||
from app.agents.new_chat.middleware.kb_persistence import (
|
||||
commit_staged_filesystem_state,
|
||||
)
|
||||
|
|
@ -72,6 +77,18 @@ from app.utils.user_message_multimodal import build_human_message_content
|
|||
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
_perf_log = get_perf_logger()
|
||||
TURN_CANCELLING_INITIAL_DELAY_MS = 200
|
||||
TURN_CANCELLING_BACKOFF_FACTOR = 2
|
||||
TURN_CANCELLING_MAX_DELAY_MS = 1500
|
||||
|
||||
|
||||
def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
|
||||
if attempt < 1:
|
||||
attempt = 1
|
||||
delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
|
||||
TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
|
||||
)
|
||||
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
|
||||
|
||||
|
||||
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
|
||||
|
|
@ -401,15 +418,35 @@ def _classify_stream_exception(
|
|||
exc: Exception,
|
||||
*,
|
||||
flow_label: str,
|
||||
) -> tuple[str, str, Literal["info", "warn", "error"], bool, str]:
|
||||
) -> tuple[
|
||||
str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None
|
||||
]:
|
||||
raw = str(exc)
|
||||
if isinstance(exc, BusyError) or "Thread is busy with another request" in raw:
|
||||
busy_thread_id = str(exc.request_id) if isinstance(exc, BusyError) else None
|
||||
if busy_thread_id and is_cancel_requested(busy_thread_id):
|
||||
cancel_state = get_cancel_state(busy_thread_id)
|
||||
attempt = cancel_state[0] if cancel_state else 1
|
||||
retry_after_ms = _compute_turn_cancelling_retry_delay(attempt)
|
||||
retry_after_at = int(time.time() * 1000) + retry_after_ms
|
||||
return (
|
||||
"thread_busy",
|
||||
"TURN_CANCELLING",
|
||||
"info",
|
||||
True,
|
||||
"A previous response is still stopping. Please try again in a moment.",
|
||||
{
|
||||
"retry_after_ms": retry_after_ms,
|
||||
"retry_after_at": retry_after_at,
|
||||
},
|
||||
)
|
||||
return (
|
||||
"thread_busy",
|
||||
"THREAD_BUSY",
|
||||
"warn",
|
||||
True,
|
||||
"Another response is still finishing for this thread. Please try again in a moment.",
|
||||
None,
|
||||
)
|
||||
|
||||
parsed = _parse_error_payload(raw)
|
||||
|
|
@ -431,6 +468,7 @@ def _classify_stream_exception(
|
|||
"warn",
|
||||
True,
|
||||
"This model is temporarily rate-limited. Please try again in a few seconds or switch models.",
|
||||
None,
|
||||
)
|
||||
|
||||
return (
|
||||
|
|
@ -439,6 +477,7 @@ def _classify_stream_exception(
|
|||
"error",
|
||||
False,
|
||||
f"Error during {flow_label}: {raw}",
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -470,7 +509,7 @@ def _emit_stream_terminal_error(
|
|||
message=message,
|
||||
extra=extra,
|
||||
)
|
||||
return streaming_service.format_error(message, error_code=error_code)
|
||||
return streaming_service.format_error(message, error_code=error_code, extra=extra)
|
||||
|
||||
|
||||
def _legacy_match_lc_id(
|
||||
|
|
@ -2497,6 +2536,7 @@ async def stream_new_chat(
|
|||
"turn-info",
|
||||
{"chat_turn_id": stream_result.turn_id},
|
||||
)
|
||||
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||
|
||||
# Initial thinking step - analyzing the request
|
||||
if mentioned_surfsense_docs:
|
||||
|
|
@ -2805,6 +2845,7 @@ async def stream_new_chat(
|
|||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
# Finish the step and message
|
||||
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
|
|
@ -2819,11 +2860,19 @@ async def stream_new_chat(
|
|||
severity,
|
||||
is_expected,
|
||||
user_message,
|
||||
error_extra,
|
||||
) = _classify_stream_exception(e, flow_label="chat")
|
||||
error_message = f"Error during chat: {e!s}"
|
||||
print(f"[stream_new_chat] {error_message}")
|
||||
print(f"[stream_new_chat] Exception type: {type(e).__name__}")
|
||||
print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}")
|
||||
if error_code == "TURN_CANCELLING":
|
||||
status_payload: dict[str, Any] = {"status": "cancelling"}
|
||||
if error_extra:
|
||||
status_payload.update(error_extra)
|
||||
yield streaming_service.format_data("turn-status", status_payload)
|
||||
else:
|
||||
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||
|
||||
yield _emit_stream_error(
|
||||
message=user_message,
|
||||
|
|
@ -2831,7 +2880,9 @@ async def stream_new_chat(
|
|||
error_code=error_code,
|
||||
severity=severity,
|
||||
is_expected=is_expected,
|
||||
extra=error_extra,
|
||||
)
|
||||
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
|
|
@ -2847,6 +2898,10 @@ async def stream_new_chat(
|
|||
# (CancelledError is a BaseException), and the rest of the
|
||||
# finally block — including session.close() — would never run.
|
||||
with anyio.CancelScope(shield=True):
|
||||
# Authoritative fallback cleanup for lock/cancel state. Middleware
|
||||
# teardown can be skipped on some client-abort paths.
|
||||
end_turn(str(chat_id))
|
||||
|
||||
# Release premium reservation if not finalized
|
||||
if _premium_request_id and _premium_reserved > 0 and user_id:
|
||||
try:
|
||||
|
|
@ -3206,6 +3261,7 @@ async def stream_resume_chat(
|
|||
"turn-info",
|
||||
{"chat_turn_id": stream_result.turn_id},
|
||||
)
|
||||
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||
|
||||
_t_stream_start = time.perf_counter()
|
||||
_first_event_logged = False
|
||||
|
|
@ -3305,6 +3361,7 @@ async def stream_resume_chat(
|
|||
},
|
||||
)
|
||||
|
||||
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
|
|
@ -3318,23 +3375,37 @@ async def stream_resume_chat(
|
|||
severity,
|
||||
is_expected,
|
||||
user_message,
|
||||
error_extra,
|
||||
) = _classify_stream_exception(e, flow_label="resume")
|
||||
error_message = f"Error during resume: {e!s}"
|
||||
print(f"[stream_resume_chat] {error_message}")
|
||||
print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}")
|
||||
if error_code == "TURN_CANCELLING":
|
||||
status_payload: dict[str, Any] = {"status": "cancelling"}
|
||||
if error_extra:
|
||||
status_payload.update(error_extra)
|
||||
yield streaming_service.format_data("turn-status", status_payload)
|
||||
else:
|
||||
yield streaming_service.format_data("turn-status", {"status": "busy"})
|
||||
yield _emit_stream_error(
|
||||
message=user_message,
|
||||
error_kind=error_kind,
|
||||
error_code=error_code,
|
||||
severity=severity,
|
||||
is_expected=is_expected,
|
||||
extra=error_extra,
|
||||
)
|
||||
yield streaming_service.format_data("turn-status", {"status": "idle"})
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
|
||||
finally:
|
||||
with anyio.CancelScope(shield=True):
|
||||
# Authoritative fallback cleanup for lock/cancel state. Middleware
|
||||
# teardown can be skipped on some client-abort paths.
|
||||
end_turn(str(chat_id))
|
||||
|
||||
# Release premium reservation if not finalized
|
||||
if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -7,7 +7,9 @@ import pytest
|
|||
from app.agents.new_chat.errors import BusyError
|
||||
from app.agents.new_chat.middleware.busy_mutex import (
|
||||
BusyMutexMiddleware,
|
||||
end_turn,
|
||||
get_cancel_event,
|
||||
is_cancel_requested,
|
||||
manager,
|
||||
request_cancel,
|
||||
reset_cancel,
|
||||
|
|
@ -88,3 +90,31 @@ async def test_no_thread_id_skipped_when_not_required() -> None:
|
|||
def test_reset_cancel_idempotent() -> None:
|
||||
# Should not raise even if event was never created
|
||||
reset_cancel("never-seen")
|
||||
|
||||
|
||||
def test_request_cancel_creates_event_for_unseen_thread() -> None:
|
||||
thread_id = "never-seen-cancel"
|
||||
reset_cancel(thread_id)
|
||||
|
||||
assert request_cancel(thread_id) is True
|
||||
assert get_cancel_event(thread_id).is_set()
|
||||
assert is_cancel_requested(thread_id) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_turn_force_clears_lock_and_cancel_state() -> None:
|
||||
thread_id = "forced-end-turn"
|
||||
mw = BusyMutexMiddleware()
|
||||
runtime = _Runtime(thread_id)
|
||||
|
||||
await mw.abefore_agent({}, runtime)
|
||||
assert manager.lock_for(thread_id).locked()
|
||||
|
||||
request_cancel(thread_id)
|
||||
assert is_cancel_requested(thread_id) is True
|
||||
|
||||
end_turn(thread_id)
|
||||
|
||||
assert not manager.lock_for(thread_id).locked()
|
||||
assert not get_cancel_event(thread_id).is_set()
|
||||
assert is_cancel_requested(thread_id) is False
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import pytest
|
|||
|
||||
import app.tasks.chat.stream_new_chat as stream_new_chat_module
|
||||
from app.agents.new_chat.errors import BusyError
|
||||
from app.agents.new_chat.middleware.busy_mutex import request_cancel, reset_cancel
|
||||
from app.tasks.chat.stream_new_chat import (
|
||||
StreamResult,
|
||||
_classify_stream_exception,
|
||||
|
|
@ -147,7 +148,7 @@ def test_stream_exception_classifies_rate_limited():
|
|||
exc = Exception(
|
||||
'{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}'
|
||||
)
|
||||
kind, code, severity, is_expected, user_message = _classify_stream_exception(
|
||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||
exc, flow_label="chat"
|
||||
)
|
||||
assert kind == "rate_limited"
|
||||
|
|
@ -155,11 +156,12 @@ def test_stream_exception_classifies_rate_limited():
|
|||
assert severity == "warn"
|
||||
assert is_expected is True
|
||||
assert "temporarily rate-limited" in user_message
|
||||
assert extra is None
|
||||
|
||||
|
||||
def test_stream_exception_classifies_thread_busy():
|
||||
exc = BusyError(request_id="thread-123")
|
||||
kind, code, severity, is_expected, user_message = _classify_stream_exception(
|
||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||
exc, flow_label="chat"
|
||||
)
|
||||
assert kind == "thread_busy"
|
||||
|
|
@ -167,11 +169,12 @@ def test_stream_exception_classifies_thread_busy():
|
|||
assert severity == "warn"
|
||||
assert is_expected is True
|
||||
assert "still finishing for this thread" in user_message
|
||||
assert extra is None
|
||||
|
||||
|
||||
def test_stream_exception_classifies_thread_busy_from_message():
|
||||
exc = Exception("Thread is busy with another request")
|
||||
kind, code, severity, is_expected, user_message = _classify_stream_exception(
|
||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||
exc, flow_label="chat"
|
||||
)
|
||||
assert kind == "thread_busy"
|
||||
|
|
@ -179,6 +182,24 @@ def test_stream_exception_classifies_thread_busy_from_message():
|
|||
assert severity == "warn"
|
||||
assert is_expected is True
|
||||
assert "still finishing for this thread" in user_message
|
||||
assert extra is None
|
||||
|
||||
|
||||
def test_stream_exception_classifies_turn_cancelling_when_cancel_requested():
|
||||
thread_id = "thread-cancelling-1"
|
||||
reset_cancel(thread_id)
|
||||
request_cancel(thread_id)
|
||||
exc = BusyError(request_id=thread_id)
|
||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||
exc, flow_label="chat"
|
||||
)
|
||||
assert kind == "thread_busy"
|
||||
assert code == "TURN_CANCELLING"
|
||||
assert severity == "info"
|
||||
assert is_expected is True
|
||||
assert "stopping" in user_message
|
||||
assert isinstance(extra, dict)
|
||||
assert "retry_after_ms" in extra
|
||||
|
||||
|
||||
def test_premium_classification_is_error_code_driven():
|
||||
|
|
@ -219,33 +240,33 @@ def test_toast_only_pre_accept_policy_has_no_inline_failed_marker():
|
|||
def test_network_send_failures_use_unified_retry_toast_message():
|
||||
classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts"
|
||||
classifier_source = classifier_path.read_text(encoding="utf-8")
|
||||
page_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
|
||||
request_errors_path = (
|
||||
Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-request-errors.ts"
|
||||
)
|
||||
page_source = page_path.read_text(encoding="utf-8")
|
||||
request_errors_source = request_errors_path.read_text(encoding="utf-8")
|
||||
|
||||
assert '"send_failed_pre_accept"' in classifier_source
|
||||
assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source
|
||||
assert 'errorCode === "TURN_CANCELLING"' in classifier_source
|
||||
assert "if (withCode.code) return withCode.code;" in classifier_source
|
||||
assert 'userMessage: "Message not sent. Please retry."' in classifier_source
|
||||
assert 'userMessage: "Connection issue. Please try again."' in classifier_source
|
||||
assert "tagPreAcceptSendFailure(error)" in page_source
|
||||
assert "const passthroughCodes = new Set([" in page_source
|
||||
assert '"PREMIUM_QUOTA_EXHAUSTED"' in page_source
|
||||
assert '"THREAD_BUSY"' in page_source
|
||||
assert '"AUTH_EXPIRED"' in page_source
|
||||
assert '"UNAUTHORIZED"' in page_source
|
||||
assert '"RATE_LIMITED"' in page_source
|
||||
assert '"NETWORK_ERROR"' in page_source
|
||||
assert '"STREAM_PARSE_ERROR"' in page_source
|
||||
assert '"TOOL_EXECUTION_ERROR"' in page_source
|
||||
assert '"PERSIST_MESSAGE_FAILED"' in page_source
|
||||
assert '"SERVER_ERROR"' in page_source
|
||||
assert "passthroughCodes.has(existingCode)" in page_source
|
||||
assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in page_source
|
||||
assert 'errorCode: "NETWORK_ERROR"' not in page_source
|
||||
assert "Failed to start chat. Please try again." not in page_source
|
||||
assert "const passthroughCodes = new Set([" in request_errors_source
|
||||
assert '"PREMIUM_QUOTA_EXHAUSTED"' in request_errors_source
|
||||
assert '"THREAD_BUSY"' in request_errors_source
|
||||
assert '"TURN_CANCELLING"' in request_errors_source
|
||||
assert '"AUTH_EXPIRED"' in request_errors_source
|
||||
assert '"UNAUTHORIZED"' in request_errors_source
|
||||
assert '"RATE_LIMITED"' in request_errors_source
|
||||
assert '"NETWORK_ERROR"' in request_errors_source
|
||||
assert '"STREAM_PARSE_ERROR"' in request_errors_source
|
||||
assert '"TOOL_EXECUTION_ERROR"' in request_errors_source
|
||||
assert '"PERSIST_MESSAGE_FAILED"' in request_errors_source
|
||||
assert '"SERVER_ERROR"' in request_errors_source
|
||||
assert "passthroughCodes.has(existingCode)" in request_errors_source
|
||||
assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in request_errors_source
|
||||
assert 'errorCode: "NETWORK_ERROR"' not in request_errors_source
|
||||
assert "Failed to start chat. Please try again." not in classifier_source
|
||||
|
||||
|
||||
def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows():
|
||||
|
|
@ -269,3 +290,75 @@ def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows()
|
|||
|
||||
# New flow persists only when accepted and not already persisted.
|
||||
assert "if (newAccepted && !userPersisted) {" in source
|
||||
assert "const fetchWithTurnCancellingRetry = useCallback(" in source
|
||||
assert "computeFallbackTurnCancellingRetryDelay" in source
|
||||
assert 'withMeta.errorCode === "TURN_CANCELLING"' in source
|
||||
assert 'withMeta.errorCode === "THREAD_BUSY"' in source
|
||||
assert "await fetchWithTurnCancellingRetry(() =>" in source
|
||||
|
||||
|
||||
def test_cancel_active_turn_route_contract_exists():
|
||||
routes_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_backend/app/routes/new_chat_routes.py"
|
||||
)
|
||||
source = routes_path.read_text(encoding="utf-8")
|
||||
|
||||
assert '@router.post(\n "/threads/{thread_id}/cancel-active-turn",' in source
|
||||
assert "response_model=CancelActiveTurnResponse" in source
|
||||
assert 'status="cancelling",' in source
|
||||
assert 'error_code="TURN_CANCELLING",' in source
|
||||
assert "retry_after_ms=retry_after_ms if retry_after_ms > 0 else None," in source
|
||||
assert "retry_after_at=" in source
|
||||
assert 'status="idle",' in source
|
||||
assert 'error_code="NO_ACTIVE_TURN",' in source
|
||||
|
||||
|
||||
def test_turn_status_route_contract_exists():
|
||||
routes_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_backend/app/routes/new_chat_routes.py"
|
||||
)
|
||||
source = routes_path.read_text(encoding="utf-8")
|
||||
|
||||
assert '@router.get(\n "/threads/{thread_id}/turn-status",' in source
|
||||
assert "response_model=TurnStatusResponse" in source
|
||||
assert "_build_turn_status_payload(thread_id)" in source
|
||||
assert "Permission.CHATS_READ.value" in source
|
||||
assert "_raise_if_thread_busy_for_start(" in source
|
||||
|
||||
|
||||
def test_turn_cancelling_retry_policy_contract_exists():
|
||||
routes_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_backend/app/routes/new_chat_routes.py"
|
||||
)
|
||||
source = routes_path.read_text(encoding="utf-8")
|
||||
|
||||
assert "TURN_CANCELLING_INITIAL_DELAY_MS = 200" in source
|
||||
assert "TURN_CANCELLING_BACKOFF_FACTOR = 2" in source
|
||||
assert "TURN_CANCELLING_MAX_DELAY_MS = 1500" in source
|
||||
assert "def _compute_turn_cancelling_retry_delay(" in source
|
||||
assert "retry-after-ms" in source
|
||||
assert '"Retry-After"' in source
|
||||
assert '"errorCode": "TURN_CANCELLING"' in source
|
||||
|
||||
|
||||
def test_turn_status_sse_contract_exists():
|
||||
stream_source = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_backend/app/tasks/chat/stream_new_chat.py"
|
||||
).read_text(encoding="utf-8")
|
||||
state_source = (
|
||||
Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/streaming-state.ts"
|
||||
).read_text(encoding="utf-8")
|
||||
pipeline_source = (
|
||||
Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/stream-pipeline.ts"
|
||||
).read_text(encoding="utf-8")
|
||||
|
||||
assert '"turn-status"' in stream_source
|
||||
assert '"status": "busy"' in stream_source
|
||||
assert '"status": "idle"' in stream_source
|
||||
assert "type: \"data-turn-status\"" in state_source
|
||||
assert "case \"data-turn-status\":" in pipeline_source
|
||||
assert "end_turn(str(chat_id))" in stream_source
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -26,7 +26,14 @@ export const setThreadVisibilityAtom = atom(null, (get, set, newVisibility: Chat
|
|||
|
||||
export const resetCurrentThreadAtom = atom(null, (_, set) => {
|
||||
set(currentThreadAtom, initialState);
|
||||
set(reportPanelAtom, { isOpen: false, reportId: null, title: null, wordCount: null });
|
||||
set(reportPanelAtom, {
|
||||
isOpen: false,
|
||||
reportId: null,
|
||||
title: null,
|
||||
wordCount: null,
|
||||
shareToken: null,
|
||||
contentType: "markdown",
|
||||
});
|
||||
});
|
||||
|
||||
/** Target comment ID to scroll to (from URL navigation or inbox click) */
|
||||
|
|
|
|||
|
|
@ -548,8 +548,10 @@ const AssistantMessageInner: FC = () => {
|
|||
</div>
|
||||
)}
|
||||
|
||||
<div className="aui-assistant-message-footer mt-3 mb-5 ml-2 flex items-center gap-2">
|
||||
<AssistantActionBar />
|
||||
<div className="aui-assistant-message-footer mt-3 mb-5 ml-2 h-6">
|
||||
<div className="h-full opacity-100 transition-opacity">
|
||||
<AssistantActionBar />
|
||||
</div>
|
||||
</div>
|
||||
</CitationMetadataProvider>
|
||||
);
|
||||
|
|
@ -642,35 +644,41 @@ export const AssistantMessage: FC = () => {
|
|||
className="aui-assistant-message-root group fade-in slide-in-from-bottom-1 relative mx-auto w-full max-w-(--thread-max-width) animate-in py-3 duration-150"
|
||||
data-role="assistant"
|
||||
>
|
||||
{/* Comment trigger — right-aligned, just below user query on all screen sizes */}
|
||||
{showCommentTrigger && (
|
||||
<div className="mr-2 mb-1 flex justify-end">
|
||||
<button
|
||||
ref={isDesktop ? commentTriggerRef : undefined}
|
||||
type="button"
|
||||
onClick={
|
||||
isDesktop ? () => setIsInlineOpen((prev) => !prev) : () => setIsSheetOpen(true)
|
||||
}
|
||||
className={cn(
|
||||
"flex items-center gap-1.5 rounded-full px-3 py-1 text-sm transition-colors",
|
||||
isDesktop && isInlineOpen
|
||||
? "bg-primary/10 text-primary"
|
||||
: hasComments
|
||||
? "text-primary hover:bg-primary/10"
|
||||
: "text-muted-foreground hover:text-foreground hover:bg-muted"
|
||||
)}
|
||||
>
|
||||
<MessageCircleReply className={cn("size-3.5", hasComments && "fill-current")} />
|
||||
{hasComments ? (
|
||||
<span>
|
||||
{commentCount} {commentCount === 1 ? "comment" : "comments"}
|
||||
</span>
|
||||
) : (
|
||||
<span>Add comment</span>
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
{/* Fixed trigger slot prevents any vertical reflow when visibility changes */}
|
||||
<div className="mr-2 mb-1 flex h-7 justify-end">
|
||||
<button
|
||||
ref={isDesktop ? commentTriggerRef : undefined}
|
||||
type="button"
|
||||
onClick={
|
||||
showCommentTrigger
|
||||
? isDesktop
|
||||
? () => setIsInlineOpen((prev) => !prev)
|
||||
: () => setIsSheetOpen(true)
|
||||
: undefined
|
||||
}
|
||||
aria-hidden={!showCommentTrigger}
|
||||
tabIndex={showCommentTrigger ? 0 : -1}
|
||||
className={cn(
|
||||
"flex items-center gap-1.5 rounded-full px-3 py-1 text-sm transition-colors",
|
||||
"opacity-0 pointer-events-none",
|
||||
showCommentTrigger && "opacity-100 pointer-events-auto",
|
||||
isDesktop && isInlineOpen
|
||||
? "bg-primary/10 text-primary"
|
||||
: hasComments
|
||||
? "text-primary hover:bg-primary/10"
|
||||
: "text-muted-foreground hover:text-foreground hover:bg-muted"
|
||||
)}
|
||||
>
|
||||
<MessageCircleReply className={cn("size-3.5", hasComments && "fill-current")} />
|
||||
{hasComments ? (
|
||||
<span>
|
||||
{commentCount} {commentCount === 1 ? "comment" : "comments"}
|
||||
</span>
|
||||
) : (
|
||||
<span>Add comment</span>
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Desktop floating comment panel — overlays on top of chat content */}
|
||||
{showCommentTrigger && isDesktop && isInlineOpen && dbMessageId && (
|
||||
|
|
|
|||
52
surfsense_web/components/assistant-ui/chat-viewport.tsx
Normal file
52
surfsense_web/components/assistant-ui/chat-viewport.tsx
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
"use client";
|
||||
|
||||
import { ThreadPrimitive } from "@assistant-ui/react";
|
||||
import { ArrowDownIcon } from "lucide-react";
|
||||
import type { FC, ReactNode } from "react";
|
||||
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
||||
|
||||
const ChatScrollToBottom: FC = () => (
|
||||
<ThreadPrimitive.ScrollToBottom asChild>
|
||||
<TooltipIconButton
|
||||
tooltip="Scroll to bottom"
|
||||
variant="outline"
|
||||
className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent"
|
||||
>
|
||||
<ArrowDownIcon />
|
||||
</TooltipIconButton>
|
||||
</ThreadPrimitive.ScrollToBottom>
|
||||
);
|
||||
|
||||
export interface ChatViewportProps {
|
||||
children: ReactNode;
|
||||
footer?: ReactNode;
|
||||
}
|
||||
|
||||
export const ChatViewport: FC<ChatViewportProps> = ({ children, footer }) => (
|
||||
<ThreadPrimitive.Viewport
|
||||
turnAnchor="top"
|
||||
autoScroll
|
||||
scrollToBottomOnRunStart
|
||||
scrollToBottomOnInitialize
|
||||
scrollToBottomOnThreadSwitch
|
||||
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 scroll-smooth"
|
||||
style={{ scrollbarGutter: "stable" }}
|
||||
>
|
||||
<div
|
||||
aria-hidden
|
||||
className="aui-chat-viewport-top-fade pointer-events-none sticky top-0 z-10 -mx-4 h-2 shrink-0 bg-gradient-to-b from-main-panel from-20% to-transparent"
|
||||
/>
|
||||
{children}
|
||||
{footer ? (
|
||||
<ThreadPrimitive.ViewportFooter
|
||||
className="aui-chat-composer-footer sticky bottom-0 z-20 -mx-4 mt-auto flex flex-col items-stretch bg-gradient-to-t from-main-panel from-60% to-transparent px-4 pt-6"
|
||||
style={{ paddingBottom: "max(0.5rem, env(safe-area-inset-bottom))" }}
|
||||
>
|
||||
<div className="aui-chat-composer-area relative mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-3 overflow-visible">
|
||||
<ChatScrollToBottom />
|
||||
{footer}
|
||||
</div>
|
||||
</ThreadPrimitive.ViewportFooter>
|
||||
) : null}
|
||||
</ThreadPrimitive.Viewport>
|
||||
);
|
||||
File diff suppressed because it is too large
Load diff
24
surfsense_web/components/assistant-ui/nested-scroll.tsx
Normal file
24
surfsense_web/components/assistant-ui/nested-scroll.tsx
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
"use client";
|
||||
|
||||
import { forwardRef, type ComponentPropsWithoutRef, type WheelEvent } from "react";
|
||||
|
||||
export type NestedScrollProps = ComponentPropsWithoutRef<"div">;
|
||||
|
||||
export const NestedScroll = forwardRef<HTMLDivElement, NestedScrollProps>(
|
||||
({ onWheel, ...props }, ref) => {
|
||||
const handleWheel = (event: WheelEvent<HTMLDivElement>) => {
|
||||
const el = event.currentTarget;
|
||||
const canScrollUp = el.scrollTop > 0;
|
||||
const canScrollDown = el.scrollTop < el.scrollHeight - el.clientHeight - 1;
|
||||
const goingUp = event.deltaY < 0;
|
||||
const goingDown = event.deltaY > 0;
|
||||
if ((goingUp && canScrollUp) || (goingDown && canScrollDown)) {
|
||||
event.stopPropagation();
|
||||
}
|
||||
onWheel?.(event);
|
||||
};
|
||||
return <div ref={ref} onWheel={handleWheel} {...props} />;
|
||||
}
|
||||
);
|
||||
|
||||
NestedScroll.displayName = "NestedScroll";
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
import { ThreadPrimitive } from "@assistant-ui/react";
|
||||
import { ArrowDownIcon } from "lucide-react";
|
||||
import type { FC } from "react";
|
||||
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
||||
|
||||
export const ThreadScrollToBottom: FC = () => {
|
||||
return (
|
||||
<ThreadPrimitive.ScrollToBottom asChild>
|
||||
<TooltipIconButton
|
||||
tooltip="Scroll to bottom"
|
||||
variant="outline"
|
||||
className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent"
|
||||
>
|
||||
<ArrowDownIcon />
|
||||
</TooltipIconButton>
|
||||
</ThreadPrimitive.ScrollToBottom>
|
||||
);
|
||||
};
|
||||
|
|
@ -5,12 +5,10 @@ import {
|
|||
ThreadPrimitive,
|
||||
useAui,
|
||||
useAuiState,
|
||||
useThreadViewportStore,
|
||||
} from "@assistant-ui/react";
|
||||
import { useAtom, useAtomValue, useSetAtom } from "jotai";
|
||||
import {
|
||||
AlertCircle,
|
||||
ArrowDownIcon,
|
||||
ArrowUpIcon,
|
||||
Camera,
|
||||
ChevronDown,
|
||||
|
|
@ -55,6 +53,7 @@ import {
|
|||
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
|
||||
import { AssistantMessage } from "@/components/assistant-ui/assistant-message";
|
||||
import { ChatSessionStatus } from "@/components/assistant-ui/chat-session-status";
|
||||
import { ChatViewport } from "@/components/assistant-ui/chat-viewport";
|
||||
import { ConnectorIndicator } from "@/components/assistant-ui/connector-popup";
|
||||
import { useDocumentUploadDialog } from "@/components/assistant-ui/document-upload-popup";
|
||||
import {
|
||||
|
|
@ -112,10 +111,13 @@ const ThreadContent: FC = () => {
|
|||
["--thread-max-width" as string]: "44rem",
|
||||
}}
|
||||
>
|
||||
<ThreadPrimitive.Viewport
|
||||
turnAnchor="top"
|
||||
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4"
|
||||
style={{ scrollbarGutter: "stable" }}
|
||||
<ChatViewport
|
||||
footer={
|
||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
||||
<PremiumQuotaPinnedAlert />
|
||||
<Composer />
|
||||
</AuiIf>
|
||||
}
|
||||
>
|
||||
<AuiIf condition={({ thread }) => thread.isEmpty}>
|
||||
<ThreadWelcome />
|
||||
|
|
@ -128,24 +130,7 @@ const ThreadContent: FC = () => {
|
|||
AssistantMessage,
|
||||
}}
|
||||
/>
|
||||
|
||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
||||
<div className="grow" />
|
||||
</AuiIf>
|
||||
|
||||
<ThreadPrimitive.ViewportFooter
|
||||
className="aui-thread-viewport-footer sticky bottom-0 z-10 mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible rounded-t-3xl bg-main-panel pb-4 md:pb-6"
|
||||
style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }}
|
||||
>
|
||||
<ThreadScrollToBottom />
|
||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
||||
<PremiumQuotaPinnedAlert />
|
||||
</AuiIf>
|
||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
||||
<Composer />
|
||||
</AuiIf>
|
||||
</ThreadPrimitive.ViewportFooter>
|
||||
</ThreadPrimitive.Viewport>
|
||||
</ChatViewport>
|
||||
</ThreadPrimitive.Root>
|
||||
);
|
||||
};
|
||||
|
|
@ -181,20 +166,6 @@ const PremiumQuotaPinnedAlert: FC = () => {
|
|||
);
|
||||
};
|
||||
|
||||
const ThreadScrollToBottom: FC = () => {
|
||||
return (
|
||||
<ThreadPrimitive.ScrollToBottom asChild>
|
||||
<TooltipIconButton
|
||||
tooltip="Scroll to bottom"
|
||||
variant="outline"
|
||||
className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent"
|
||||
>
|
||||
<ArrowDownIcon />
|
||||
</TooltipIconButton>
|
||||
</ThreadPrimitive.ScrollToBottom>
|
||||
);
|
||||
};
|
||||
|
||||
const getTimeBasedGreeting = (user?: { display_name?: string | null; email?: string }): string => {
|
||||
const hour = new Date().getHours();
|
||||
|
||||
|
|
@ -411,23 +382,9 @@ const Composer: FC = () => {
|
|||
>(new Map());
|
||||
const documentPickerRef = useRef<DocumentMentionPickerRef>(null);
|
||||
const promptPickerRef = useRef<PromptPickerRef>(null);
|
||||
const viewportRef = useRef<Element | null>(null);
|
||||
const { search_space_id, chat_id } = useParams();
|
||||
const aui = useAui();
|
||||
const threadViewportStore = useThreadViewportStore();
|
||||
const hasAutoFocusedRef = useRef(false);
|
||||
const submitCleanupRef = useRef<(() => void) | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
submitCleanupRef.current?.();
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Store viewport element reference on mount
|
||||
useEffect(() => {
|
||||
viewportRef.current = document.querySelector(".aui-thread-viewport");
|
||||
}, []);
|
||||
|
||||
const electronAPI = useElectronAPI();
|
||||
const [clipboardInitialText, setClipboardInitialText] = useState<string | undefined>();
|
||||
|
|
@ -626,7 +583,6 @@ const Composer: FC = () => {
|
|||
[showDocumentPopover, showPromptPicker]
|
||||
);
|
||||
|
||||
// Submit message (blocked during streaming, document picker open, or AI responding to another user)
|
||||
const handleSubmit = useCallback(() => {
|
||||
if (isThreadRunning || isBlockedByOtherUser) return;
|
||||
if (showDocumentPopover || showPromptPicker) return;
|
||||
|
|
@ -638,50 +594,9 @@ const Composer: FC = () => {
|
|||
setClipboardInitialText(undefined);
|
||||
}
|
||||
|
||||
const viewportEl = viewportRef.current;
|
||||
const heightBefore = viewportEl?.scrollHeight ?? 0;
|
||||
|
||||
aui.composer().send();
|
||||
editorRef.current?.clear();
|
||||
setMentionedDocuments([]);
|
||||
|
||||
// With turnAnchor="top", ViewportSlack adds min-height to the last
|
||||
// assistant message so that scrolling-to-bottom actually positions the
|
||||
// user message at the TOP of the viewport. That slack height is
|
||||
// calculated asynchronously (ResizeObserver → style → layout).
|
||||
// Poll via rAF for ~500ms, re-scrolling whenever scrollHeight changes.
|
||||
const scrollToBottom = () =>
|
||||
threadViewportStore.getState().scrollToBottom({ behavior: "instant" });
|
||||
|
||||
let lastHeight = heightBefore;
|
||||
let frames = 0;
|
||||
let cancelled = false;
|
||||
const POLL_FRAMES = 30;
|
||||
|
||||
const pollAndScroll = () => {
|
||||
if (cancelled) return;
|
||||
const el = viewportRef.current;
|
||||
if (el) {
|
||||
const h = el.scrollHeight;
|
||||
if (h !== lastHeight) {
|
||||
lastHeight = h;
|
||||
scrollToBottom();
|
||||
}
|
||||
}
|
||||
if (++frames < POLL_FRAMES) {
|
||||
requestAnimationFrame(pollAndScroll);
|
||||
}
|
||||
};
|
||||
requestAnimationFrame(pollAndScroll);
|
||||
|
||||
const t1 = setTimeout(scrollToBottom, 100);
|
||||
const t2 = setTimeout(scrollToBottom, 300);
|
||||
|
||||
submitCleanupRef.current = () => {
|
||||
cancelled = true;
|
||||
clearTimeout(t1);
|
||||
clearTimeout(t2);
|
||||
};
|
||||
}, [
|
||||
showDocumentPopover,
|
||||
showPromptPicker,
|
||||
|
|
@ -690,7 +605,6 @@ const Composer: FC = () => {
|
|||
clipboardInitialText,
|
||||
aui,
|
||||
setMentionedDocuments,
|
||||
threadViewportStore,
|
||||
]);
|
||||
|
||||
const handleDocumentRemove = useCallback(
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import {
|
|||
isDoomLoopInterrupt,
|
||||
} from "@/components/tool-ui/doom-loop-approval";
|
||||
import { GenericHitlApprovalToolUI } from "@/components/tool-ui/generic-hitl-approval";
|
||||
import { NestedScroll } from "@/components/assistant-ui/nested-scroll";
|
||||
import {
|
||||
AlertDialog,
|
||||
AlertDialogAction,
|
||||
|
|
@ -475,7 +476,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => {
|
|||
{(argsText || isRunning) && (
|
||||
<div className="flex flex-col gap-1 min-w-0">
|
||||
<p className="text-xs font-medium text-muted-foreground">Inputs</p>
|
||||
<div className="max-h-48 overflow-auto rounded-md bg-muted/40">
|
||||
<NestedScroll className="max-h-48 overflow-auto rounded-md bg-muted/40">
|
||||
{argsText ? (
|
||||
<pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono">
|
||||
{argsText}
|
||||
|
|
@ -489,7 +490,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => {
|
|||
Waiting for input…
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</NestedScroll>
|
||||
</div>
|
||||
)}
|
||||
{!isCancelled && result !== undefined && (
|
||||
|
|
@ -497,11 +498,11 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => {
|
|||
<Separator />
|
||||
<div className="flex flex-col gap-1 min-w-0">
|
||||
<p className="text-xs font-medium text-muted-foreground">Result</p>
|
||||
<div className="max-h-64 overflow-auto rounded-md bg-muted/40">
|
||||
<NestedScroll className="max-h-64 overflow-auto rounded-md bg-muted/40">
|
||||
<pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono">
|
||||
{typeof result === "string" ? result : serializedResult}
|
||||
</pre>
|
||||
</div>
|
||||
</NestedScroll>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,10 @@
|
|||
import { ActionBarPrimitive, AuiIf, MessagePrimitive, useAuiState } from "@assistant-ui/react";
|
||||
import {
|
||||
ActionBarPrimitive,
|
||||
AuiIf,
|
||||
MessagePrimitive,
|
||||
useAuiState,
|
||||
useMessagePartText,
|
||||
} from "@assistant-ui/react";
|
||||
import { useAtomValue } from "jotai";
|
||||
import { CheckIcon, CopyIcon, Pencil } from "lucide-react";
|
||||
import Image from "next/image";
|
||||
|
|
@ -7,6 +13,8 @@ import { currentThreadAtom } from "@/atoms/chat/current-thread.atom";
|
|||
import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom";
|
||||
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
||||
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
||||
import { getMentionDocKey } from "@/lib/chat/mention-doc-key";
|
||||
import { parseMentionSegments } from "@/lib/chat/parse-mention-segments";
|
||||
|
||||
interface AuthorMetadata {
|
||||
displayName: string | null;
|
||||
|
|
@ -47,23 +55,40 @@ const UserAvatar: FC<AuthorMetadata> = ({ displayName, avatarUrl }) => {
|
|||
);
|
||||
};
|
||||
|
||||
export const UserMessage: FC = () => {
|
||||
const UserTextPart: FC = () => {
|
||||
const messageId = useAuiState(({ message }) => message?.id);
|
||||
const messageText = useAuiState(({ message }) =>
|
||||
(message?.content ?? [])
|
||||
.map((part) =>
|
||||
typeof part === "object" &&
|
||||
part !== null &&
|
||||
"type" in part &&
|
||||
(part as { type?: string }).type === "text" &&
|
||||
"text" in part
|
||||
? String((part as { text?: string }).text ?? "")
|
||||
: ""
|
||||
)
|
||||
.join("")
|
||||
);
|
||||
const part = useMessagePartText();
|
||||
const text = (part as { text?: string }).text ?? "";
|
||||
const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom);
|
||||
const mentionedDocs = messageId ? messageDocumentsMap[messageId] : undefined;
|
||||
const mentionedDocs = (messageId ? messageDocumentsMap[messageId] : undefined) ?? [];
|
||||
|
||||
const segments = parseMentionSegments(text, mentionedDocs);
|
||||
|
||||
return (
|
||||
<p style={{ whiteSpace: "pre-line" }} className="break-words">
|
||||
{segments.map((segment) =>
|
||||
segment.type === "text" ? (
|
||||
<span key={`txt-${segment.start}`}>{segment.value}</span>
|
||||
) : (
|
||||
<span
|
||||
key={`mention-${getMentionDocKey(segment.doc)}-${segment.start}`}
|
||||
className="inline-flex items-center gap-1 mx-0.5 px-1 py-0.5 rounded bg-primary/10 text-xs font-bold text-primary/60 select-none align-middle leading-none"
|
||||
title={segment.doc.title}
|
||||
>
|
||||
<span className="flex items-center text-muted-foreground">
|
||||
{getConnectorIcon(segment.doc.document_type ?? "UNKNOWN", "h-3 w-3")}
|
||||
</span>
|
||||
<span className="max-w-[120px] truncate">{segment.doc.title}</span>
|
||||
</span>
|
||||
)
|
||||
)}
|
||||
</p>
|
||||
);
|
||||
};
|
||||
|
||||
const userMessageParts = { Text: UserTextPart };
|
||||
|
||||
export const UserMessage: FC = () => {
|
||||
const metadata = useAuiState(({ message }) => message?.metadata);
|
||||
const author = metadata?.custom?.author as AuthorMetadata | undefined;
|
||||
const isSharedChat = useAtomValue(currentThreadAtom).visibility === "SEARCH_SPACE";
|
||||
|
|
@ -78,11 +103,7 @@ export const UserMessage: FC = () => {
|
|||
<div className="aui-user-message-content-wrapper flex items-end gap-2">
|
||||
<div className="relative flex-1 min-w-0">
|
||||
<div className="aui-user-message-content wrap-break-word rounded-2xl bg-muted px-4 py-2.5 text-foreground">
|
||||
{mentionedDocs && mentionedDocs.length > 0 ? (
|
||||
<UserMessageWithMentionChips text={messageText} mentionedDocs={mentionedDocs} />
|
||||
) : (
|
||||
<MessagePrimitive.Parts />
|
||||
)}
|
||||
<MessagePrimitive.Parts components={userMessageParts} />
|
||||
</div>
|
||||
<div className="absolute right-0 top-full mt-1 z-10 opacity-100 pointer-events-auto md:opacity-0 md:pointer-events-none md:transition-opacity md:duration-200 md:delay-300 md:group-hover/user-msg:opacity-100 md:group-hover/user-msg:delay-0 md:group-hover/user-msg:pointer-events-auto">
|
||||
<UserActionBar />
|
||||
|
|
@ -99,64 +120,6 @@ export const UserMessage: FC = () => {
|
|||
);
|
||||
};
|
||||
|
||||
const UserMessageWithMentionChips: FC<{
|
||||
text: string;
|
||||
mentionedDocs: { id: number; title: string; document_type: string }[];
|
||||
}> = ({ text, mentionedDocs }) => {
|
||||
type Segment =
|
||||
| { type: "text"; value: string; start: number }
|
||||
| { type: "mention"; doc: { id: number; title: string; document_type: string }; start: number };
|
||||
|
||||
const tokens = mentionedDocs
|
||||
.map((doc) => ({ doc, token: `@${doc.title}` }))
|
||||
.sort((a, b) => b.token.length - a.token.length);
|
||||
|
||||
const segments: Segment[] = [];
|
||||
let i = 0;
|
||||
let buffer = "";
|
||||
let bufferStart = 0;
|
||||
while (i < text.length) {
|
||||
const tokenMatch = tokens.find(({ token }) => text.startsWith(token, i));
|
||||
if (tokenMatch) {
|
||||
if (buffer) {
|
||||
segments.push({ type: "text", value: buffer, start: bufferStart });
|
||||
buffer = "";
|
||||
}
|
||||
segments.push({ type: "mention", doc: tokenMatch.doc, start: i });
|
||||
i += tokenMatch.token.length;
|
||||
bufferStart = i;
|
||||
continue;
|
||||
}
|
||||
if (!buffer) bufferStart = i;
|
||||
buffer += text[i];
|
||||
i += 1;
|
||||
}
|
||||
if (buffer) {
|
||||
segments.push({ type: "text", value: buffer, start: bufferStart });
|
||||
}
|
||||
|
||||
return (
|
||||
<span className="whitespace-pre-wrap break-words">
|
||||
{segments.map((segment) =>
|
||||
segment.type === "text" ? (
|
||||
<span key={`txt-${segment.start}`}>{segment.value}</span>
|
||||
) : (
|
||||
<span
|
||||
key={`mention-${segment.doc.document_type}:${segment.doc.id}-${segment.start}`}
|
||||
className="inline-flex items-center gap-1 mx-0.5 px-1 py-0.5 rounded bg-primary/10 text-xs font-bold text-primary/60 select-none align-baseline"
|
||||
title={segment.doc.title}
|
||||
>
|
||||
<span className="flex items-center text-muted-foreground">
|
||||
{getConnectorIcon(segment.doc.document_type ?? "UNKNOWN", "h-3 w-3")}
|
||||
</span>
|
||||
<span className="max-w-[120px] truncate">{segment.doc.title}</span>
|
||||
</span>
|
||||
)
|
||||
)}
|
||||
</span>
|
||||
);
|
||||
};
|
||||
|
||||
const UserActionBar: FC = () => {
|
||||
const isThreadRunning = useAuiState(({ thread }) => thread.isRunning);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
"use client";
|
||||
|
||||
import { AuiIf, ThreadPrimitive } from "@assistant-ui/react";
|
||||
import { ArrowDownIcon } from "lucide-react";
|
||||
import type { FC } from "react";
|
||||
import { AssistantMessage } from "@/components/assistant-ui/assistant-message";
|
||||
import { ChatViewport } from "@/components/assistant-ui/chat-viewport";
|
||||
import { EditComposer } from "@/components/assistant-ui/edit-composer";
|
||||
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
||||
import { UserMessage } from "@/components/assistant-ui/user-message";
|
||||
import { FreeComposer } from "./free-composer";
|
||||
|
||||
|
|
@ -24,20 +23,6 @@ const FreeThreadWelcome: FC = () => {
|
|||
);
|
||||
};
|
||||
|
||||
const ThreadScrollToBottom: FC = () => {
|
||||
return (
|
||||
<ThreadPrimitive.ScrollToBottom asChild>
|
||||
<TooltipIconButton
|
||||
tooltip="Scroll to bottom"
|
||||
variant="outline"
|
||||
className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent"
|
||||
>
|
||||
<ArrowDownIcon />
|
||||
</TooltipIconButton>
|
||||
</ThreadPrimitive.ScrollToBottom>
|
||||
);
|
||||
};
|
||||
|
||||
export const FreeThread: FC = () => {
|
||||
return (
|
||||
<ThreadPrimitive.Root
|
||||
|
|
@ -46,10 +31,12 @@ export const FreeThread: FC = () => {
|
|||
["--thread-max-width" as string]: "44rem",
|
||||
}}
|
||||
>
|
||||
<ThreadPrimitive.Viewport
|
||||
turnAnchor="top"
|
||||
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4"
|
||||
style={{ scrollbarGutter: "stable" }}
|
||||
<ChatViewport
|
||||
footer={
|
||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
||||
<FreeComposer />
|
||||
</AuiIf>
|
||||
}
|
||||
>
|
||||
<AuiIf condition={({ thread }) => thread.isEmpty}>
|
||||
<FreeThreadWelcome />
|
||||
|
|
@ -62,21 +49,7 @@ export const FreeThread: FC = () => {
|
|||
AssistantMessage,
|
||||
}}
|
||||
/>
|
||||
|
||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
||||
<div className="grow" />
|
||||
</AuiIf>
|
||||
|
||||
<ThreadPrimitive.ViewportFooter
|
||||
className="aui-thread-viewport-footer sticky bottom-0 z-10 mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible rounded-t-3xl bg-main-panel pb-4 md:pb-6"
|
||||
style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }}
|
||||
>
|
||||
<ThreadScrollToBottom />
|
||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
||||
<FreeComposer />
|
||||
</AuiIf>
|
||||
</ThreadPrimitive.ViewportFooter>
|
||||
</ThreadPrimitive.Viewport>
|
||||
</ChatViewport>
|
||||
</ThreadPrimitive.Root>
|
||||
);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -236,6 +236,93 @@ interface DisplayItem {
|
|||
isAutoMode: boolean;
|
||||
}
|
||||
|
||||
const TruncatedNameWithTooltip: React.FC<{
|
||||
text: string;
|
||||
className?: string;
|
||||
enableTooltip: boolean;
|
||||
}> = ({ text, className, enableTooltip }) => {
|
||||
const textRef = useRef<HTMLSpanElement>(null);
|
||||
const openTimerRef = useRef<number | undefined>(undefined);
|
||||
const [isTruncated, setIsTruncated] = useState(false);
|
||||
const [open, setOpen] = useState(false);
|
||||
|
||||
const recalcTruncation = useCallback(() => {
|
||||
const el = textRef.current;
|
||||
if (!el) return;
|
||||
setIsTruncated(el.scrollWidth > el.clientWidth + 1);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (!enableTooltip) return;
|
||||
const el = textRef.current;
|
||||
if (!el) return;
|
||||
|
||||
const raf = requestAnimationFrame(recalcTruncation);
|
||||
recalcTruncation();
|
||||
|
||||
const observer = new ResizeObserver(recalcTruncation);
|
||||
observer.observe(el);
|
||||
if (el.parentElement) observer.observe(el.parentElement);
|
||||
window.addEventListener("resize", recalcTruncation);
|
||||
|
||||
return () => {
|
||||
cancelAnimationFrame(raf);
|
||||
observer.disconnect();
|
||||
window.removeEventListener("resize", recalcTruncation);
|
||||
};
|
||||
}, [enableTooltip, recalcTruncation]);
|
||||
|
||||
useEffect(() => {
|
||||
// Recompute when row text changes.
|
||||
void text;
|
||||
requestAnimationFrame(recalcTruncation);
|
||||
}, [text, recalcTruncation]);
|
||||
|
||||
useEffect(
|
||||
() => () => {
|
||||
if (openTimerRef.current) window.clearTimeout(openTimerRef.current);
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
if (!enableTooltip) {
|
||||
return (
|
||||
<span ref={textRef} className={cn("block max-w-full", className)}>
|
||||
{text}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
const handleOpenChange = (nextOpen: boolean) => {
|
||||
if (openTimerRef.current) {
|
||||
window.clearTimeout(openTimerRef.current);
|
||||
openTimerRef.current = undefined;
|
||||
}
|
||||
if (!nextOpen) {
|
||||
setOpen(false);
|
||||
return;
|
||||
}
|
||||
if (!isTruncated) return;
|
||||
openTimerRef.current = window.setTimeout(() => {
|
||||
setOpen(true);
|
||||
openTimerRef.current = undefined;
|
||||
}, 220);
|
||||
};
|
||||
|
||||
return (
|
||||
<Tooltip open={open} onOpenChange={handleOpenChange}>
|
||||
<TooltipTrigger asChild>
|
||||
<span ref={textRef} className={cn("block max-w-full", className)}>
|
||||
{text}
|
||||
</span>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="top" align="start">
|
||||
{text}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
||||
// ─── Component ──────────────────────────────────────────────────────
|
||||
|
||||
interface ModelSelectorProps {
|
||||
|
|
@ -936,7 +1023,11 @@ export function ModelSelector({
|
|||
{/* Model info */}
|
||||
<div className="flex-1 min-w-0">
|
||||
<div className="flex items-center gap-1.5">
|
||||
<span className="font-medium text-sm truncate">{config.name}</span>
|
||||
<TruncatedNameWithTooltip
|
||||
text={config.name}
|
||||
enableTooltip={!isMobile}
|
||||
className="font-medium text-sm truncate"
|
||||
/>
|
||||
{isAutoMode && (
|
||||
<Badge
|
||||
variant="secondary"
|
||||
|
|
|
|||
|
|
@ -45,20 +45,21 @@ export const PublicThread: FC<PublicThreadProps> = ({ footer }) => {
|
|||
["--thread-max-width" as string]: "44rem",
|
||||
}}
|
||||
>
|
||||
<ThreadPrimitive.Viewport className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4">
|
||||
<ThreadPrimitive.Viewport
|
||||
scrollToBottomOnInitialize
|
||||
scrollToBottomOnThreadSwitch
|
||||
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4 pb-6"
|
||||
>
|
||||
<ThreadPrimitive.Messages
|
||||
components={{
|
||||
UserMessage: PublicUserMessage,
|
||||
AssistantMessage: PublicAssistantMessage,
|
||||
}}
|
||||
/>
|
||||
|
||||
{/* Spacer to ensure footer doesn't overlap last message */}
|
||||
<div className="h-24" />
|
||||
</ThreadPrimitive.Viewport>
|
||||
|
||||
{footer && (
|
||||
<div className="sticky bottom-0 z-20 border-t bg-main-panel/95 backdrop-blur supports-backdrop-filter:bg-main-panel/60">
|
||||
<div className="border-t bg-main-panel/95 backdrop-blur supports-backdrop-filter:bg-main-panel/60">
|
||||
{footer}
|
||||
</div>
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -147,6 +147,22 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError
|
|||
};
|
||||
}
|
||||
|
||||
if (
|
||||
errorCode === "TURN_CANCELLING"
|
||||
) {
|
||||
return {
|
||||
kind: "thread_busy",
|
||||
channel: "toast",
|
||||
severity: "info",
|
||||
telemetryEvent: "chat_blocked",
|
||||
isExpected: true,
|
||||
userMessage: "A previous response is still stopping. Please try again in a moment.",
|
||||
rawMessage,
|
||||
errorCode: errorCode ?? "TURN_CANCELLING",
|
||||
details: { flow: input.flow },
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
errorCode === "THREAD_BUSY"
|
||||
) {
|
||||
|
|
@ -156,7 +172,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError
|
|||
severity: "warn",
|
||||
telemetryEvent: "chat_blocked",
|
||||
isExpected: true,
|
||||
userMessage: "A previous response is still stopping. Please try again in a moment.",
|
||||
userMessage: "Another response is still finishing for this thread. Please try again in a moment.",
|
||||
rawMessage,
|
||||
errorCode: errorCode ?? "THREAD_BUSY",
|
||||
details: { flow: input.flow },
|
||||
|
|
|
|||
114
surfsense_web/lib/chat/chat-request-errors.ts
Normal file
114
surfsense_web/lib/chat/chat-request-errors.ts
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
export async function toHttpResponseError(
|
||||
response: Response
|
||||
): Promise<Error & { errorCode?: string; retryAfterMs?: number }> {
|
||||
const statusDefaultCode =
|
||||
response.status === 409
|
||||
? "THREAD_BUSY"
|
||||
: response.status === 429
|
||||
? "RATE_LIMITED"
|
||||
: response.status === 401 || response.status === 403
|
||||
? "AUTH_EXPIRED"
|
||||
: "SERVER_ERROR";
|
||||
|
||||
let rawBody = "";
|
||||
try {
|
||||
rawBody = await response.text();
|
||||
} catch {
|
||||
// noop
|
||||
}
|
||||
|
||||
let parsedBody: Record<string, unknown> | null = null;
|
||||
if (rawBody) {
|
||||
try {
|
||||
const parsed = JSON.parse(rawBody);
|
||||
if (typeof parsed === "object" && parsed !== null) {
|
||||
parsedBody = parsed as Record<string, unknown>;
|
||||
}
|
||||
} catch {
|
||||
// noop
|
||||
}
|
||||
}
|
||||
|
||||
const detail = parsedBody?.detail;
|
||||
const detailObject =
|
||||
typeof detail === "object" && detail !== null ? (detail as Record<string, unknown>) : null;
|
||||
const detailMessage = typeof detail === "string" ? detail : undefined;
|
||||
const topLevelMessage =
|
||||
typeof parsedBody?.message === "string" ? (parsedBody.message as string) : undefined;
|
||||
const detailNestedMessage =
|
||||
typeof detailObject?.message === "string" ? (detailObject.message as string) : undefined;
|
||||
|
||||
const topLevelCode =
|
||||
typeof parsedBody?.errorCode === "string"
|
||||
? parsedBody.errorCode
|
||||
: typeof parsedBody?.error_code === "string"
|
||||
? parsedBody.error_code
|
||||
: undefined;
|
||||
const detailCode =
|
||||
typeof detailObject?.errorCode === "string"
|
||||
? detailObject.errorCode
|
||||
: typeof detailObject?.error_code === "string"
|
||||
? detailObject.error_code
|
||||
: undefined;
|
||||
|
||||
const errorCode = detailCode ?? topLevelCode ?? statusDefaultCode;
|
||||
|
||||
const detailRetryAfterMs =
|
||||
typeof detailObject?.retry_after_ms === "number"
|
||||
? detailObject.retry_after_ms
|
||||
: typeof detailObject?.retryAfterMs === "number"
|
||||
? detailObject.retryAfterMs
|
||||
: undefined;
|
||||
const topRetryAfterMs =
|
||||
typeof parsedBody?.retry_after_ms === "number"
|
||||
? parsedBody.retry_after_ms
|
||||
: typeof parsedBody?.retryAfterMs === "number"
|
||||
? parsedBody.retryAfterMs
|
||||
: undefined;
|
||||
const headerRetryAfterMsRaw = response.headers.get("retry-after-ms");
|
||||
const headerRetryAfterMs = headerRetryAfterMsRaw ? Number.parseFloat(headerRetryAfterMsRaw) : NaN;
|
||||
const retryAfterHeader = response.headers.get("retry-after");
|
||||
const retryAfterSeconds = retryAfterHeader ? Number.parseFloat(retryAfterHeader) : NaN;
|
||||
const retryAfterMsFromHeader = Number.isFinite(headerRetryAfterMs)
|
||||
? Math.max(0, Math.round(headerRetryAfterMs))
|
||||
: Number.isFinite(retryAfterSeconds)
|
||||
? Math.max(0, Math.round(retryAfterSeconds * 1000))
|
||||
: undefined;
|
||||
const retryAfterMs =
|
||||
detailRetryAfterMs ?? topRetryAfterMs ?? retryAfterMsFromHeader ?? undefined;
|
||||
const message =
|
||||
detailNestedMessage ??
|
||||
detailMessage ??
|
||||
topLevelMessage ??
|
||||
`Backend error: ${response.status}`;
|
||||
|
||||
return Object.assign(new Error(message), { errorCode, retryAfterMs });
|
||||
}
|
||||
|
||||
export function tagPreAcceptSendFailure(error: unknown): unknown {
|
||||
if (error instanceof Error) {
|
||||
const withCode = error as Error & { errorCode?: string; code?: string };
|
||||
const existingCode = withCode.errorCode ?? withCode.code;
|
||||
const passthroughCodes = new Set([
|
||||
"PREMIUM_QUOTA_EXHAUSTED",
|
||||
"THREAD_BUSY",
|
||||
"TURN_CANCELLING",
|
||||
"AUTH_EXPIRED",
|
||||
"UNAUTHORIZED",
|
||||
"RATE_LIMITED",
|
||||
"NETWORK_ERROR",
|
||||
"STREAM_PARSE_ERROR",
|
||||
"TOOL_EXECUTION_ERROR",
|
||||
"PERSIST_MESSAGE_FAILED",
|
||||
"SERVER_ERROR",
|
||||
]);
|
||||
if (existingCode && passthroughCodes.has(existingCode)) {
|
||||
return Object.assign(error, { errorCode: existingCode });
|
||||
}
|
||||
return Object.assign(error, { errorCode: "SEND_FAILED_PRE_ACCEPT" });
|
||||
}
|
||||
|
||||
return Object.assign(new Error("Failed to send message before stream acceptance"), {
|
||||
errorCode: "SEND_FAILED_PRE_ACCEPT",
|
||||
});
|
||||
}
|
||||
54
surfsense_web/lib/chat/parse-mention-segments.ts
Normal file
54
surfsense_web/lib/chat/parse-mention-segments.ts
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
import type { MentionedDocumentInfo } from "@/atoms/chat/mentioned-documents.atom";
|
||||
|
||||
export type MentionSegment =
|
||||
| { type: "text"; value: string; start: number }
|
||||
| { type: "mention"; doc: MentionedDocumentInfo; start: number };
|
||||
|
||||
/**
|
||||
* Tokenizes a user message into text and `@mention` segments.
|
||||
*
|
||||
* Pure: no React, no DOM, no side effects. Safe to unit-test and reuse.
|
||||
*
|
||||
* Mentions are matched greedily by longest title first so that a longer title
|
||||
* (e.g. `@Project Roadmap`) is never shadowed by a shorter prefix
|
||||
* (e.g. `@Project`).
|
||||
*/
|
||||
export function parseMentionSegments(
|
||||
text: string,
|
||||
docs: ReadonlyArray<MentionedDocumentInfo>
|
||||
): MentionSegment[] {
|
||||
if (text.length === 0) return [];
|
||||
if (docs.length === 0) return [{ type: "text", value: text, start: 0 }];
|
||||
|
||||
const tokens = docs
|
||||
.map((doc) => ({ doc, token: `@${doc.title}` }))
|
||||
.sort((a, b) => b.token.length - a.token.length);
|
||||
|
||||
const segments: MentionSegment[] = [];
|
||||
let i = 0;
|
||||
let buffer = "";
|
||||
let bufferStart = 0;
|
||||
|
||||
while (i < text.length) {
|
||||
const tokenMatch = tokens.find(({ token }) => text.startsWith(token, i));
|
||||
if (tokenMatch) {
|
||||
if (buffer) {
|
||||
segments.push({ type: "text", value: buffer, start: bufferStart });
|
||||
buffer = "";
|
||||
}
|
||||
segments.push({ type: "mention", doc: tokenMatch.doc, start: i });
|
||||
i += tokenMatch.token.length;
|
||||
bufferStart = i;
|
||||
continue;
|
||||
}
|
||||
if (!buffer) bufferStart = i;
|
||||
buffer += text[i];
|
||||
i += 1;
|
||||
}
|
||||
|
||||
if (buffer) {
|
||||
segments.push({ type: "text", value: buffer, start: bufferStart });
|
||||
}
|
||||
|
||||
return segments;
|
||||
}
|
||||
19
surfsense_web/lib/chat/stream-flush.ts
Normal file
19
surfsense_web/lib/chat/stream-flush.ts
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
import { FrameBatchedUpdater } from "@/lib/chat/streaming-state";
|
||||
|
||||
export function createStreamFlushHelpers(flushMessages: () => void): {
|
||||
batcher: FrameBatchedUpdater;
|
||||
scheduleFlush: () => void;
|
||||
forceFlush: () => void;
|
||||
} {
|
||||
const batcher = new FrameBatchedUpdater();
|
||||
const scheduleFlush = () => batcher.schedule(flushMessages);
|
||||
// Force-flush helper: ``batcher.flush()`` is a no-op when
|
||||
// ``dirty=false`` (e.g. a tool starts before any text streamed).
|
||||
// ``scheduleFlush(); batcher.flush()`` sets the dirty bit first so
|
||||
// terminal events render promptly without the throttle delay.
|
||||
const forceFlush = () => {
|
||||
scheduleFlush();
|
||||
batcher.flush();
|
||||
};
|
||||
return { batcher, scheduleFlush, forceFlush };
|
||||
}
|
||||
196
surfsense_web/lib/chat/stream-pipeline.ts
Normal file
196
surfsense_web/lib/chat/stream-pipeline.ts
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
import {
|
||||
addStepSeparator,
|
||||
addToolCall,
|
||||
appendReasoning,
|
||||
appendText,
|
||||
appendToolInputDelta,
|
||||
type ContentPartsState,
|
||||
endReasoning,
|
||||
readSSEStream,
|
||||
type SSEEvent,
|
||||
type ThinkingStepData,
|
||||
type ToolUIGate,
|
||||
updateThinkingSteps,
|
||||
updateToolCall,
|
||||
} from "@/lib/chat/streaming-state";
|
||||
|
||||
export type SharedStreamEventContext = {
|
||||
contentPartsState: ContentPartsState;
|
||||
toolsWithUI: ToolUIGate;
|
||||
currentThinkingSteps: Map<string, ThinkingStepData>;
|
||||
scheduleFlush: () => void;
|
||||
forceFlush: () => void;
|
||||
onTokenUsage?: (data: Extract<SSEEvent, { type: "data-token-usage" }>["data"]) => void;
|
||||
onTurnStatus?: (data: Extract<SSEEvent, { type: "data-turn-status" }>["data"]) => void;
|
||||
onToolOutputAvailable?: (
|
||||
event: Extract<SSEEvent, { type: "tool-output-available" }>,
|
||||
context: {
|
||||
contentPartsState: ContentPartsState;
|
||||
toolCallIndices: Map<string, number>;
|
||||
}
|
||||
) => void;
|
||||
};
|
||||
|
||||
/**
|
||||
* After a tool produces output, mark any previously-decided interrupt tool
|
||||
* calls as completed so the ApprovalCard can transition from shimmer to done.
|
||||
*/
|
||||
export function markInterruptsCompleted(
|
||||
contentParts: Array<{ type: string; result?: unknown }>
|
||||
): void {
|
||||
for (const part of contentParts) {
|
||||
if (
|
||||
part.type === "tool-call" &&
|
||||
typeof part.result === "object" &&
|
||||
part.result !== null &&
|
||||
(part.result as Record<string, unknown>).__interrupt__ === true &&
|
||||
(part.result as Record<string, unknown>).__decided__ &&
|
||||
!(part.result as Record<string, unknown>).__completed__
|
||||
) {
|
||||
part.result = { ...(part.result as Record<string, unknown>), __completed__: true };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function hasPersistableContent(
|
||||
contentParts: ContentPartsState["contentParts"],
|
||||
toolsWithUI: ToolUIGate
|
||||
) {
|
||||
return contentParts.some(
|
||||
(part) =>
|
||||
(part.type === "text" && part.text.length > 0) ||
|
||||
(part.type === "reasoning" && part.text.length > 0) ||
|
||||
(part.type === "tool-call" && (toolsWithUI === "all" || toolsWithUI.has(part.toolName)))
|
||||
);
|
||||
}
|
||||
|
||||
function toStreamTerminalError(
|
||||
event: Extract<SSEEvent, { type: "error" }>
|
||||
): Error & { errorCode?: string } {
|
||||
return Object.assign(new Error(event.errorText || "Server error"), {
|
||||
errorCode: event.errorCode,
|
||||
});
|
||||
}
|
||||
|
||||
export function processSharedStreamEvent(parsed: SSEEvent, context: SharedStreamEventContext): boolean {
|
||||
const { contentPartsState, toolsWithUI, currentThinkingSteps, scheduleFlush, forceFlush } = context;
|
||||
const { contentParts, toolCallIndices } = contentPartsState;
|
||||
|
||||
switch (parsed.type) {
|
||||
case "text-delta":
|
||||
appendText(contentPartsState, parsed.delta);
|
||||
scheduleFlush();
|
||||
return true;
|
||||
|
||||
case "reasoning-delta":
|
||||
appendReasoning(contentPartsState, parsed.delta);
|
||||
scheduleFlush();
|
||||
return true;
|
||||
|
||||
case "reasoning-end":
|
||||
endReasoning(contentPartsState);
|
||||
scheduleFlush();
|
||||
return true;
|
||||
|
||||
case "start-step":
|
||||
addStepSeparator(contentPartsState);
|
||||
scheduleFlush();
|
||||
return true;
|
||||
|
||||
case "finish-step":
|
||||
return true;
|
||||
|
||||
case "tool-input-start":
|
||||
addToolCall(
|
||||
contentPartsState,
|
||||
toolsWithUI,
|
||||
parsed.toolCallId,
|
||||
parsed.toolName,
|
||||
{},
|
||||
false,
|
||||
parsed.langchainToolCallId
|
||||
);
|
||||
forceFlush();
|
||||
return true;
|
||||
|
||||
case "tool-input-delta":
|
||||
// High-frequency event: deltas can fire dozens of times per call,
|
||||
// so use throttled scheduleFlush (NOT forceFlush) to coalesce.
|
||||
appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta);
|
||||
scheduleFlush();
|
||||
return true;
|
||||
|
||||
case "tool-input-available": {
|
||||
const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2);
|
||||
if (toolCallIndices.has(parsed.toolCallId)) {
|
||||
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||
args: parsed.input || {},
|
||||
argsText: finalArgsText,
|
||||
langchainToolCallId: parsed.langchainToolCallId,
|
||||
});
|
||||
} else {
|
||||
addToolCall(
|
||||
contentPartsState,
|
||||
toolsWithUI,
|
||||
parsed.toolCallId,
|
||||
parsed.toolName,
|
||||
parsed.input || {},
|
||||
false,
|
||||
parsed.langchainToolCallId
|
||||
);
|
||||
// addToolCall doesn't accept argsText today; backfill via
|
||||
// updateToolCall so the new card renders pretty-printed JSON.
|
||||
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||
argsText: finalArgsText,
|
||||
});
|
||||
}
|
||||
forceFlush();
|
||||
return true;
|
||||
}
|
||||
|
||||
case "tool-output-available":
|
||||
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||
result: parsed.output,
|
||||
langchainToolCallId: parsed.langchainToolCallId,
|
||||
});
|
||||
markInterruptsCompleted(contentParts);
|
||||
context.onToolOutputAvailable?.(parsed, { contentPartsState, toolCallIndices });
|
||||
forceFlush();
|
||||
return true;
|
||||
|
||||
case "data-thinking-step": {
|
||||
const stepData = parsed.data as ThinkingStepData;
|
||||
if (stepData?.id) {
|
||||
currentThinkingSteps.set(stepData.id, stepData);
|
||||
const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps);
|
||||
if (didUpdate) {
|
||||
scheduleFlush();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
case "data-token-usage":
|
||||
context.onTokenUsage?.(parsed.data);
|
||||
return true;
|
||||
|
||||
case "data-turn-status":
|
||||
context.onTurnStatus?.(parsed.data);
|
||||
return true;
|
||||
|
||||
case "error":
|
||||
throw toStreamTerminalError(parsed);
|
||||
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export async function consumeSseEvents(
|
||||
response: Response,
|
||||
onEvent: (event: SSEEvent) => void | Promise<void>
|
||||
): Promise<void> {
|
||||
for await (const parsed of readSSEStream(response)) {
|
||||
await onEvent(parsed);
|
||||
}
|
||||
}
|
||||
127
surfsense_web/lib/chat/stream-side-effects.ts
Normal file
127
surfsense_web/lib/chat/stream-side-effects.ts
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
import type { ThreadMessageLike } from "@assistant-ui/react";
|
||||
import {
|
||||
addToolCall,
|
||||
type ContentPartsState,
|
||||
type ToolUIGate,
|
||||
updateToolCall,
|
||||
} from "@/lib/chat/streaming-state";
|
||||
|
||||
type InterruptActionRequest = {
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
};
|
||||
|
||||
export type EditedInterruptAction = {
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
};
|
||||
|
||||
function readInterruptActions(
|
||||
interruptData: Record<string, unknown>
|
||||
): InterruptActionRequest[] {
|
||||
return (interruptData.action_requests ?? []) as InterruptActionRequest[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies an interrupt request payload to tool-call parts. Existing tool cards
|
||||
* are updated in-place; missing ones are upserted so approval UI always shows.
|
||||
*/
|
||||
export function applyInterruptRequestToContentParts(
|
||||
contentPartsState: ContentPartsState,
|
||||
toolsWithUI: ToolUIGate,
|
||||
interruptData: Record<string, unknown>
|
||||
): void {
|
||||
const { contentParts, toolCallIndices } = contentPartsState;
|
||||
const actionRequests = readInterruptActions(interruptData);
|
||||
for (const action of actionRequests) {
|
||||
const existingEntry = Array.from(toolCallIndices.entries()).find(([, idx]) => {
|
||||
const part = contentParts[idx];
|
||||
return part?.type === "tool-call" && part.toolName === action.name;
|
||||
});
|
||||
|
||||
if (existingEntry) {
|
||||
updateToolCall(contentPartsState, existingEntry[0], {
|
||||
result: { __interrupt__: true, ...interruptData },
|
||||
});
|
||||
} else {
|
||||
const toolCallId = `interrupt-${action.name}`;
|
||||
addToolCall(contentPartsState, toolsWithUI, toolCallId, action.name, action.args, true);
|
||||
updateToolCall(contentPartsState, toolCallId, {
|
||||
result: { __interrupt__: true, ...interruptData },
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function mergeEditedInterruptAction(
|
||||
contentParts: ContentPartsState["contentParts"],
|
||||
editedAction: EditedInterruptAction | undefined
|
||||
): void {
|
||||
if (!editedAction) return;
|
||||
for (const part of contentParts) {
|
||||
if (part.type === "tool-call" && part.toolName === editedAction.name) {
|
||||
const mergedArgs = { ...part.args, ...editedAction.args };
|
||||
part.args = mergedArgs;
|
||||
// assistant-ui prefers argsText over JSON.stringify(args)
|
||||
part.argsText = JSON.stringify(mergedArgs, null, 2);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function markInterruptDecisionOnContentParts(
|
||||
contentParts: ContentPartsState["contentParts"],
|
||||
decisionType: "approve" | "reject" | undefined
|
||||
): void {
|
||||
if (!decisionType) return;
|
||||
for (const part of contentParts) {
|
||||
if (
|
||||
part.type === "tool-call" &&
|
||||
typeof part.result === "object" &&
|
||||
part.result !== null &&
|
||||
"__interrupt__" in (part.result as Record<string, unknown>)
|
||||
) {
|
||||
part.result = {
|
||||
...(part.result as Record<string, unknown>),
|
||||
__decided__: decisionType,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* When a streamed message is persisted, the backend returns the durable
|
||||
* turn_id; merge it into assistant-ui metadata for turn-scoped actions.
|
||||
*/
|
||||
export function mergeChatTurnIdIntoMessage(
|
||||
msg: ThreadMessageLike,
|
||||
turnId: string | null | undefined
|
||||
): ThreadMessageLike {
|
||||
if (!turnId) return msg;
|
||||
const existingMeta = (msg.metadata ?? {}) as { custom?: Record<string, unknown> };
|
||||
const existingCustom = existingMeta.custom ?? {};
|
||||
if ((existingCustom as { chatTurnId?: string }).chatTurnId === turnId) return msg;
|
||||
return {
|
||||
...msg,
|
||||
metadata: {
|
||||
...existingMeta,
|
||||
custom: { ...existingCustom, chatTurnId: turnId },
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export function readStreamedChatTurnId(data: unknown): string | null {
|
||||
if (typeof data !== "object" || data === null) return null;
|
||||
const value = (data as { chat_turn_id?: unknown }).chat_turn_id;
|
||||
return typeof value === "string" && value.length > 0 ? value : null;
|
||||
}
|
||||
|
||||
export function applyTurnIdToAssistantMessageList(
|
||||
messages: ThreadMessageLike[],
|
||||
assistantMsgId: string,
|
||||
turnId: string
|
||||
): ThreadMessageLike[] {
|
||||
return messages.map((m) =>
|
||||
m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, turnId) : m
|
||||
);
|
||||
}
|
||||
|
|
@ -528,6 +528,14 @@ export type SSEEvent =
|
|||
}>;
|
||||
};
|
||||
}
|
||||
| {
|
||||
type: "data-turn-status";
|
||||
data: {
|
||||
status: "idle" | "busy" | "cancelling";
|
||||
retry_after_ms?: number;
|
||||
retry_after_at?: number;
|
||||
};
|
||||
}
|
||||
| {
|
||||
type: "data-token-usage";
|
||||
data: {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue