From 345cb88224803a5fb0b18340882409a046e02ff9 Mon Sep 17 00:00:00 2001 From: guangyang1206 Date: Wed, 29 Apr 2026 12:14:08 +0800 Subject: [PATCH 01/68] refactor(settings): use key prop to reset LLM role manager form state MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #1018 Remove the sync useEffect that copied preferences into local state, along with the savingRef guard that prevented mid-save overwrites. Instead, pass key={searchSpaceId} on the LLMRoleManager component so React remounts the form with correct initial state whenever the search space changes — no extra re-render, no effect dependency array. Changes: - llm-role-manager.tsx: remove useEffect + useRef + savingRef pattern; drop useEffect and useRef from imports (now only useCallback, useState) - search-space-settings-dialog.tsx: add key={searchSpaceId} to so the component remounts on search-space change Before: useEffect synced preferences → assignments on each preference update, with savingRef to avoid overwriting an in-flight save. After: React remounts the component with correct initial state from the preferences selector; no mid-save race possible. --- .../components/settings/llm-role-manager.tsx | 21 +------------------ .../settings/search-space-settings-dialog.tsx | 2 +- 2 files changed, 2 insertions(+), 21 deletions(-) diff --git a/surfsense_web/components/settings/llm-role-manager.tsx b/surfsense_web/components/settings/llm-role-manager.tsx index 015027111..e21dc9028 100644 --- a/surfsense_web/components/settings/llm-role-manager.tsx +++ b/surfsense_web/components/settings/llm-role-manager.tsx @@ -11,7 +11,7 @@ import { RefreshCw, ScanEye, } from "lucide-react"; -import { useCallback, useEffect, useRef, useState } from "react"; +import { useCallback, useState } from "react"; import { toast } from "sonner"; import { globalImageGenConfigsAtom, @@ -143,23 +143,6 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) { })); const [savingRole, setSavingRole] = useState(null); - const savingRef = useRef(false); - - useEffect(() => { - if (!savingRef.current) { - setAssignments({ - agent_llm_id: preferences.agent_llm_id ?? "", - document_summary_llm_id: preferences.document_summary_llm_id ?? "", - image_generation_config_id: preferences.image_generation_config_id ?? "", - vision_llm_config_id: preferences.vision_llm_config_id ?? "", - }); - } - }, [ - preferences?.agent_llm_id, - preferences?.document_summary_llm_id, - preferences?.image_generation_config_id, - preferences?.vision_llm_config_id, - ]); const handleRoleAssignment = useCallback( async (prefKey: string, configId: string) => { @@ -167,7 +150,6 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) { setAssignments((prev) => ({ ...prev, [prefKey]: value })); setSavingRole(prefKey); - savingRef.current = true; try { await updatePreferences({ @@ -177,7 +159,6 @@ export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) { toast.success("Role assignment updated"); } finally { setSavingRole(null); - savingRef.current = false; } }, [updatePreferences, searchSpaceId] diff --git a/surfsense_web/components/settings/search-space-settings-dialog.tsx b/surfsense_web/components/settings/search-space-settings-dialog.tsx index aefe1efd2..2a7ba82b6 100644 --- a/surfsense_web/components/settings/search-space-settings-dialog.tsx +++ b/surfsense_web/components/settings/search-space-settings-dialog.tsx @@ -116,7 +116,7 @@ export function SearchSpaceSettingsDialog({ searchSpaceId }: SearchSpaceSettings const content: Record = { general: , models: , - roles: , + roles: , "image-models": , "vision-models": , "team-roles": , From 57db198919bbd1e7da8d8364aa90eba01525e7d0 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 29 Apr 2026 19:14:56 +0530 Subject: [PATCH 02/68] feat(chat): add thread-level auto model pinning fields --- ...34_add_thread_auto_model_pinning_fields.py | 63 +++++++++++++++++++ surfsense_backend/app/db.py | 7 +++ 2 files changed, 70 insertions(+) create mode 100644 surfsense_backend/alembic/versions/134_add_thread_auto_model_pinning_fields.py diff --git a/surfsense_backend/alembic/versions/134_add_thread_auto_model_pinning_fields.py b/surfsense_backend/alembic/versions/134_add_thread_auto_model_pinning_fields.py new file mode 100644 index 000000000..ab1643b02 --- /dev/null +++ b/surfsense_backend/alembic/versions/134_add_thread_auto_model_pinning_fields.py @@ -0,0 +1,63 @@ +"""134_add_thread_auto_model_pinning_fields + +Revision ID: 134 +Revises: 133 +Create Date: 2026-04-29 + +Add thread-level fields to persist Auto (Fastest) model pinning metadata: +- pinned_llm_config_id: concrete resolved config id used for this thread +- pinned_auto_mode: auto policy identifier (currently "auto_fastest") +- pinned_at: timestamp when the pin was created/refreshed +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "134" +down_revision: str | None = "133" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.add_column( + "new_chat_threads", + sa.Column("pinned_llm_config_id", sa.Integer(), nullable=True), + ) + op.add_column( + "new_chat_threads", + sa.Column("pinned_auto_mode", sa.String(length=32), nullable=True), + ) + op.add_column( + "new_chat_threads", + sa.Column("pinned_at", sa.TIMESTAMP(timezone=True), nullable=True), + ) + + op.create_index( + "ix_new_chat_threads_pinned_llm_config_id", + "new_chat_threads", + ["pinned_llm_config_id"], + unique=False, + ) + op.create_index( + "ix_new_chat_threads_pinned_auto_mode", + "new_chat_threads", + ["pinned_auto_mode"], + unique=False, + ) + + +def downgrade() -> None: + op.drop_index("ix_new_chat_threads_pinned_auto_mode", table_name="new_chat_threads") + op.drop_index( + "ix_new_chat_threads_pinned_llm_config_id", table_name="new_chat_threads" + ) + + op.drop_column("new_chat_threads", "pinned_at") + op.drop_column("new_chat_threads", "pinned_auto_mode") + op.drop_column("new_chat_threads", "pinned_llm_config_id") diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 75342a8e1..f8b1390d9 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -638,6 +638,13 @@ class NewChatThread(BaseModel, TimestampMixin): default=False, server_default="false", ) + # Auto model pinning metadata: + # - pinned_llm_config_id stores the concrete resolved model config id. + # - pinned_auto_mode indicates which auto policy produced the pin. + # This allows Auto (Fastest) to resolve once per thread and stay stable. + pinned_llm_config_id = Column(Integer, nullable=True, index=True) + pinned_auto_mode = Column(String(32), nullable=True, index=True) + pinned_at = Column(TIMESTAMP(timezone=True), nullable=True) # Relationships search_space = relationship("SearchSpace", back_populates="new_chat_threads") From 41849fe10f5fbe9a4792ad665308f5fea4c37721 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 29 Apr 2026 19:15:15 +0530 Subject: [PATCH 03/68] feat(chat): add auto model pin resolution service --- .../app/services/auto_model_pin_service.py | 205 ++++++++++++ .../services/test_auto_model_pin_service.py | 291 ++++++++++++++++++ 2 files changed, 496 insertions(+) create mode 100644 surfsense_backend/app/services/auto_model_pin_service.py create mode 100644 surfsense_backend/tests/unit/services/test_auto_model_pin_service.py diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py new file mode 100644 index 000000000..ce417a26d --- /dev/null +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -0,0 +1,205 @@ +"""Resolve and persist Auto (Fastest) model pins per chat thread. + +Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we +resolve that virtual mode to one concrete global LLM config exactly once and +persist the chosen config id on ``new_chat_threads`` so subsequent turns are +stable. +""" + +from __future__ import annotations + +import hashlib +import logging +from dataclasses import dataclass +from datetime import UTC, datetime +from uuid import UUID + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import config +from app.db import NewChatThread +from app.services.token_quota_service import TokenQuotaService + +logger = logging.getLogger(__name__) + +AUTO_FASTEST_ID = 0 +AUTO_FASTEST_MODE = "auto_fastest" + + +@dataclass +class AutoPinResolution: + resolved_llm_config_id: int + resolved_tier: str + from_existing_pin: bool + + +def _is_usable_global_config(cfg: dict) -> bool: + return bool( + cfg.get("id") is not None + and cfg.get("model_name") + and cfg.get("provider") + and cfg.get("api_key") + ) + + +def _global_candidates() -> list[dict]: + candidates = [cfg for cfg in config.GLOBAL_LLM_CONFIGS if _is_usable_global_config(cfg)] + return sorted(candidates, key=lambda c: int(c.get("id", 0))) + + +def _tier_of(cfg: dict) -> str: + return str(cfg.get("billing_tier", "free")).lower() + + +def _deterministic_pick(candidates: list[dict], thread_id: int) -> dict: + digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest() + idx = int.from_bytes(digest[:8], "big") % len(candidates) + return candidates[idx] + + +def _to_uuid(user_id: str | UUID | None) -> UUID | None: + if user_id is None: + return None + if isinstance(user_id, UUID): + return user_id + try: + return UUID(str(user_id)) + except Exception: + return None + + +async def _is_premium_eligible(session: AsyncSession, user_id: str | UUID | None) -> bool: + parsed = _to_uuid(user_id) + if parsed is None: + return False + usage = await TokenQuotaService.premium_get_usage(session, parsed) + return bool(usage.allowed) + + +async def resolve_or_get_pinned_llm_config_id( + session: AsyncSession, + *, + thread_id: int, + search_space_id: int, + user_id: str | UUID | None, + selected_llm_config_id: int, +) -> AutoPinResolution: + """Resolve Auto (Fastest) to one concrete config id and persist pin metadata. + + For non-auto selections, this function clears existing auto pin metadata and + returns the selected id as-is. + """ + thread = ( + ( + await session.execute( + select(NewChatThread) + .where(NewChatThread.id == thread_id) + .with_for_update(of=NewChatThread) + ) + ) + .unique() + .scalar_one_or_none() + ) + if thread is None: + raise ValueError(f"Thread {thread_id} not found") + if thread.search_space_id != search_space_id: + raise ValueError( + f"Thread {thread_id} does not belong to search space {search_space_id}" + ) + + # Explicit model selected: clear stale auto pin metadata. + if selected_llm_config_id != AUTO_FASTEST_ID: + if ( + thread.pinned_llm_config_id is not None + or thread.pinned_auto_mode is not None + or thread.pinned_at is not None + ): + thread.pinned_llm_config_id = None + thread.pinned_auto_mode = None + thread.pinned_at = None + await session.commit() + return AutoPinResolution( + resolved_llm_config_id=selected_llm_config_id, + resolved_tier="explicit", + from_existing_pin=False, + ) + + candidates = _global_candidates() + if not candidates: + raise ValueError("No usable global LLM configs are available for Auto mode") + candidate_by_id = {int(c["id"]): c for c in candidates} + + # Reuse existing valid pin without re-checking current quota (no silent tier switch). + pinned_id = thread.pinned_llm_config_id + if ( + thread.pinned_auto_mode == AUTO_FASTEST_MODE + and pinned_id is not None + and int(pinned_id) in candidate_by_id + ): + pinned_cfg = candidate_by_id[int(pinned_id)] + logger.info( + "auto_pin_reused thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s", + thread_id, + search_space_id, + pinned_id, + _tier_of(pinned_cfg), + ) + return AutoPinResolution( + resolved_llm_config_id=int(pinned_id), + resolved_tier=_tier_of(pinned_cfg), + from_existing_pin=True, + ) + if pinned_id is not None: + logger.info( + "auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s pinned_auto_mode=%s", + thread_id, + search_space_id, + pinned_id, + thread.pinned_auto_mode, + ) + + premium_eligible = await _is_premium_eligible(session, user_id) + if premium_eligible: + eligible = candidates + else: + eligible = [c for c in candidates if _tier_of(c) != "premium"] + + if not eligible: + raise ValueError( + "Auto mode could not find an eligible LLM config for this user and quota state" + ) + + selected_cfg = _deterministic_pick(eligible, thread_id) + selected_id = int(selected_cfg["id"]) + selected_tier = _tier_of(selected_cfg) + + thread.pinned_llm_config_id = selected_id + thread.pinned_auto_mode = AUTO_FASTEST_MODE + thread.pinned_at = datetime.now(UTC) + await session.commit() + + if pinned_id is None: + logger.info( + "auto_pin_created thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s premium_eligible=%s", + thread_id, + search_space_id, + selected_id, + selected_tier, + premium_eligible, + ) + else: + logger.info( + "auto_pin_repaired thread_id=%s search_space_id=%s previous_config_id=%s resolved_config_id=%s tier=%s premium_eligible=%s", + thread_id, + search_space_id, + pinned_id, + selected_id, + selected_tier, + premium_eligible, + ) + return AutoPinResolution( + resolved_llm_config_id=selected_id, + resolved_tier=selected_tier, + from_existing_pin=False, + ) diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py new file mode 100644 index 000000000..a9853c980 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +from dataclasses import dataclass +from types import SimpleNamespace + +import pytest + +from app.services.auto_model_pin_service import ( + AUTO_FASTEST_MODE, + resolve_or_get_pinned_llm_config_id, +) + +pytestmark = pytest.mark.unit + + +@dataclass +class _FakeQuotaResult: + allowed: bool + + +class _FakeExecResult: + def __init__(self, thread): + self._thread = thread + + def unique(self): + return self + + def scalar_one_or_none(self): + return self._thread + + +class _FakeSession: + def __init__(self, thread): + self.thread = thread + self.commit_count = 0 + + async def execute(self, _stmt): + return _FakeExecResult(self.thread) + + async def commit(self): + self.commit_count += 1 + + +def _thread( + *, + search_space_id: int = 10, + pinned_llm_config_id: int | None = None, + pinned_auto_mode: str | None = None, +): + return SimpleNamespace( + id=1, + search_space_id=search_space_id, + pinned_llm_config_id=pinned_llm_config_id, + pinned_auto_mode=pinned_auto_mode, + pinned_at=None, + ) + + +@pytest.mark.asyncio +async def test_auto_first_turn_pins_one_model(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, + {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id in {-1, -2} + assert session.thread.pinned_llm_config_id == result.resolved_llm_config_id + assert session.thread.pinned_auto_mode == AUTO_FASTEST_MODE + assert session.thread.pinned_at is not None + assert session.commit_count == 1 + + +@pytest.mark.asyncio +async def test_next_turn_reuses_existing_pin(monkeypatch): + from app.config import config + + session = _FakeSession( + _thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE) + ) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + ], + ) + + async def _must_not_call(*_args, **_kwargs): + raise AssertionError("premium_get_usage should not be called for valid pin reuse") + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _must_not_call, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.from_existing_pin is True + assert session.commit_count == 0 + + +@pytest.mark.asyncio +async def test_premium_eligible_auto_can_pin_premium(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.resolved_tier == "premium" + + +@pytest.mark.asyncio +async def test_premium_ineligible_auto_pins_free_only(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"}, + {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + assert result.resolved_tier == "free" + + +@pytest.mark.asyncio +async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): + from app.config import config + + session = _FakeSession( + _thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE) + ) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"}, + {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.from_existing_pin is True + + +@pytest.mark.asyncio +async def test_explicit_user_model_change_clears_pin(monkeypatch): + from app.config import config + + session = _FakeSession( + _thread(pinned_llm_config_id=-2, pinned_auto_mode=AUTO_FASTEST_MODE) + ) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, + ], + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=7, + ) + assert result.resolved_llm_config_id == 7 + assert session.thread.pinned_llm_config_id is None + assert session.thread.pinned_auto_mode is None + assert session.thread.pinned_at is None + assert session.commit_count == 1 + + +@pytest.mark.asyncio +async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): + from app.config import config + + session = _FakeSession( + _thread(pinned_llm_config_id=-999, pinned_auto_mode=AUTO_FASTEST_MODE) + ) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + assert session.thread.pinned_llm_config_id == -2 + assert session.commit_count == 1 From 835bd9f65df2abfd80ecd8def501b2db8595c326 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 29 Apr 2026 19:15:36 +0530 Subject: [PATCH 04/68] fix(chat): enforce pinned model quota flow and reset stale pins --- .../app/routes/search_spaces_routes.py | 25 +++- .../app/tasks/chat/stream_new_chat.py | 107 +++++++++++------- 2 files changed, 88 insertions(+), 44 deletions(-) diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 828137518..7944e7d66 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -3,7 +3,7 @@ import logging from fastapi import APIRouter, Depends, HTTPException from langchain_core.messages import HumanMessage from pydantic import BaseModel as PydanticBaseModel -from sqlalchemy import func +from sqlalchemy import func, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select @@ -15,6 +15,7 @@ from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_mem from app.config import config from app.db import ( ImageGenerationConfig, + NewChatThread, NewLLMConfig, Permission, SearchSpace, @@ -790,9 +791,31 @@ async def update_llm_preferences( # Update preferences update_data = preferences.model_dump(exclude_unset=True) + previous_agent_llm_id = search_space.agent_llm_id for key, value in update_data.items(): setattr(search_space, key, value) + agent_llm_changed = ( + "agent_llm_id" in update_data + and update_data["agent_llm_id"] != previous_agent_llm_id + ) + if agent_llm_changed: + await session.execute( + update(NewChatThread) + .where(NewChatThread.search_space_id == search_space_id) + .values( + pinned_llm_config_id=None, + pinned_auto_mode=None, + pinned_at=None, + ) + ) + logger.info( + "Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)", + search_space_id, + previous_agent_llm_id, + update_data["agent_llm_id"], + ) + await session.commit() await session.refresh(search_space) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index c254e66e2..1a56547ca 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -56,6 +56,7 @@ from app.db import ( shielded_async_session, ) from app.prompts import TITLE_GENERATION_PROMPT +from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id from app.services.chat_session_state_service import ( clear_ai_responding, set_ai_responding, @@ -1456,6 +1457,21 @@ async def stream_new_chat( agent_config: AgentConfig | None = None _t0 = time.perf_counter() + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=llm_config_id, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield streaming_service.format_error(str(pin_error)) + yield streaming_service.format_done() + return + if llm_config_id >= 0: # Positive ID: Load from NewLLMConfig database table agent_config = await load_agent_config( @@ -1491,12 +1507,11 @@ async def stream_new_chat( llm_config_id, ) - # Premium quota reservation — applies to explicitly premium configs - # AND Auto mode (which may route to premium models). + # Premium quota reservation for pinned premium model only. _needs_premium_quota = ( agent_config is not None and user_id - and (agent_config.is_premium or agent_config.is_auto_mode) + and agent_config.is_premium ) if _needs_premium_quota: import uuid as _uuid @@ -1519,16 +1534,18 @@ async def stream_new_chat( ) _premium_reserved = reserve_amount if not quota_result.allowed: - if agent_config.is_premium: - yield streaming_service.format_error( - "Premium token quota exceeded. Please purchase more tokens to continue using premium models." - ) - yield streaming_service.format_done() - return - # Auto mode: quota exhausted but we can still proceed - # (the router may pick a free model). Reset reservation. - _premium_request_id = None - _premium_reserved = 0 + logging.getLogger(__name__).info( + "premium_quota_blocked_pinned_model thread_id=%s search_space_id=%s user_id=%s resolved_config_id=%s", + chat_id, + search_space_id, + user_id, + llm_config_id, + ) + yield streaming_service.format_error( + "Premium token quota exceeded for this pinned model. Select a free model or re-select Auto (Fastest) to repin." + ) + yield streaming_service.format_done() + return if not llm: yield streaming_service.format_error("Failed to create LLM instance") @@ -1961,28 +1978,20 @@ async def stream_new_chat( ) # Finalize premium quota with actual tokens. - # For Auto mode, only count tokens from calls that used premium models. if _premium_request_id and user_id: try: from app.services.token_quota_service import TokenQuotaService - if agent_config and agent_config.is_auto_mode: - from app.services.llm_router_service import LLMRouterService - - actual_premium_tokens = LLMRouterService.compute_premium_tokens( - accumulator.calls - ) - else: - actual_premium_tokens = accumulator.grand_total - async with shielded_async_session() as quota_session: await TokenQuotaService.premium_finalize( db_session=quota_session, user_id=UUID(user_id), request_id=_premium_request_id, - actual_tokens=actual_premium_tokens, + actual_tokens=accumulator.grand_total, reserved_tokens=_premium_reserved, ) + _premium_request_id = None + _premium_reserved = 0 except Exception: logging.getLogger(__name__).warning( "Failed to finalize premium quota for user %s", @@ -2175,6 +2184,21 @@ async def stream_resume_chat( agent_config: AgentConfig | None = None _t0 = time.perf_counter() + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=llm_config_id, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield streaming_service.format_error(str(pin_error)) + yield streaming_service.format_done() + return + if llm_config_id >= 0: agent_config = await load_agent_config( session=session, @@ -2208,7 +2232,7 @@ async def stream_resume_chat( _resume_needs_premium = ( agent_config is not None and user_id - and (agent_config.is_premium or agent_config.is_auto_mode) + and agent_config.is_premium ) if _resume_needs_premium: import uuid as _uuid @@ -2231,14 +2255,18 @@ async def stream_resume_chat( ) _resume_premium_reserved = reserve_amount if not quota_result.allowed: - if agent_config.is_premium: - yield streaming_service.format_error( - "Premium token quota exceeded. Please purchase more tokens to continue using premium models." - ) - yield streaming_service.format_done() - return - _resume_premium_request_id = None - _resume_premium_reserved = 0 + logging.getLogger(__name__).info( + "premium_quota_blocked_pinned_model thread_id=%s search_space_id=%s user_id=%s resolved_config_id=%s", + chat_id, + search_space_id, + user_id, + llm_config_id, + ) + yield streaming_service.format_error( + "Premium token quota exceeded for this pinned model. Select a free model or re-select Auto (Fastest) to repin." + ) + yield streaming_service.format_done() + return if not llm: yield streaming_service.format_error("Failed to create LLM instance") @@ -2370,23 +2398,16 @@ async def stream_resume_chat( try: from app.services.token_quota_service import TokenQuotaService - if agent_config and agent_config.is_auto_mode: - from app.services.llm_router_service import LLMRouterService - - actual_premium_tokens = LLMRouterService.compute_premium_tokens( - accumulator.calls - ) - else: - actual_premium_tokens = accumulator.grand_total - async with shielded_async_session() as quota_session: await TokenQuotaService.premium_finalize( db_session=quota_session, user_id=UUID(user_id), request_id=_resume_premium_request_id, - actual_tokens=actual_premium_tokens, + actual_tokens=accumulator.grand_total, reserved_tokens=_resume_premium_reserved, ) + _resume_premium_request_id = None + _resume_premium_reserved = 0 except Exception: logging.getLogger(__name__).warning( "Failed to finalize premium quota for user %s (resume)", From d5ef0d2598573578d3abf0140c58da6d4e63401d Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 29 Apr 2026 19:15:46 +0530 Subject: [PATCH 05/68] feat(ui): surface pinned premium quota alerts in chat thread --- .../new-chat/[[...chat_id]]/page.tsx | 81 +++++++++++++++++-- .../atoms/chat/premium-alert.atom.ts | 33 ++++++++ .../components/assistant-ui/thread.tsx | 44 +++++++++- 3 files changed, 148 insertions(+), 10 deletions(-) create mode 100644 surfsense_web/atoms/chat/premium-alert.atom.ts diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 7773a438a..a5461e17f 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -19,6 +19,7 @@ import { currentThreadAtom, setTargetCommentIdAtom, } from "@/atoms/chat/current-thread.atom"; +import { setPremiumAlertForThreadAtom } from "@/atoms/chat/premium-alert.atom"; import { type MentionedDocumentInfo, mentionedDocumentIdsAtom, @@ -200,6 +201,19 @@ const BASE_TOOLS_WITH_UI = new Set([ // "write_todos", // Disabled for now ]); +const PINNED_PREMIUM_QUOTA_MESSAGE = "Premium token quota exceeded for this pinned model."; + +function getPinnedPremiumQuotaErrorMessage(error: unknown): string | null { + if (!(error instanceof Error)) return null; + if (!error.message.toLowerCase().includes("premium token quota exceeded")) { + return null; + } + if (!error.message.toLowerCase().includes("pinned model")) { + return null; + } + return error.message || PINNED_PREMIUM_QUOTA_MESSAGE; +} + export default function NewChatPage() { const params = useParams(); const queryClient = useQueryClient(); @@ -226,6 +240,7 @@ export default function NewChatPage() { const setMentionedDocuments = useSetAtom(mentionedDocumentsAtom); const setMessageDocumentsMap = useSetAtom(messageDocumentsMapAtom); const setCurrentThreadState = useSetAtom(currentThreadAtom); + const setPremiumAlertForThread = useSetAtom(setPremiumAlertForThreadAtom); const setTargetCommentId = useSetAtom(setTargetCommentIdAtom); const clearTargetCommentId = useSetAtom(clearTargetCommentIdAtom); const closeReportPanel = useSetAtom(closeReportPanelAtom); @@ -951,6 +966,7 @@ export default function NewChatPage() { return; } console.error("[NewChatPage] Chat error:", error); + const premiumQuotaAlertMessage = getPinnedPremiumQuotaErrorMessage(error); // Track chat error trackChatError( @@ -959,7 +975,15 @@ export default function NewChatPage() { error instanceof Error ? error.message : "Unknown error" ); - toast.error("Failed to get response. Please try again."); + if (premiumQuotaAlertMessage) { + setPremiumAlertForThread({ + threadId: currentThreadId, + message: premiumQuotaAlertMessage, + }); + toast.error(PINNED_PREMIUM_QUOTA_MESSAGE); + } else { + toast.error("Failed to get response. Please try again."); + } // Update assistant message with error setMessages((prev) => prev.map((m) => @@ -969,7 +993,9 @@ export default function NewChatPage() { content: [ { type: "text", - text: "Sorry, there was an error. Please try again.", + text: + premiumQuotaAlertMessage ?? + "Sorry, there was an error. Please try again.", }, ], } @@ -998,6 +1024,7 @@ export default function NewChatPage() { pendingUserImageUrls, setPendingUserImageUrls, toolsWithUI, + setPremiumAlertForThread, ] ); @@ -1257,13 +1284,29 @@ export default function NewChatPage() { return; } console.error("[NewChatPage] Resume error:", error); - toast.error("Failed to resume. Please try again."); + const premiumQuotaAlertMessage = getPinnedPremiumQuotaErrorMessage(error); + if (premiumQuotaAlertMessage) { + setPremiumAlertForThread({ + threadId: resumeThreadId, + message: premiumQuotaAlertMessage, + }); + toast.error(PINNED_PREMIUM_QUOTA_MESSAGE); + } else { + toast.error("Failed to resume. Please try again."); + } } finally { setIsRunning(false); abortControllerRef.current = null; } }, - [pendingInterrupt, messages, searchSpaceId, tokenUsageStore, toolsWithUI] + [ + pendingInterrupt, + messages, + searchSpaceId, + tokenUsageStore, + toolsWithUI, + setPremiumAlertForThread, + ] ); useEffect(() => { @@ -1584,18 +1627,34 @@ export default function NewChatPage() { } batcher.dispose(); console.error("[NewChatPage] Regeneration error:", error); + const premiumQuotaAlertMessage = getPinnedPremiumQuotaErrorMessage(error); trackChatError( searchSpaceId, threadId, error instanceof Error ? error.message : "Unknown error" ); - toast.error("Failed to regenerate response. Please try again."); + if (premiumQuotaAlertMessage) { + setPremiumAlertForThread({ + threadId, + message: premiumQuotaAlertMessage, + }); + toast.error(PINNED_PREMIUM_QUOTA_MESSAGE); + } else { + toast.error("Failed to regenerate response. Please try again."); + } setMessages((prev) => prev.map((m) => m.id === assistantMsgId ? { ...m, - content: [{ type: "text", text: "Sorry, there was an error. Please try again." }], + content: [ + { + type: "text", + text: + premiumQuotaAlertMessage ?? + "Sorry, there was an error. Please try again.", + }, + ], } : m ) @@ -1605,7 +1664,15 @@ export default function NewChatPage() { abortControllerRef.current = null; } }, - [threadId, searchSpaceId, messages, disabledTools, tokenUsageStore, toolsWithUI] + [ + threadId, + searchSpaceId, + messages, + disabledTools, + tokenUsageStore, + toolsWithUI, + setPremiumAlertForThread, + ] ); // Handle editing a message - truncates history and regenerates with new query diff --git a/surfsense_web/atoms/chat/premium-alert.atom.ts b/surfsense_web/atoms/chat/premium-alert.atom.ts new file mode 100644 index 000000000..c0efc174f --- /dev/null +++ b/surfsense_web/atoms/chat/premium-alert.atom.ts @@ -0,0 +1,33 @@ +import { atom } from "jotai"; + +export type PremiumAlertState = { + message: string; +}; + +export const premiumAlertByThreadAtom = atom>({}); + +export const setPremiumAlertForThreadAtom = atom( + null, + ( + get, + set, + payload: { + threadId: number; + message: string; + } + ) => { + const current = get(premiumAlertByThreadAtom); + set(premiumAlertByThreadAtom, { + ...current, + [payload.threadId]: { message: payload.message }, + }); + } +); + +export const clearPremiumAlertForThreadAtom = atom(null, (get, set, threadId: number) => { + const current = get(premiumAlertByThreadAtom); + if (!(threadId in current)) return; + const next = { ...current }; + delete next[threadId]; + set(premiumAlertByThreadAtom, next); +}); diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index cf99598f1..06f25f5fb 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -37,10 +37,13 @@ import { toggleToolAtom, } from "@/atoms/agent-tools/agent-tools.atoms"; import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom"; -import { - mentionedDocumentsAtom, -} from "@/atoms/chat/mentioned-documents.atom"; +import { currentThreadAtom } from "@/atoms/chat/current-thread.atom"; +import { mentionedDocumentsAtom } from "@/atoms/chat/mentioned-documents.atom"; import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom"; +import { + clearPremiumAlertForThreadAtom, + premiumAlertByThreadAtom, +} from "@/atoms/chat/premium-alert.atom"; import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms"; import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms"; import { membersAtom } from "@/atoms/members/members-query.atoms"; @@ -134,6 +137,9 @@ const ThreadContent: FC = () => { style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }} > + !thread.isEmpty}> + + !thread.isEmpty}> @@ -143,6 +149,38 @@ const ThreadContent: FC = () => { ); }; +const PremiumQuotaPinnedAlert: FC = () => { + const currentThreadState = useAtomValue(currentThreadAtom); + const alertsByThread = useAtomValue(premiumAlertByThreadAtom); + const clearPremiumAlertForThread = useSetAtom(clearPremiumAlertForThreadAtom); + + const currentThreadId = currentThreadState?.id; + if (!currentThreadId) return null; + + const alert = alertsByThread[currentThreadId]; + if (!alert) return null; + + return ( +
+
+ +
+

Premium quota exhausted

+

{alert.message}

+
+ +
+
+ ); +}; + const ThreadScrollToBottom: FC = () => { return ( From c110f5b9551d8593a2e9a207282153c30003da86 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" Date: Wed, 29 Apr 2026 07:20:31 -0700 Subject: [PATCH 06/68] feat: improved agent streaming --- surfsense_backend/.env.example | 8 + .../versions/134_relax_revision_fks.py | 139 +++ .../135_action_log_correlation_ids.py | 82 ++ .../versions/136_new_chat_message_turn_id.py | 52 ++ .../137_unique_reverse_of_in_action_log.py | 74 ++ .../app/agents/new_chat/chat_deepagent.py | 72 +- .../app/agents/new_chat/feature_flags.py | 14 + .../app/agents/new_chat/filesystem_state.py | 69 +- .../agents/new_chat/middleware/action_log.py | 77 +- .../agents/new_chat/middleware/filesystem.py | 421 ++++++++- .../new_chat/middleware/kb_persistence.py | 870 +++++++++++++++++- .../middleware/kb_postgres_backend.py | 109 ++- .../new_chat/middleware/knowledge_tree.py | 40 +- .../middleware/local_folder_backend.py | 68 ++ .../multi_root_local_folder_backend.py | 28 + .../app/agents/new_chat/state_reducers.py | 4 + .../app/agents/new_chat/subagents/config.py | 2 + .../app/agents/new_chat/tools/hitl.py | 42 + surfsense_backend/app/db.py | 36 +- .../app/routes/agent_action_log_route.py | 9 + .../app/routes/agent_revert_route.py | 386 +++++++- .../app/routes/new_chat_routes.py | 486 +++++++++- surfsense_backend/app/schemas/new_chat.py | 44 + .../app/services/new_streaming_service.py | 86 +- .../app/services/revert_service.py | 440 ++++++++- .../app/tasks/chat/stream_new_chat.py | 324 ++++++- .../unit/agents/new_chat/test_action_log.py | 110 +++ .../new_chat/test_desktop_safety_rules.py | 122 +++ .../agents/new_chat/test_hitl_auto_approve.py | 111 +++ .../agents/new_chat/test_rm_rmdir_cloud.py | 333 +++++++ .../agents/new_chat/test_state_reducers.py | 44 + surfsense_backend/tests/unit/db/__init__.py | 0 .../db/test_relax_revision_fks_migration.py | 83 ++ .../middleware/test_filesystem_middleware.py | 16 + .../test_kb_persistence_revisions.py | 309 +++++++ .../unit/middleware/test_knowledge_tree.py | 139 +++ .../middleware/test_local_folder_backend.py | 71 ++ .../tests/unit/routes/__init__.py | 0 .../routes/test_regenerate_from_message_id.py | 143 +++ .../unit/routes/test_revert_turn_route.py | 530 +++++++++++ .../services/test_revert_filesystem_tools.py | 370 ++++++++ .../tests/unit/tasks/__init__.py | 0 .../tests/unit/tasks/chat/__init__.py | 0 .../tasks/chat/test_extract_chunk_parts.py | 185 ++++ .../new-chat/[[...chat_id]]/page.tsx | 580 ++++++++++-- .../atoms/chat/agent-actions.atom.ts | 194 ++++ .../assistant-ui/assistant-message.tsx | 13 + .../assistant-ui/edit-message-dialog.tsx | 106 +++ .../assistant-ui/reasoning-message-part.tsx | 81 ++ .../assistant-ui/revert-turn-button.tsx | 232 +++++ .../assistant-ui/step-separator.tsx | 27 + .../components/assistant-ui/tool-fallback.tsx | 118 ++- .../components/free-chat/free-chat-page.tsx | 52 +- .../public-chat/public-chat-view.tsx | 2 + .../components/public-chat/public-thread.tsx | 2 + surfsense_web/contracts/enums/toolIcons.tsx | 85 ++ .../lib/apis/agent-actions-api.service.ts | 56 ++ surfsense_web/lib/chat/message-utils.ts | 6 +- surfsense_web/lib/chat/streaming-state.ts | 252 ++++- surfsense_web/lib/chat/thread-persistence.ts | 17 +- 60 files changed, 8068 insertions(+), 303 deletions(-) create mode 100644 surfsense_backend/alembic/versions/134_relax_revision_fks.py create mode 100644 surfsense_backend/alembic/versions/135_action_log_correlation_ids.py create mode 100644 surfsense_backend/alembic/versions/136_new_chat_message_turn_id.py create mode 100644 surfsense_backend/alembic/versions/137_unique_reverse_of_in_action_log.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_desktop_safety_rules.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_hitl_auto_approve.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_rm_rmdir_cloud.py create mode 100644 surfsense_backend/tests/unit/db/__init__.py create mode 100644 surfsense_backend/tests/unit/db/test_relax_revision_fks_migration.py create mode 100644 surfsense_backend/tests/unit/middleware/test_kb_persistence_revisions.py create mode 100644 surfsense_backend/tests/unit/middleware/test_knowledge_tree.py create mode 100644 surfsense_backend/tests/unit/routes/__init__.py create mode 100644 surfsense_backend/tests/unit/routes/test_regenerate_from_message_id.py create mode 100644 surfsense_backend/tests/unit/routes/test_revert_turn_route.py create mode 100644 surfsense_backend/tests/unit/services/test_revert_filesystem_tools.py create mode 100644 surfsense_backend/tests/unit/tasks/__init__.py create mode 100644 surfsense_backend/tests/unit/tasks/chat/__init__.py create mode 100644 surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py create mode 100644 surfsense_web/atoms/chat/agent-actions.atom.ts create mode 100644 surfsense_web/components/assistant-ui/edit-message-dialog.tsx create mode 100644 surfsense_web/components/assistant-ui/reasoning-message-part.tsx create mode 100644 surfsense_web/components/assistant-ui/revert-turn-button.tsx create mode 100644 surfsense_web/components/assistant-ui/step-separator.tsx diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index c1bfcc538..a793f33d1 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -282,6 +282,14 @@ LANGSMITH_PROJECT=surfsense # SURFSENSE_ENABLE_ACTION_LOG=false # SURFSENSE_ENABLE_REVERT_ROUTE=false # Backend-only; flip when UI ships +# Streaming parity v2 — opt in to LangChain's structured AIMessageChunk +# content (typed reasoning blocks, tool-input deltas) and propagate the +# real tool_call_id to the SSE layer. When OFF, the stream falls back to +# the str-only text path and synthetic "call_" tool-call ids. +# Schema migrations 135/136 ship unconditionally because they are +# forward-compatible. +# SURFSENSE_ENABLE_STREAM_PARITY_V2=false + # Plugins # SURFSENSE_ENABLE_PLUGIN_LOADER=false # Comma-separated allowlist of plugin entry-point names diff --git a/surfsense_backend/alembic/versions/134_relax_revision_fks.py b/surfsense_backend/alembic/versions/134_relax_revision_fks.py new file mode 100644 index 000000000..99b665426 --- /dev/null +++ b/surfsense_backend/alembic/versions/134_relax_revision_fks.py @@ -0,0 +1,139 @@ +"""134_relax_revision_fks + +Revision ID: 134 +Revises: 133 +Create Date: 2026-04-29 + +Relax the parent FKs on ``document_revisions`` and ``folder_revisions`` so +revisions survive the deletes they describe. + +Why: the snapshot/revert pipeline writes a ``DocumentRevision`` BEFORE +hard-deleting a document via the ``rm`` tool (and likewise a +``FolderRevision`` before ``rmdir``). If the FK is ``ON DELETE CASCADE`` +the snapshot row is wiped at the exact moment we need it most — revert +then has nothing to read and the operation becomes irreversible. + +Migration: + +* ``document_revisions.document_id``: ``NOT NULL`` -> nullable; FK + ``ON DELETE CASCADE`` -> ``ON DELETE SET NULL``. +* ``folder_revisions.folder_id``: same treatment. + +The ``search_space_id`` FK on both tables is left unchanged (still +``ON DELETE CASCADE``). When a search space is deleted, all documents, +folders, AND their revisions go together — that's the correct teardown +story. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy import inspect + +from alembic import op + +revision: str = "134" +down_revision: str | None = "133" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def _fk_name(bind, table: str, column: str) -> str | None: + """Return the (single) FK constraint name on ``table.column``, if any.""" + inspector = inspect(bind) + for fk in inspector.get_foreign_keys(table): + cols = fk.get("constrained_columns") or [] + if cols == [column]: + return fk.get("name") + return None + + +def upgrade() -> None: + bind = op.get_bind() + + # --- document_revisions.document_id -> nullable + SET NULL --------------- + fk_name = _fk_name(bind, "document_revisions", "document_id") + if fk_name: + op.drop_constraint(fk_name, "document_revisions", type_="foreignkey") + op.alter_column( + "document_revisions", + "document_id", + existing_type=sa.Integer(), + nullable=True, + ) + op.create_foreign_key( + "document_revisions_document_id_fkey", + "document_revisions", + "documents", + ["document_id"], + ["id"], + ondelete="SET NULL", + ) + + # --- folder_revisions.folder_id -> nullable + SET NULL ------------------- + fk_name = _fk_name(bind, "folder_revisions", "folder_id") + if fk_name: + op.drop_constraint(fk_name, "folder_revisions", type_="foreignkey") + op.alter_column( + "folder_revisions", + "folder_id", + existing_type=sa.Integer(), + nullable=True, + ) + op.create_foreign_key( + "folder_revisions_folder_id_fkey", + "folder_revisions", + "folders", + ["folder_id"], + ["id"], + ondelete="SET NULL", + ) + + +def downgrade() -> None: + bind = op.get_bind() + + # Reinstating NOT NULL + CASCADE requires draining orphan rows first + # (any revision whose parent doc/folder has already been deleted). + op.execute("DELETE FROM document_revisions WHERE document_id IS NULL") + op.execute("DELETE FROM folder_revisions WHERE folder_id IS NULL") + + # --- document_revisions.document_id -> NOT NULL + CASCADE --------------- + fk_name = _fk_name(bind, "document_revisions", "document_id") + if fk_name: + op.drop_constraint(fk_name, "document_revisions", type_="foreignkey") + op.alter_column( + "document_revisions", + "document_id", + existing_type=sa.Integer(), + nullable=False, + ) + op.create_foreign_key( + "document_revisions_document_id_fkey", + "document_revisions", + "documents", + ["document_id"], + ["id"], + ondelete="CASCADE", + ) + + # --- folder_revisions.folder_id -> NOT NULL + CASCADE ------------------- + fk_name = _fk_name(bind, "folder_revisions", "folder_id") + if fk_name: + op.drop_constraint(fk_name, "folder_revisions", type_="foreignkey") + op.alter_column( + "folder_revisions", + "folder_id", + existing_type=sa.Integer(), + nullable=False, + ) + op.create_foreign_key( + "folder_revisions_folder_id_fkey", + "folder_revisions", + "folders", + ["folder_id"], + ["id"], + ondelete="CASCADE", + ) diff --git a/surfsense_backend/alembic/versions/135_action_log_correlation_ids.py b/surfsense_backend/alembic/versions/135_action_log_correlation_ids.py new file mode 100644 index 000000000..9ae368b81 --- /dev/null +++ b/surfsense_backend/alembic/versions/135_action_log_correlation_ids.py @@ -0,0 +1,82 @@ +"""135_action_log_correlation_ids + +Revision ID: 135 +Revises: 134 +Create Date: 2026-04-29 + +Action-log correlation-id cleanup. + +Background +---------- +``agent_action_log.turn_id`` is misnamed. ``ActionLogMiddleware`` writes +the LangChain ``tool_call.id`` into that column today (see +``action_log.py:_resolve_turn_id``), and ``kb_persistence._find_action_ids_batch`` +joins on it as such. The real chat-turn id (``f"{chat_id}:{ms}"`` from +``stream_new_chat.py``) lives in ``config.configurable.turn_id`` and was +never persisted. + +This migration introduces two new, correctly-named columns: + +* ``tool_call_id`` (LangChain tool-call id, what ``turn_id`` actually held) +* ``chat_turn_id`` (the per-turn correlation id from + ``configurable.turn_id`` — used by the per-turn ``revert-turn`` route). + +Backfill copies the current ``turn_id`` values into ``tool_call_id`` so +existing joins keep working. The old ``turn_id`` column is left in place +for one release as a deprecated alias to give safe rollback. ``ActionLogMiddleware`` +keeps writing it (= ``tool_call_id``) for the same reason. + +Indexes +------- + +* ``ix_agent_action_log_tool_call_id`` — required by + ``_find_action_ids_batch`` (was on ``turn_id``). +* ``ix_agent_action_log_chat_turn_id`` — required by the + ``revert-turn/{chat_turn_id}`` query. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "135" +down_revision: str | None = "134" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.add_column( + "agent_action_log", + sa.Column("tool_call_id", sa.String(length=64), nullable=True), + ) + op.add_column( + "agent_action_log", + sa.Column("chat_turn_id", sa.String(length=64), nullable=True), + ) + + op.create_index( + "ix_agent_action_log_tool_call_id", + "agent_action_log", + ["tool_call_id"], + ) + op.create_index( + "ix_agent_action_log_chat_turn_id", + "agent_action_log", + ["chat_turn_id"], + ) + + op.execute( + "UPDATE agent_action_log SET tool_call_id = turn_id WHERE tool_call_id IS NULL" + ) + + +def downgrade() -> None: + op.drop_index("ix_agent_action_log_chat_turn_id", table_name="agent_action_log") + op.drop_index("ix_agent_action_log_tool_call_id", table_name="agent_action_log") + op.drop_column("agent_action_log", "chat_turn_id") + op.drop_column("agent_action_log", "tool_call_id") diff --git a/surfsense_backend/alembic/versions/136_new_chat_message_turn_id.py b/surfsense_backend/alembic/versions/136_new_chat_message_turn_id.py new file mode 100644 index 000000000..8d4350424 --- /dev/null +++ b/surfsense_backend/alembic/versions/136_new_chat_message_turn_id.py @@ -0,0 +1,52 @@ +"""136_new_chat_message_turn_id + +Revision ID: 136 +Revises: 135 +Create Date: 2026-04-29 + +Persist the per-turn correlation id on each chat message. + +Background +---------- +LangGraph's checkpointer stores user-provided ``configurable.turn_id`` +in checkpoint metadata (see +``langgraph/checkpoint/base/__init__.py:get_checkpoint_metadata``). To +support edit-from-arbitrary-position, the regenerate route needs to map +a ``message_id`` -> ``turn_id`` -> checkpoint at request time. Without +this column the mapping doesn't exist anywhere, so regenerate would +have to hardcode the "last 2 messages" rewind heuristic. + +This migration adds a nullable ``turn_id`` column to ``new_chat_messages`` +plus an index. Legacy rows have NULL — the regenerate route degrades +gracefully to the reload-last-two heuristic for those. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "136" +down_revision: str | None = "135" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.add_column( + "new_chat_messages", + sa.Column("turn_id", sa.String(length=64), nullable=True), + ) + op.create_index( + "ix_new_chat_messages_turn_id", + "new_chat_messages", + ["turn_id"], + ) + + +def downgrade() -> None: + op.drop_index("ix_new_chat_messages_turn_id", table_name="new_chat_messages") + op.drop_column("new_chat_messages", "turn_id") diff --git a/surfsense_backend/alembic/versions/137_unique_reverse_of_in_action_log.py b/surfsense_backend/alembic/versions/137_unique_reverse_of_in_action_log.py new file mode 100644 index 000000000..d606a00f9 --- /dev/null +++ b/surfsense_backend/alembic/versions/137_unique_reverse_of_in_action_log.py @@ -0,0 +1,74 @@ +"""137_unique_reverse_of_in_action_log + +Revision ID: 137 +Revises: 136 +Create Date: 2026-04-29 + +Protect ``agent_action_log.reverse_of`` against double inserts. Two +concurrent revert calls (single-action route + the per-turn batch +route, or two batch routes racing) both pass the +``_was_already_reverted`` SELECT and both insert their own +``_revert:*`` rows. The application-level idempotency check is racy +because there's no DB constraint backing it. + +This migration adds a partial unique index on ``reverse_of`` (PostgreSQL +``WHERE reverse_of IS NOT NULL``) so the second concurrent insert raises +``IntegrityError`` and the route can translate it to ``"already_reverted"`` +deterministically. + +The plain ``UniqueConstraint`` flavour can't be used because most +existing rows have ``reverse_of = NULL`` (only revert rows fill it), +and Postgres does treat NULL as distinct in unique indexes — but a +partial index is the cleanest expression of intent and works even on +older Postgres releases that distinguish NULL handling. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op + +revision: str = "137" +down_revision: str | None = "136" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +_INDEX_NAME = "ux_agent_action_log_reverse_of" + + +def upgrade() -> None: + # Defensively de-dup any pre-existing double-revert rows before + # adding the unique index. Keeps the OLDEST row (smallest id) and + # NULLs out the duplicates' ``reverse_of`` so they survive as audit + # trail but no longer claim to be the canonical revert. We do NOT + # delete them — operators can still inspect them via /actions. + op.execute( + """ + WITH dups AS ( + SELECT id, + reverse_of, + ROW_NUMBER() OVER ( + PARTITION BY reverse_of ORDER BY id ASC + ) AS rn + FROM agent_action_log + WHERE reverse_of IS NOT NULL + ) + UPDATE agent_action_log + SET reverse_of = NULL + WHERE id IN (SELECT id FROM dups WHERE rn > 1) + """ + ) + + op.create_index( + _INDEX_NAME, + "agent_action_log", + ["reverse_of"], + unique=True, + postgresql_where="reverse_of IS NOT NULL", + ) + + +def downgrade() -> None: + op.drop_index(_INDEX_NAME, table_name="agent_action_log") diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index bfb94ba2d..fdd72ea92 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -724,7 +724,8 @@ def _build_compiled_agent_blocking( repair_mw = None if flags.enable_tool_call_repair and not flags.disable_new_agent_stack: registered_names: set[str] = {t.name for t in tools} - # Tools owned by the standard deepagents middleware stack. + # Tools owned by the standard deepagents middleware stack and the + # SurfSense filesystem extension. registered_names |= { "write_todos", "ls", @@ -735,6 +736,14 @@ def _build_compiled_agent_blocking( "grep", "execute", "task", + "mkdir", + "cd", + "pwd", + "move_file", + "rm", + "rmdir", + "list_tree", + "execute_code", } repair_mw = ToolCallNameRepairMiddleware( registered_tool_names=registered_names, @@ -763,25 +772,51 @@ def _build_compiled_agent_blocking( # on every safe read-only call (``ls``, ``read_file``, ``grep``, # ``glob``, ``web_search`` …) and, on resume, replay the previous # reject decision into innocent calls. - # 2. ``connector_synthesized`` — deny rules for tools whose required - # connector is not connected to this space. Overrides #1. - # 3. (future) user-defined rules from ``agent_permission_rules`` table - # via the Agent Permissions UI. Loaded last so they override both. + # 2. ``desktop_safety`` — ``ask`` for destructive filesystem ops when + # the agent is operating against the user's real disk. Cloud mode + # has full revision-based revert via ``revert_service``, but + # desktop mode hits disk immediately with no undo, so an + # accidental ``rm`` / ``rmdir`` / ``move_file`` / ``edit_file`` / + # ``write_file`` is unrecoverable. This layer is forced on in + # desktop mode regardless of ``enable_permission`` because the + # safety net is non-negotiable. + # 3. ``connector_synthesized`` — deny rules for tools whose required + # connector is not connected to this space. Overrides #1/#2. + # 4. (future) user-defined rules from ``agent_permission_rules`` table + # via the Agent Permissions UI. Loaded last so they override all. permission_mw: PermissionMiddleware | None = None - if flags.enable_permission and not flags.disable_new_agent_stack: - synthesized = _synthesize_connector_deny_rules( - available_connectors=available_connectors, - enabled_tool_names={t.name for t in tools}, - ) - permission_mw = PermissionMiddleware( - rulesets=[ + is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER + permission_enabled = flags.enable_permission and not flags.disable_new_agent_stack + # Build the middleware whenever it has work to do: either the user + # opted into the rule engine, OR we're in desktop mode and need the + # safety rules unconditionally. + if permission_enabled or is_desktop_fs: + rulesets: list[Ruleset] = [ + Ruleset( + rules=[Rule(permission="*", pattern="*", action="allow")], + origin="surfsense_defaults", + ), + ] + if is_desktop_fs: + rulesets.append( Ruleset( - rules=[Rule(permission="*", pattern="*", action="allow")], - origin="surfsense_defaults", - ), - Ruleset(rules=synthesized, origin="connector_synthesized"), - ], - ) + rules=[ + Rule(permission="rm", pattern="*", action="ask"), + Rule(permission="rmdir", pattern="*", action="ask"), + Rule(permission="move_file", pattern="*", action="ask"), + Rule(permission="edit_file", pattern="*", action="ask"), + Rule(permission="write_file", pattern="*", action="ask"), + ], + origin="desktop_safety", + ) + ) + if permission_enabled: + synthesized = _synthesize_connector_deny_rules( + available_connectors=available_connectors, + enabled_tool_names={t.name for t in tools}, + ) + rulesets.append(Ruleset(rules=synthesized, origin="connector_synthesized")) + permission_mw = PermissionMiddleware(rulesets=rulesets) # ActionLogMiddleware. Off by default until the ``agent_action_log`` # table is migrated. When enabled, persists one row per tool call @@ -938,6 +973,7 @@ def _build_compiled_agent_blocking( search_space_id=search_space_id, created_by_id=user_id, filesystem_mode=filesystem_mode, + thread_id=thread_id, ) if filesystem_mode == FilesystemMode.CLOUD else None, diff --git a/surfsense_backend/app/agents/new_chat/feature_flags.py b/surfsense_backend/app/agents/new_chat/feature_flags.py index 55525abc5..f58bf0dd7 100644 --- a/surfsense_backend/app/agents/new_chat/feature_flags.py +++ b/surfsense_backend/app/agents/new_chat/feature_flags.py @@ -23,6 +23,7 @@ Local development (recommended for trying everything except doom-loop / selector SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false + SURFSENSE_ENABLE_STREAM_PARITY_V2=false # structured streaming events Master kill-switch (overrides everything else): @@ -86,6 +87,15 @@ class AgentFeatureFlags: False # Backend ships before UI; route returns 503 until this flips ) + # Streaming parity v2 — opt in to LangChain's structured + # ``AIMessageChunk`` content (typed reasoning blocks, tool-input + # deltas) and propagate the real ``tool_call_id`` to the SSE layer. + # When OFF the ``stream_new_chat`` task falls back to the str-only + # text path and the synthetic ``call_`` tool-call id (no + # ``langchainToolCallId`` propagation). Schema migrations 135/136 + # ship unconditionally because they're forward-compatible. + enable_stream_parity_v2: bool = False + # Plugins enable_plugin_loader: bool = False @@ -139,6 +149,10 @@ class AgentFeatureFlags: # Snapshot / revert enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False), enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False), + # Streaming parity v2 + enable_stream_parity_v2=_env_bool( + "SURFSENSE_ENABLE_STREAM_PARITY_V2", False + ), # Plugins enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False), # Observability diff --git a/surfsense_backend/app/agents/new_chat/filesystem_state.py b/surfsense_backend/app/agents/new_chat/filesystem_state.py index 18952ed6f..f54ada76e 100644 --- a/surfsense_backend/app/agents/new_chat/filesystem_state.py +++ b/surfsense_backend/app/agents/new_chat/filesystem_state.py @@ -5,9 +5,14 @@ extra fields needed to implement Postgres-backed virtual filesystem semantics: * ``cwd`` — current working directory (per-thread checkpointed). * ``staged_dirs`` — pending mkdir requests (cloud only). +* ``staged_dir_tool_calls`` — sidecar map ``path -> tool_call_id`` for staged dirs. * ``pending_moves`` — pending move_file requests (cloud only). +* ``pending_deletes`` — pending ``rm`` requests (cloud only). +* ``pending_dir_deletes`` — pending ``rmdir`` requests (cloud only). * ``doc_id_by_path`` — virtual_path -> Document.id, populated by lazy reads. * ``dirty_paths`` — paths whose state file content differs from DB. +* ``dirty_path_tool_calls`` — sidecar map ``path -> latest tool_call_id`` for + dirty paths; used to bind the per-path snapshot to an action_id. * ``kb_priority`` — top-K priority hints rendered into a system message. * ``kb_matched_chunk_ids`` — internal hand-off for matched-chunk highlighting. * ``kb_anon_doc`` — Redis-loaded anonymous document (if any). @@ -32,12 +37,31 @@ from app.agents.new_chat.state_reducers import ( ) -class PendingMove(TypedDict): - """A staged move_file operation pending end-of-turn commit.""" +class PendingMove(TypedDict, total=False): + """A staged move_file operation pending end-of-turn commit. + + ``tool_call_id`` is optional for backward compatibility with checkpoints + written before the snapshot/revert pipeline was wired up; new entries + always include it so the persistence body can resolve an action_id. + """ source: str dest: str overwrite: bool + tool_call_id: str + + +class PendingDelete(TypedDict, total=False): + """A staged ``rm`` or ``rmdir`` operation pending end-of-turn commit. + + ``tool_call_id`` is required for new entries (it's the binding key used + by :class:`KnowledgeBasePersistenceMiddleware` to find the matching + :class:`AgentActionLog` row and bind the snapshot to it). Marked + ``total=False`` only to tolerate older checkpoint payloads. + """ + + path: str + tool_call_id: str class KbPriorityEntry(TypedDict, total=False): @@ -76,9 +100,38 @@ class SurfSenseFilesystemState(FilesystemState): staged_dirs: NotRequired[Annotated[list[str], _add_unique_reducer]] """mkdir paths staged for end-of-turn folder creation (cloud only).""" + staged_dir_tool_calls: NotRequired[ + Annotated[dict[str, str], _dict_merge_with_tombstones_reducer] + ] + """``path -> tool_call_id`` sidecar for ``staged_dirs``. + + Used by :class:`KnowledgeBasePersistenceMiddleware` to bind the + :class:`FolderRevision` snapshot to the originating ``mkdir`` action. + Kept separate from ``staged_dirs`` (which stays a unique-string list) + to avoid breaking ``_add_unique_reducer`` semantics. + """ + pending_moves: NotRequired[Annotated[list[PendingMove], _list_append_reducer]] """move_file ops staged for end-of-turn commit (cloud only).""" + pending_deletes: NotRequired[Annotated[list[PendingDelete], _list_append_reducer]] + """``rm`` ops staged for end-of-turn ``DELETE FROM documents`` (cloud only). + + Each entry is a dict ``{"path": ..., "tool_call_id": ...}``. Per-path + uniqueness is enforced inside the commit body, not the reducer (we keep + ``tool_call_id`` per occurrence so snapshot binding works). + """ + + pending_dir_deletes: NotRequired[ + Annotated[list[PendingDelete], _list_append_reducer] + ] + """``rmdir`` ops staged for end-of-turn ``DELETE FROM folders`` (cloud only). + + Same shape as :data:`pending_deletes`. Commit body re-verifies the + folder is empty (in-DB AND with this turn's pending changes accounted + for) before issuing the DELETE. + """ + doc_id_by_path: NotRequired[ Annotated[dict[str, int], _dict_merge_with_tombstones_reducer] ] @@ -92,6 +145,17 @@ class SurfSenseFilesystemState(FilesystemState): dirty_paths: NotRequired[Annotated[list[str], _add_unique_reducer]] """Paths whose ``state["files"]`` content has been modified this turn.""" + dirty_path_tool_calls: NotRequired[ + Annotated[dict[str, str], _dict_merge_with_tombstones_reducer] + ] + """``path -> latest tool_call_id`` sidecar for ``dirty_paths``. + + The persistence body coalesces multiple writes/edits to the same path + into one snapshot per turn. This map captures the most-recent + ``tool_call_id`` so the resulting :class:`DocumentRevision` is bound + to the latest action_id (the one the user is most likely to revert). + """ + kb_priority: NotRequired[Annotated[list[KbPriorityEntry], _replace_reducer]] """Top-K priority hints rendered as a system message before the user turn.""" @@ -108,6 +172,7 @@ class SurfSenseFilesystemState(FilesystemState): __all__ = [ "KbAnonDoc", "KbPriorityEntry", + "PendingDelete", "PendingMove", "SurfSenseFilesystemState", ] diff --git a/surfsense_backend/app/agents/new_chat/middleware/action_log.py b/surfsense_backend/app/agents/new_chat/middleware/action_log.py index 3675064e8..716a1616c 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/action_log.py +++ b/surfsense_backend/app/agents/new_chat/middleware/action_log.py @@ -30,6 +30,7 @@ from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any from langchain.agents.middleware import AgentMiddleware +from langchain_core.callbacks import adispatch_custom_event from langchain_core.messages import ToolMessage from app.agents.new_chat.feature_flags import get_flags @@ -144,11 +145,19 @@ class ActionLogMiddleware(AgentMiddleware): result=result, ) + tool_call_id = _resolve_tool_call_id(request) + chat_turn_id = _resolve_chat_turn_id(request) + row = AgentActionLog( thread_id=self._thread_id, user_id=self._user_id, search_space_id=self._search_space_id, - turn_id=_resolve_turn_id(request), + # ``turn_id`` is the deprecated alias of ``tool_call_id`` + # kept for one release for safe rollback. New consumers + # should read ``tool_call_id`` directly. + turn_id=tool_call_id, + tool_call_id=tool_call_id, + chat_turn_id=chat_turn_id, message_id=_resolve_message_id(request), tool_name=tool_name, args=args_payload, @@ -160,11 +169,41 @@ class ActionLogMiddleware(AgentMiddleware): async with shielded_async_session() as session: session.add(row) await session.commit() + row_id = int(row.id) if row.id is not None else None + row_created_at = row.created_at except Exception: logger.warning( "ActionLogMiddleware failed to persist action log row", exc_info=True, ) + return + + # Surface a side-channel SSE event so the chat tool card can + # render a Revert button immediately after the row is durable. + # ``stream_new_chat`` translates this into a + # ``data-action-log`` SSE event. We DO NOT include the + # ``reverse_descriptor`` payload here; only a presence flag. + try: + await adispatch_custom_event( + "action_log", + { + "id": row_id, + "lc_tool_call_id": tool_call_id, + "chat_turn_id": chat_turn_id, + "tool_name": tool_name, + "reversible": bool(reversible), + "reverse_descriptor_present": reverse_descriptor is not None, + "created_at": row_created_at.isoformat() + if row_created_at + else None, + "error": error_payload is not None, + }, + ) + except Exception: + logger.debug( + "ActionLogMiddleware failed to dispatch action_log event", + exc_info=True, + ) def _render_reverse( self, @@ -254,7 +293,8 @@ def _resolve_args_payload(request: Any) -> dict[str, Any] | None: } -def _resolve_turn_id(request: Any) -> str | None: +def _resolve_tool_call_id(request: Any) -> str | None: + """Return the LangChain ``tool_call.id`` for this request, if any.""" try: call = getattr(request, "tool_call", None) or {} if isinstance(call, dict): @@ -266,9 +306,40 @@ def _resolve_turn_id(request: Any) -> str | None: return None +# Deprecated alias kept for one release. Old callers and tests treated +# ``turn_id`` as if it carried the LangChain tool_call id; the new column +# lives under ``tool_call_id``. Both resolve to the same value today. +_resolve_turn_id = _resolve_tool_call_id + + +def _resolve_chat_turn_id(request: Any) -> str | None: + """Return ``configurable.turn_id`` for this request, if accessible. + + ``ToolRuntime.config`` is exposed by LangGraph (see + ``langgraph/prebuilt/tool_node.py``); the chat-turn correlation id + lives at ``runtime.config["configurable"]["turn_id"]``. + """ + try: + runtime = getattr(request, "runtime", None) + if runtime is None: + return None + config = getattr(runtime, "config", None) + if not isinstance(config, dict): + return None + configurable = config.get("configurable") + if not isinstance(configurable, dict): + return None + value = configurable.get("turn_id") + if isinstance(value, str) and value: + return value + except Exception: # pragma: no cover - defensive + pass + return None + + def _resolve_message_id(request: Any) -> str | None: """Tool-call IDs serve as best-available message correlator at this layer.""" - return _resolve_turn_id(request) + return _resolve_tool_call_id(request) def _resolve_result_id(result: Any) -> str | None: diff --git a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py index 62316d69e..c46eb98a5 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py +++ b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py @@ -102,6 +102,8 @@ current working directory (`cwd`, default `/documents`). - cd(path): change the current working directory. - pwd(): print the current working directory. - move_file(source, dest): move/rename a file under `/documents/`. +- rm(path): delete a single file under `/documents/` (no `-r`). +- rmdir(path): delete an empty directory under `/documents/`. - list_tree(path, max_depth, page_size): recursively list files/folders. ## Persistence Rules @@ -112,8 +114,9 @@ current working directory (`cwd`, default `/documents`). `/documents/temp_scratch.md`) are **discarded** at end of turn — use this prefix for any scratch/working content you do NOT want saved. - All other paths (outside `/documents/` and not `temp_*`) are rejected. -- mkdir/move_file are staged this turn and committed at end of turn alongside - any new/edited documents. +- mkdir/move_file/rm/rmdir are staged this turn and committed at end of + turn alongside any new/edited documents. Snapshot/revert is enabled + for every destructive operation when action logging is on. ## Reading Documents Efficiently @@ -176,6 +179,8 @@ directory (`cwd`). - cd(path): change the current working directory. - pwd(): print the current working directory. - move_file(source, dest): move/rename a file. +- rm(path): delete a single file from disk (no `-r`). NOT reversible. +- rmdir(path): delete an empty directory from disk. NOT reversible. - list_tree(path, max_depth, page_size): recursively list files/folders. ## Workflow Tips @@ -184,6 +189,8 @@ directory (`cwd`). - For large trees, prefer `list_tree` then `grep` then `read_file` over brute-force directory traversal. - Cross-mount moves are not supported. +- Desktop deletes hit disk immediately and cannot be undone via the + agent's revert flow — confirm before calling `rm`/`rmdir`. """ ) @@ -355,6 +362,42 @@ Notes: - Parent folders are created as needed. """ +_CLOUD_RM_TOOL_DESCRIPTION = """Deletes a single file under `/documents/`. + +Mirrors POSIX `rm path` (no `-r`, no glob expansion). Stages the deletion +for end-of-turn commit; the row is removed only after the agent's turn +finishes successfully. + +Args: +- path: absolute or relative file path. Cannot point at a directory — use + `rmdir` for empty folders. Cannot target the root or `/documents`. + +Notes: +- The action is reversible via the per-action revert flow when action + logging is enabled. +- The anonymous uploaded document is read-only and cannot be deleted. +""" + +_CLOUD_RMDIR_TOOL_DESCRIPTION = """Deletes an empty directory under `/documents/`. + +Mirrors POSIX `rmdir path`: refuses non-empty directories. Recursive +deletion (`rm -r`) is intentionally NOT supported — clear contents with +`rm` first. + +Args: +- path: absolute or relative directory path. Cannot target the root, + `/documents`, the current cwd, or any ancestor of cwd (use `cd` to + move out first). + +Notes: +- Emptiness is evaluated against the post-staged view, so a same-turn + `rm /a/x.md` followed by `rmdir /a` is fine. +- If the directory was added in this same turn via `mkdir` and never + committed, the staged mkdir is dropped instead of issuing a delete. +- The action is reversible via the per-action revert flow when action + logging is enabled. +""" + # --- desktop-only ---------------------------------------------------------- _DESKTOP_LIST_FILES_TOOL_DESCRIPTION = """Lists files and directories at the given path. @@ -421,6 +464,28 @@ Notes: - Parent folders are created as needed. """ +_DESKTOP_RM_TOOL_DESCRIPTION = """Deletes a single file from disk. + +Mirrors POSIX `rm path` (no `-r`, no glob expansion). The deletion hits +disk immediately. Desktop deletes are NOT reversible via the agent's +revert flow. + +Args: +- path: absolute mount-prefixed file path. Cannot point at a directory — + use `rmdir` for empty folders. +""" + +_DESKTOP_RMDIR_TOOL_DESCRIPTION = """Deletes an empty directory from disk. + +Mirrors POSIX `rmdir path`: refuses non-empty directories. Recursive +deletion is NOT supported. The deletion hits disk immediately and is +NOT reversible via the agent's revert flow. + +Args: +- path: absolute mount-prefixed directory path. Cannot target the mount + root or any directory containing files/subfolders. +""" + def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]: """Pick the active-mode description for every filesystem tool.""" @@ -437,6 +502,8 @@ def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]: "mkdir": _CLOUD_MKDIR_TOOL_DESCRIPTION, "cd": SURFSENSE_CD_TOOL_DESCRIPTION, "pwd": SURFSENSE_PWD_TOOL_DESCRIPTION, + "rm": _CLOUD_RM_TOOL_DESCRIPTION, + "rmdir": _CLOUD_RMDIR_TOOL_DESCRIPTION, } return { "ls": _DESKTOP_LIST_FILES_TOOL_DESCRIPTION, @@ -450,6 +517,8 @@ def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]: "mkdir": _DESKTOP_MKDIR_TOOL_DESCRIPTION, "cd": SURFSENSE_CD_TOOL_DESCRIPTION, "pwd": SURFSENSE_PWD_TOOL_DESCRIPTION, + "rm": _DESKTOP_RM_TOOL_DESCRIPTION, + "rmdir": _DESKTOP_RMDIR_TOOL_DESCRIPTION, } @@ -476,6 +545,21 @@ def _basename(path: str) -> str: return path.rsplit("/", 1)[-1] +def _is_ancestor_of(candidate: str, target: str) -> bool: + """True iff ``candidate`` is a strict ancestor directory of ``target``. + + ``target`` itself is NOT considered an ancestor (use equality for that). + Both paths are assumed to be canonicalised, absolute, and free of + trailing slashes (except the root ``/``). + """ + if not candidate.startswith("/") or not target.startswith("/"): + return False + if candidate == target: + return False + prefix = candidate.rstrip("/") + "/" + return target.startswith(prefix) + + class SurfSenseFilesystemMiddleware(FilesystemMiddleware): """SurfSense-specific filesystem middleware (cloud + desktop).""" @@ -519,6 +603,8 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): self.tools.append(self._create_cd_tool()) self.tools.append(self._create_pwd_tool()) self.tools.append(self._create_move_file_tool()) + self.tools.append(self._create_rm_tool()) + self.tools.append(self._create_rmdir_tool()) self.tools.append(self._create_list_tree_tool()) if self._sandbox_available: self.tools.append(self._create_execute_code_tool()) @@ -941,6 +1027,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): } if self._is_cloud(): update["dirty_paths"] = [path] + update["dirty_path_tool_calls"] = {path: runtime.tool_call_id} return Command(update=update) def sync_write_file( @@ -1036,6 +1123,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): } if self._is_cloud(): update["dirty_paths"] = [path] + update["dirty_path_tool_calls"] = {path: runtime.tool_call_id} if doc_id_to_attach is not None: update["doc_id_by_path"] = {path: doc_id_to_attach} return Command(update=update) @@ -1103,6 +1191,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): return Command( update={ "staged_dirs": [validated], + "staged_dir_tool_calls": { + validated: runtime.tool_call_id, + }, "messages": [ ToolMessage( content=( @@ -1372,7 +1463,14 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): files_update: dict[str, Any] = {source: None, dest: source_file_data} update: dict[str, Any] = { "files": files_update, - "pending_moves": [{"source": source, "dest": dest, "overwrite": False}], + "pending_moves": [ + { + "source": source, + "dest": dest, + "overwrite": False, + "tool_call_id": runtime.tool_call_id, + } + ], "messages": [ ToolMessage( content=( @@ -1396,6 +1494,323 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): update["dirty_paths"] = new_dirty return Command(update=update) + # ------------------------------------------------------------------ tool: rm + + def _create_rm_tool(self) -> BaseTool: + tool_description = ( + self._custom_tool_descriptions.get("rm") or _CLOUD_RM_TOOL_DESCRIPTION + ) + + async def async_rm( + path: Annotated[ + str, + "Absolute or relative path to the file to delete.", + ], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> Command | str: + if not path or not path.strip(): + return "Error: path is required." + + target = self._resolve_relative(path, runtime) + try: + validated = validate_path(target) + except ValueError as exc: + return f"Error: {exc}" + + if self._is_cloud(): + if validated in ("/", DOCUMENTS_ROOT): + return f"Error: refusing to rm '{validated}'." + if not validated.startswith(DOCUMENTS_ROOT + "/"): + return ( + "Error: cloud rm must target a path under /documents/ " + f"(got '{validated}')." + ) + + anon = runtime.state.get("kb_anon_doc") or {} + if isinstance(anon, dict) and str(anon.get("path") or "") == validated: + return "Error: the anonymous uploaded document is read-only." + + # Refuse if the path looks like a directory. + staged_dirs = list(runtime.state.get("staged_dirs") or []) + if validated in staged_dirs: + return ( + f"Error: '{validated}' is a directory. Use rmdir for " + "empty directories." + ) + pending_dir_deletes = list( + runtime.state.get("pending_dir_deletes") or [] + ) + if any( + isinstance(d, dict) and d.get("path") == validated + for d in pending_dir_deletes + ): + return f"Error: '{validated}' is already queued for rmdir." + + backend = self._get_backend(runtime) + if isinstance(backend, KBPostgresBackend): + # Detect "is a directory" via `ls`: if the path lists + # children we know it's a folder. Otherwise we still + # need to confirm it's a real file before staging. + children = await backend.als_info(validated) + if children: + return ( + f"Error: '{validated}' is a directory. Use rmdir for " + "empty directories." + ) + + # Already queued for delete this turn? + pending_deletes = list(runtime.state.get("pending_deletes") or []) + if any( + isinstance(d, dict) and d.get("path") == validated + for d in pending_deletes + ): + return f"'{validated}' is already queued for deletion." + + # Resolve doc_id (best-effort): file in state or DB. + files_state = runtime.state.get("files") or {} + doc_id_by_path = runtime.state.get("doc_id_by_path") or {} + resolved_doc_id: int | None = doc_id_by_path.get(validated) + if ( + validated not in files_state + and resolved_doc_id is None + and isinstance(backend, KBPostgresBackend) + ): + loaded = await backend._load_file_data(validated) + if loaded is None: + return f"Error: file '{validated}' not found." + _, resolved_doc_id = loaded + + files_update: dict[str, Any] = {validated: None} + update: dict[str, Any] = { + "pending_deletes": [ + { + "path": validated, + "tool_call_id": runtime.tool_call_id, + } + ], + "files": files_update, + "doc_id_by_path": {validated: None}, + "messages": [ + ToolMessage( + content=( + f"Staged delete of '{validated}' (will commit at " + "end of turn)." + ), + tool_call_id=runtime.tool_call_id, + ) + ], + } + + # Drop the path from dirty_paths so a same-turn write+rm + # doesn't recreate the doc at commit time. + dirty_paths = list(runtime.state.get("dirty_paths") or []) + if validated in dirty_paths: + new_dirty: list[Any] = [_CLEAR] + for entry in dirty_paths: + if entry != validated: + new_dirty.append(entry) + update["dirty_paths"] = new_dirty + update["dirty_path_tool_calls"] = {validated: None} + + return Command(update=update) + + # Desktop mode — hit disk immediately. + backend = self._get_backend(runtime) + adelete = getattr(backend, "adelete_file", None) + if not callable(adelete): + return "Error: rm is not supported by the active backend." + res: WriteResult = await adelete(validated) + if res.error: + return res.error + update_desktop: dict[str, Any] = { + "files": {validated: None}, + "messages": [ + ToolMessage( + content=f"Deleted file '{res.path or validated}'", + tool_call_id=runtime.tool_call_id, + ) + ], + } + return Command(update=update_desktop) + + def sync_rm( + path: Annotated[ + str, + "Absolute or relative path to the file to delete.", + ], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> Command | str: + return self._run_async_blocking(async_rm(path, runtime)) + + return StructuredTool.from_function( + name="rm", + description=tool_description, + func=sync_rm, + coroutine=async_rm, + ) + + # ------------------------------------------------------------------ tool: rmdir + + def _create_rmdir_tool(self) -> BaseTool: + tool_description = ( + self._custom_tool_descriptions.get("rmdir") or _CLOUD_RMDIR_TOOL_DESCRIPTION + ) + + async def async_rmdir( + path: Annotated[ + str, + "Absolute or relative path of the empty directory to delete.", + ], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> Command | str: + if not path or not path.strip(): + return "Error: path is required." + + target = self._resolve_relative(path, runtime) + try: + validated = validate_path(target) + except ValueError as exc: + return f"Error: {exc}" + + if self._is_cloud(): + if validated in ("/", DOCUMENTS_ROOT): + return f"Error: refusing to rmdir '{validated}'." + if not validated.startswith(DOCUMENTS_ROOT + "/"): + return ( + "Error: cloud rmdir must target a path under /documents/ " + f"(got '{validated}')." + ) + + cwd = self._current_cwd(runtime) + if validated == cwd or _is_ancestor_of(validated, cwd): + return ( + f"Error: cannot rmdir '{validated}' because the current " + "cwd is at or under it. cd out first." + ) + + staged_dirs = list(runtime.state.get("staged_dirs") or []) + pending_dir_deletes = list( + runtime.state.get("pending_dir_deletes") or [] + ) + if any( + isinstance(d, dict) and d.get("path") == validated + for d in pending_dir_deletes + ): + return f"'{validated}' is already queued for deletion." + + backend = self._get_backend(runtime) + + # The path must currently exist either in DB folder paths or + # in staged_dirs. We rely on KBPostgresBackend.als_info (which + # already accounts for pending deletes/moves) to evaluate + # both existence and emptiness against the post-staged view. + exists_in_staged = validated in staged_dirs + children: list[Any] = [] + if isinstance(backend, KBPostgresBackend): + children = list(await backend.als_info(validated)) + + # Detect "is a file" — if als_info returns no children but + # the path is actually a file, we should reject. We use + # _load_file_data to disambiguate file vs missing folder. + if ( + isinstance(backend, KBPostgresBackend) + and not children + and not exists_in_staged + ): + loaded = await backend._load_file_data(validated) + if loaded is not None: + return ( + f"Error: '{validated}' is a file. Use rm to delete files." + ) + # Confirm folder exists in DB by checking the parent listing. + parent = posixpath.dirname(validated) or "/" + parent_listing = await backend.als_info(parent) + parent_has_dir = any( + info.get("path") == validated and info.get("is_dir") + for info in parent_listing + ) + if not parent_has_dir: + return f"Error: directory '{validated}' not found." + + if children: + return ( + f"Error: directory '{validated}' is not empty. " + "Remove contents first." + ) + + # Same-turn mkdir un-stage: drop the staged_dirs entry + # entirely and skip queuing a DB delete (nothing was ever + # committed). + if exists_in_staged: + rest = [d for d in staged_dirs if d != validated] + return Command( + update={ + "staged_dirs": [_CLEAR, *rest], + "staged_dir_tool_calls": {validated: None}, + "messages": [ + ToolMessage( + content=(f"Un-staged directory '{validated}'."), + tool_call_id=runtime.tool_call_id, + ) + ], + } + ) + + return Command( + update={ + "pending_dir_deletes": [ + { + "path": validated, + "tool_call_id": runtime.tool_call_id, + } + ], + "messages": [ + ToolMessage( + content=( + f"Staged rmdir of '{validated}' (will commit " + "at end of turn)." + ), + tool_call_id=runtime.tool_call_id, + ) + ], + } + ) + + # Desktop mode — hit disk immediately. + backend = self._get_backend(runtime) + armdir = getattr(backend, "armdir", None) + if not callable(armdir): + return "Error: rmdir is not supported by the active backend." + res: WriteResult = await armdir(validated) + if res.error: + return res.error + return Command( + update={ + "messages": [ + ToolMessage( + content=f"Deleted directory '{res.path or validated}'", + tool_call_id=runtime.tool_call_id, + ) + ], + } + ) + + def sync_rmdir( + path: Annotated[ + str, + "Absolute or relative path of the empty directory to delete.", + ], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> Command | str: + return self._run_async_blocking(async_rmdir(path, runtime)) + + return StructuredTool.from_function( + name="rmdir", + description=tool_description, + func=sync_rmdir, + coroutine=async_rmdir, + ) + # ------------------------------------------------------------------ tool: list_tree def _create_list_tree_tool(self) -> BaseTool: diff --git a/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py b/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py index 378b83950..d577441dd 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py +++ b/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py @@ -1,16 +1,29 @@ """End-of-turn persistence for the cloud-mode SurfSense filesystem. This middleware runs ``aafter_agent`` once per turn (cloud only). It commits -all staged folder creations, file moves, and content writes/edits to -Postgres in a single ordered pass: +all staged folder creations, file moves, content writes/edits, file deletes +(``rm``), and directory deletes (``rmdir``) to Postgres in a single ordered +pass: 1. Materialize ``staged_dirs`` into ``Folder`` rows. 2. Apply ``pending_moves`` in order (chained moves resolved via ``doc_id_by_path``). 3. Normalize ``dirty_paths`` through ``pending_moves`` so write-then-move - sequences commit at the final path. + sequences commit at the final path. Paths queued for ``rm`` this turn + are dropped here so a write+rm sequence doesn't recreate the doc. 4. Commit content writes / edits for ``/documents/*`` paths, skipping ``temp_*`` basenames. +5. Apply ``pending_deletes`` (``rm``) — file deletes run BEFORE directory + deletes so a same-turn ``rm /a/x.md`` + ``rmdir /a`` sequence works. +6. Apply ``pending_dir_deletes`` (``rmdir``); re-verifies emptiness against + the post-step-5 DB state. + +When ``flags.enable_action_log`` is on every destructive op also writes a +``DocumentRevision`` / ``FolderRevision`` snapshot bound to the +originating ``AgentActionLog`` row via ``tool_call_id``. ``rm``/``rmdir`` +share a single ``SAVEPOINT`` with their snapshot — if the snapshot fails +the DELETE rolls back and we surface the error rather than silently +making the data irreversible. The commit body is exposed as a free function ``commit_staged_filesystem_state`` so the optional stream-task fallback (``stream_new_chat.py``) can call the @@ -25,12 +38,13 @@ from typing import Any from fractional_indexing import generate_key_between from langchain.agents.middleware import AgentMiddleware, AgentState -from langchain_core.callbacks import dispatch_custom_event +from langchain_core.callbacks import adispatch_custom_event, dispatch_custom_event from langgraph.runtime import Runtime -from sqlalchemy import delete, select +from sqlalchemy import delete, select, update from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.feature_flags import get_flags from app.agents.new_chat.filesystem_selection import FilesystemMode from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState from app.agents.new_chat.path_resolver import ( @@ -41,10 +55,13 @@ from app.agents.new_chat.path_resolver import ( ) from app.agents.new_chat.state_reducers import _CLEAR from app.db import ( + AgentActionLog, Chunk, Document, + DocumentRevision, DocumentType, Folder, + FolderRevision, shielded_async_session, ) from app.indexing_pipeline.document_chunker import chunk_text @@ -123,6 +140,47 @@ async def _ensure_folder_hierarchy( return parent_id +async def _resolve_folder_id( + session: AsyncSession, + *, + search_space_id: int, + folder_parts: list[str], +) -> int | None: + """Look up an existing folder chain without creating anything. + + Returns ``None`` if any segment is missing. Used by ``rmdir`` snapshot + capture and by parent-folder lookup at ``rmdir`` commit time. + """ + if not folder_parts: + return None + parent_id: int | None = None + for raw in folder_parts: + name = safe_folder_segment(str(raw)) + query = select(Folder).where( + Folder.search_space_id == search_space_id, + Folder.name == name, + ) + query = ( + query.where(Folder.parent_id.is_(None)) + if parent_id is None + else query.where(Folder.parent_id == parent_id) + ) + result = await session.execute(query) + folder = result.scalar_one_or_none() + if folder is None: + return None + parent_id = folder.id + return parent_id + + +def _split_folder_path(folder_path: str) -> list[str]: + """Return the folder segments under ``/documents/`` for a path.""" + if not folder_path.startswith(DOCUMENTS_ROOT): + return [] + rel = folder_path[len(DOCUMENTS_ROOT) :].strip("/") + return [p for p in rel.split("/") if p] + + # --------------------------------------------------------------------------- # Document helpers # --------------------------------------------------------------------------- @@ -331,6 +389,298 @@ async def _apply_move( return {"id": document.id, "source": source, "dest": dest, "title": new_title} +# --------------------------------------------------------------------------- +# Action log binding helpers +# --------------------------------------------------------------------------- + + +async def _find_action_ids_batch( + session: AsyncSession, + *, + thread_id: int | None, + tool_call_ids: set[str], +) -> dict[str, int]: + """Resolve ``tool_call_id -> AgentActionLog.id`` in a single query. + + Returns an empty dict when ``thread_id`` or ``tool_call_ids`` are + missing — callers treat that as "no binding available" and write the + revision with ``agent_action_id = NULL``. + """ + if thread_id is None or not tool_call_ids: + return {} + rows = await session.execute( + select(AgentActionLog.id, AgentActionLog.tool_call_id).where( + AgentActionLog.thread_id == thread_id, + AgentActionLog.tool_call_id.in_(list(tool_call_ids)), + ) + ) + mapping: dict[str, int] = {} + for row in rows.all(): + if row.tool_call_id and row.id: + mapping[str(row.tool_call_id)] = int(row.id) + return mapping + + +async def _mark_action_reversible( + session: AsyncSession, + *, + action_id: int | None, +) -> None: + """Flip ``agent_action_log.reversible = TRUE`` for ``action_id``. + + Best-effort: caller may invoke from inside a SAVEPOINT and treat + failure as a soft demotion (snapshot persists, just no Revert button). + + Callers should also call ``_dispatch_reversibility_update`` (defined + below) AFTER the enclosing SAVEPOINT block exits successfully so the + chat tool card can light up its Revert button without + re-fetching ``GET /threads/.../actions``. Dispatching from inside the + SAVEPOINT would risk emitting "reversible=true" for rows whose + update gets rolled back if the surrounding destructive op fails. + """ + if action_id is None: + return + await session.execute( + update(AgentActionLog) + .where(AgentActionLog.id == action_id) + .values(reversible=True) + ) + + +async def _dispatch_reversibility_update(action_id: int | None) -> None: + """Best-effort dispatch of an ``action_log_updated`` custom event. + + Surfaces the post-SAVEPOINT reversibility flip to the SSE layer so + the chat tool card can flip its Revert button live. Defensive: + failures are logged at debug level and swallowed; the + REST endpoint ``GET /threads/.../actions`` is still authoritative. + + .. warning:: + Inside :func:`commit_staged_filesystem_state` we DEFER all + dispatches until the outer ``session.commit()`` succeeds — see + the ``deferred_dispatches`` queue in that function. Dispatching + from inside a SAVEPOINT block while the outer transaction is + still pending would emit ``reversible=true`` for rows whose + snapshots get rolled back if the outer commit fails. Direct + callers (e.g. the optional stream-task fallback) that own the + full session lifetime can still call this helper inline. + """ + if action_id is None: + return + try: + await adispatch_custom_event( + "action_log_updated", + {"id": int(action_id), "reversible": True}, + ) + except Exception: + logger.debug( + "kb_persistence.aafter_agent failed to dispatch action_log_updated", + exc_info=True, + ) + + +# --------------------------------------------------------------------------- +# Snapshot helpers +# --------------------------------------------------------------------------- +# +# Best-effort helpers swallow + log so a snapshot failure can never break +# the destructive op for non-destructive tools (write/edit/move/mkdir). +# Strict helpers run inside the SAME ``begin_nested()`` SAVEPOINT as the +# destructive DELETE — failure aborts the savepoint and leaves the doc / +# folder intact, so revertable ops never become irreversible silently. + + +def _doc_revision_payload( + doc: Document, + *, + chunks_before: list[dict[str, str]] | None = None, +) -> dict[str, Any]: + """Pre-mutation field map for ``DocumentRevision``.""" + metadata = dict(doc.document_metadata or {}) + return { + "content_before": doc.content, + "title_before": doc.title, + "folder_id_before": doc.folder_id, + "chunks_before": chunks_before, + "metadata_before": metadata or None, + } + + +async def _load_chunks_for_snapshot( + session: AsyncSession, *, doc_id: int +) -> list[dict[str, str]]: + rows = await session.execute( + select(Chunk.content).where(Chunk.document_id == doc_id).order_by(Chunk.id) + ) + return [{"content": row.content} for row in rows.all() if row.content is not None] + + +async def _snapshot_document_pre_write( + session: AsyncSession, + *, + doc: Document, + action_id: int | None, + search_space_id: int, + turn_id: str | None = None, + deferred_dispatches: list[int] | None = None, +) -> int | None: + """Best-effort snapshot ahead of an in-place ``write_file``/``edit_file``. + + When ``deferred_dispatches`` is provided, on success the action id + is APPENDED to it and the SSE dispatch is left to the caller (so it + can be flushed only after the outer ``session.commit()`` succeeds). + """ + try: + async with session.begin_nested(): + chunks = await _load_chunks_for_snapshot(session, doc_id=doc.id) + payload = _doc_revision_payload(doc, chunks_before=chunks) + rev = DocumentRevision( + document_id=doc.id, + search_space_id=search_space_id, + created_by_turn_id=turn_id, + agent_action_id=action_id, + **payload, + ) + session.add(rev) + await session.flush() + await _mark_action_reversible(session, action_id=action_id) + rev_id = rev.id + if deferred_dispatches is None: + await _dispatch_reversibility_update(action_id) + elif action_id is not None: + deferred_dispatches.append(int(action_id)) + return rev_id + except Exception as exc: # pragma: no cover - defensive + logger.warning( + "kb_persistence: pre-write snapshot for doc=%s failed: %s", + doc.id, + exc, + ) + return None + + +async def _snapshot_document_pre_create( + session: AsyncSession, + *, + action_id: int | None, + search_space_id: int, + turn_id: str | None = None, + deferred_dispatches: list[int] | None = None, +) -> int | None: + """Best-effort placeholder revision for a fresh ``write_file`` create. + + ``document_id`` is patched in by the caller after the new doc is + flushed and gets an ID; the placeholder lets us bind the action_id + even though no parent row exists yet. + """ + try: + async with session.begin_nested(): + rev = DocumentRevision( + document_id=None, + search_space_id=search_space_id, + content_before=None, + title_before=None, + folder_id_before=None, + chunks_before=None, + metadata_before=None, + created_by_turn_id=turn_id, + agent_action_id=action_id, + ) + session.add(rev) + await session.flush() + await _mark_action_reversible(session, action_id=action_id) + rev_id = rev.id + if deferred_dispatches is None: + await _dispatch_reversibility_update(action_id) + elif action_id is not None: + deferred_dispatches.append(int(action_id)) + return rev_id + except Exception as exc: # pragma: no cover - defensive + logger.warning("kb_persistence: pre-create snapshot failed: %s", exc) + return None + + +async def _snapshot_document_pre_move( + session: AsyncSession, + *, + doc: Document, + action_id: int | None, + search_space_id: int, + turn_id: str | None = None, + deferred_dispatches: list[int] | None = None, +) -> int | None: + """Best-effort snapshot ahead of a ``move_file``.""" + try: + async with session.begin_nested(): + payload = _doc_revision_payload(doc, chunks_before=None) + rev = DocumentRevision( + document_id=doc.id, + search_space_id=search_space_id, + created_by_turn_id=turn_id, + agent_action_id=action_id, + **payload, + ) + session.add(rev) + await session.flush() + await _mark_action_reversible(session, action_id=action_id) + rev_id = rev.id + if deferred_dispatches is None: + await _dispatch_reversibility_update(action_id) + elif action_id is not None: + deferred_dispatches.append(int(action_id)) + return rev_id + except Exception as exc: # pragma: no cover - defensive + logger.warning( + "kb_persistence: pre-move snapshot for doc=%s failed: %s", + doc.id, + exc, + ) + return None + + +async def _snapshot_folder_pre_mkdir( + session: AsyncSession, + *, + folder: Folder, + action_id: int | None, + search_space_id: int, + turn_id: str | None = None, + deferred_dispatches: list[int] | None = None, +) -> int | None: + """Best-effort placeholder for an ``mkdir`` (revert deletes the folder). + + The "before" state is "did not exist", so all ``*_before`` fields are + NULL — revert routes by ``tool_name == "mkdir"`` and DELETEs. + """ + try: + async with session.begin_nested(): + rev = FolderRevision( + folder_id=folder.id, + search_space_id=search_space_id, + name_before=None, + parent_id_before=None, + position_before=None, + created_by_turn_id=turn_id, + agent_action_id=action_id, + ) + session.add(rev) + await session.flush() + await _mark_action_reversible(session, action_id=action_id) + rev_id = rev.id + if deferred_dispatches is None: + await _dispatch_reversibility_update(action_id) + elif action_id is not None: + deferred_dispatches.append(int(action_id)) + return rev_id + except Exception as exc: # pragma: no cover - defensive + logger.warning( + "kb_persistence: pre-mkdir snapshot for folder=%s failed: %s", + folder.id, + exc, + ) + return None + + # --------------------------------------------------------------------------- # Commit body # --------------------------------------------------------------------------- @@ -342,12 +692,20 @@ async def commit_staged_filesystem_state( search_space_id: int, created_by_id: str | None, filesystem_mode: FilesystemMode = FilesystemMode.CLOUD, + thread_id: int | None = None, dispatch_events: bool = True, ) -> dict[str, Any] | None: """Commit all staged filesystem changes; return the state delta for reducers. Shared between :class:`KnowledgeBasePersistenceMiddleware.aafter_agent` and the optional stream-task fallback. + + When ``flags.enable_action_log`` is on every destructive op also writes + a ``DocumentRevision`` / ``FolderRevision`` snapshot bound to the + originating ``AgentActionLog`` row via ``tool_call_id``. Snapshot + durability is best-effort for non-destructive ops and STRICT for + ``rm``/``rmdir`` (snapshot + DELETE share a SAVEPOINT — snapshot + failure aborts the delete). """ if filesystem_mode != FilesystemMode.CLOUD: return None @@ -360,8 +718,20 @@ async def commit_staged_filesystem_state( files: dict[str, Any] = state_dict.get("files") or {} staged_dirs: list[str] = list(state_dict.get("staged_dirs") or []) + staged_dir_tool_calls: dict[str, str] = dict( + state_dict.get("staged_dir_tool_calls") or {} + ) pending_moves: list[dict[str, Any]] = list(state_dict.get("pending_moves") or []) + pending_deletes: list[dict[str, Any]] = list( + state_dict.get("pending_deletes") or [] + ) + pending_dir_deletes: list[dict[str, Any]] = list( + state_dict.get("pending_dir_deletes") or [] + ) dirty_paths: list[str] = list(state_dict.get("dirty_paths") or []) + dirty_path_tool_calls: dict[str, str] = dict( + state_dict.get("dirty_path_tool_calls") or {} + ) doc_id_by_path: dict[str, int] = dict(state_dict.get("doc_id_by_path") or {}) kb_anon_doc = state_dict.get("kb_anon_doc") @@ -374,32 +744,112 @@ async def commit_staged_filesystem_state( return { "dirty_paths": [_CLEAR], "staged_dirs": [_CLEAR], + "staged_dir_tool_calls": {_CLEAR: True}, "pending_moves": [_CLEAR], + "pending_deletes": [_CLEAR], + "pending_dir_deletes": [_CLEAR], + "dirty_path_tool_calls": {_CLEAR: True}, "files": dict.fromkeys(temp_paths), } - if not (staged_dirs or pending_moves or dirty_paths): + if not ( + staged_dirs + or pending_moves + or dirty_paths + or pending_deletes + or pending_dir_deletes + ): return None + flags = get_flags() + snapshot_enabled = flags.enable_action_log + + # De-duplicate pending deletes per-path while preserving the latest + # tool_call_id (the one the user is most likely to revert via the UI). + file_delete_paths: dict[str, str] = {} + for entry in pending_deletes: + if not isinstance(entry, dict): + continue + path = str(entry.get("path") or "") + if path: + file_delete_paths[path] = str(entry.get("tool_call_id") or "") + dir_delete_paths: dict[str, str] = {} + for entry in pending_dir_deletes: + if not isinstance(entry, dict): + continue + path = str(entry.get("path") or "") + if path: + dir_delete_paths[path] = str(entry.get("tool_call_id") or "") + committed_creates: list[dict[str, Any]] = [] committed_updates: list[dict[str, Any]] = [] + committed_deletes: list[dict[str, Any]] = [] + committed_folder_deletes: list[dict[str, Any]] = [] discarded: list[str] = [] applied_moves: list[dict[str, Any]] = [] doc_id_path_tombstones: dict[str, int | None] = {} tree_changed = False + # Reversibility-flip dispatches are deferred until AFTER the outer + # ``session.commit()`` succeeds. Dispatching from inside the + # SAVEPOINT chain while the outer transaction is still pending + # would emit ``reversible=true`` for rows whose snapshots get rolled + # back if the final commit raises. Snapshot helpers append on + # success; we drain this list after commit and silently abandon it + # on rollback so the UI stays consistent with durable state. + deferred_dispatches: list[int] = [] try: async with shielded_async_session() as session: + # ------------------------------------------------------------------ + # Resolve action-id bindings up front. One SELECT per turn for all + # tool_call_ids, NOT one per op — important because a turn that + # touches 50 paths would otherwise issue 50 lookups. + # ------------------------------------------------------------------ + action_id_by_call: dict[str, int] = {} + if snapshot_enabled and thread_id is not None: + tool_call_ids: set[str] = set() + tool_call_ids.update( + tcid for tcid in staged_dir_tool_calls.values() if tcid + ) + for move in pending_moves: + tcid = str(move.get("tool_call_id") or "") + if tcid: + tool_call_ids.add(tcid) + tool_call_ids.update( + tcid for tcid in dirty_path_tool_calls.values() if tcid + ) + tool_call_ids.update( + tcid for tcid in file_delete_paths.values() if tcid + ) + tool_call_ids.update(tcid for tcid in dir_delete_paths.values() if tcid) + action_id_by_call = await _find_action_ids_batch( + session, + thread_id=thread_id, + tool_call_ids=tool_call_ids, + ) + + def _action_id_for(tool_call_id: str | None) -> int | None: + if not snapshot_enabled or not tool_call_id: + return None + return action_id_by_call.get(str(tool_call_id)) + + turn_id_for_revision = ( + next(iter(action_id_by_call), None) if action_id_by_call else None + ) + + # ------------------------------------------------------------------ + # 1. staged_dirs -> Folder rows. Snapshot post-flush so the new + # folder_id is available for the FK. + # ------------------------------------------------------------------ for folder_path in staged_dirs: if not isinstance(folder_path, str): continue if not folder_path.startswith(DOCUMENTS_ROOT): continue - rel = folder_path[len(DOCUMENTS_ROOT) :].strip("/") - folder_parts_full = [p for p in rel.split("/") if p] + folder_parts_full = _split_folder_path(folder_path) if not folder_parts_full: continue - await _ensure_folder_hierarchy( + folder_id = await _ensure_folder_hierarchy( session, search_space_id=search_space_id, created_by_id=created_by_id, @@ -407,7 +857,61 @@ async def commit_staged_filesystem_state( ) tree_changed = True + if snapshot_enabled and folder_id is not None: + tcid = staged_dir_tool_calls.get(folder_path) + action_id = _action_id_for(tcid) + if action_id is not None: + # Re-read the folder for the snapshot. + result = await session.execute( + select(Folder).where(Folder.id == folder_id) + ) + folder_row = result.scalar_one_or_none() + if folder_row is not None: + await _snapshot_folder_pre_mkdir( + session, + folder=folder_row, + action_id=action_id, + search_space_id=search_space_id, + turn_id=tcid, + deferred_dispatches=deferred_dispatches, + ) + + # ------------------------------------------------------------------ + # 2. pending_moves. Snapshot pre-move (in-place restore on revert). + # ------------------------------------------------------------------ for move in pending_moves: + source = str(move.get("source") or "") + if snapshot_enabled and source: + tcid = str(move.get("tool_call_id") or "") + action_id = _action_id_for(tcid) + if action_id is not None: + # Resolve the doc to snapshot BEFORE we mutate it. + doc_id_pre = doc_id_by_path.get(source) + document_pre: Document | None = None + if doc_id_pre is not None: + res_pre = await session.execute( + select(Document).where( + Document.id == doc_id_pre, + Document.search_space_id == search_space_id, + ) + ) + document_pre = res_pre.scalar_one_or_none() + if document_pre is None: + document_pre = await virtual_path_to_doc( + session, + search_space_id=search_space_id, + virtual_path=source, + ) + if document_pre is not None: + await _snapshot_document_pre_move( + session, + doc=document_pre, + action_id=action_id, + search_space_id=search_space_id, + turn_id=tcid, + deferred_dispatches=deferred_dispatches, + ) + applied = await _apply_move( session, search_space_id=search_space_id, @@ -431,8 +935,13 @@ async def commit_staged_filesystem_state( path = move_alias[path] return path + # ------------------------------------------------------------------ + # 3. dirty_paths -> writes/edits. Skip any path queued for ``rm`` + # this turn so a write+rm sequence doesn't recreate the doc. + # ------------------------------------------------------------------ kb_dirty_seen: set[str] = set() kb_dirty: list[str] = [] + kb_dirty_origin: dict[str, str] = {} for raw in dirty_paths: if not isinstance(raw, str): continue @@ -441,8 +950,12 @@ async def commit_staged_filesystem_state( continue if final in kb_dirty_seen: continue + if final in file_delete_paths: + discarded.append(final) + continue kb_dirty_seen.add(final) kb_dirty.append(final) + kb_dirty_origin[final] = raw for path in kb_dirty: basename = _basename(path) @@ -454,6 +967,15 @@ async def commit_staged_filesystem_state( continue content = "\n".join(file_data.get("content") or []) doc_id = doc_id_by_path.get(path) + # Path ↔ tool_call_id binding: the dirty_paths list dedupes via + # _add_unique_reducer, so we look up the latest tool_call_id by + # path (or by the un-renamed origin). + origin = kb_dirty_origin.get(path, path) + tcid = dirty_path_tool_calls.get(path) or dirty_path_tool_calls.get( + origin + ) + action_id = _action_id_for(tcid) + if doc_id is None: # The in-memory ``doc_id_by_path`` is per-thread and starts # empty in every new chat. If the agent writes to a path @@ -470,6 +992,23 @@ async def commit_staged_filesystem_state( doc_id = existing.id doc_id_by_path[path] = existing.id if doc_id is not None: + if snapshot_enabled and action_id is not None: + result_doc = await session.execute( + select(Document).where( + Document.id == doc_id, + Document.search_space_id == search_space_id, + ) + ) + existing_doc = result_doc.scalar_one_or_none() + if existing_doc is not None: + await _snapshot_document_pre_write( + session, + doc=existing_doc, + action_id=action_id, + search_space_id=search_space_id, + turn_id=tcid, + deferred_dispatches=deferred_dispatches, + ) updated = await _update_document( session, doc_id=doc_id, @@ -492,12 +1031,21 @@ async def commit_staged_filesystem_state( } ) else: - # Wrap each create in a SAVEPOINT so a residual - # ``IntegrityError`` (e.g. a deployment that hasn't run - # migration 133 yet, where ``documents.content_hash`` - # still carries its legacy global UNIQUE constraint) - # rolls back only this one create instead of poisoning - # the whole turn's transaction. + # Fresh create. Wrap each create in a SAVEPOINT so a + # residual ``IntegrityError`` (e.g. a deployment that + # hasn't run migration 133 yet, where + # ``documents.content_hash`` still carries its legacy + # global UNIQUE constraint) rolls back only this one + # create instead of poisoning the whole turn. + placeholder_revision_id: int | None = None + if snapshot_enabled and action_id is not None: + placeholder_revision_id = await _snapshot_document_pre_create( + session, + action_id=action_id, + search_space_id=search_space_id, + turn_id=tcid, + deferred_dispatches=deferred_dispatches, + ) try: async with session.begin_nested(): new_doc = await _create_document( @@ -511,14 +1059,16 @@ async def commit_staged_filesystem_state( logger.warning( "kb_persistence: skipping %s create: %s", path, exc ) + # Roll back the placeholder revision since the create + # never happened. + if placeholder_revision_id is not None: + await session.execute( + delete(DocumentRevision).where( + DocumentRevision.id == placeholder_revision_id + ) + ) continue except IntegrityError as exc: - # The path-uniqueness check above already protected - # against ``unique_identifier_hash`` collisions, so - # the most likely culprit is the legacy - # ``ix_documents_content_hash`` UNIQUE constraint - # that migration 133 drops. Log loudly so operators - # know to run the migration; do NOT silently swallow. msg = str(exc.orig) if exc.orig is not None else str(exc) logger.error( "kb_persistence: IntegrityError creating %s: %s. " @@ -528,8 +1078,20 @@ async def commit_staged_filesystem_state( path, msg, ) + if placeholder_revision_id is not None: + await session.execute( + delete(DocumentRevision).where( + DocumentRevision.id == placeholder_revision_id + ) + ) continue doc_id_by_path[path] = new_doc.id + if placeholder_revision_id is not None: + await session.execute( + update(DocumentRevision) + .where(DocumentRevision.id == placeholder_revision_id) + .values(document_id=new_doc.id) + ) committed_creates.append( { "id": new_doc.id, @@ -545,13 +1107,234 @@ async def commit_staged_filesystem_state( ) tree_changed = True + # ------------------------------------------------------------------ + # 4. pending_deletes -> ``rm``. STRICT durability: snapshot + DELETE + # share a SAVEPOINT. If the snapshot insert fails, the DELETE + # rolls back too and we surface the error rather than silently + # making the data irreversible. + # ------------------------------------------------------------------ + for raw_path, tcid in file_delete_paths.items(): + final = _final_path(raw_path) + if not final.startswith(DOCUMENTS_ROOT + "/"): + continue + action_id = _action_id_for(tcid) + + # Resolve the doc. + doc_id_for_delete = doc_id_by_path.get(final) + document_to_delete: Document | None = None + if doc_id_for_delete is not None: + result = await session.execute( + select(Document).where( + Document.id == doc_id_for_delete, + Document.search_space_id == search_space_id, + ) + ) + document_to_delete = result.scalar_one_or_none() + if document_to_delete is None: + document_to_delete = await virtual_path_to_doc( + session, + search_space_id=search_space_id, + virtual_path=final, + ) + if document_to_delete is None: + logger.info( + "kb_persistence: skipping rm %s (target not found)", final + ) + continue + + doc_pk = document_to_delete.id + doc_title = document_to_delete.title + doc_folder_id = document_to_delete.folder_id + + try: + async with session.begin_nested(): + # Strict: snapshot first; failure aborts the delete. + if snapshot_enabled and action_id is not None: + chunks = await _load_chunks_for_snapshot( + session, doc_id=doc_pk + ) + payload = _doc_revision_payload( + document_to_delete, chunks_before=chunks + ) + rev = DocumentRevision( + document_id=doc_pk, + search_space_id=search_space_id, + created_by_turn_id=tcid, + agent_action_id=action_id, + **payload, + ) + session.add(rev) + await session.flush() + await _mark_action_reversible(session, action_id=action_id) + await session.execute( + delete(Document).where(Document.id == doc_pk) + ) + except Exception as exc: + logger.exception( + "kb_persistence: strict rm SAVEPOINT for path=%s failed: %s", + final, + exc, + ) + continue + + # B1 — SAVEPOINT released. Defer the reversibility-flip + # dispatch until AFTER the outer commit succeeds so we + # never tell the UI a row is reversible if its snapshot + # gets rolled back. + if snapshot_enabled and action_id is not None: + deferred_dispatches.append(int(action_id)) + + doc_id_by_path.pop(final, None) + doc_id_path_tombstones[final] = None + committed_deletes.append( + { + "id": doc_pk, + "title": doc_title, + "documentType": DocumentType.NOTE.value, + "searchSpaceId": search_space_id, + "folderId": doc_folder_id, + "createdById": str(created_by_id) if created_by_id else None, + "virtualPath": final, + } + ) + tree_changed = True + + # ------------------------------------------------------------------ + # 5. pending_dir_deletes -> ``rmdir``. STRICT durability + final + # emptiness check (after step 4's deletes have run, an "empty + # mid-turn" directory really IS empty in DB now). + # ------------------------------------------------------------------ + for raw_path, tcid in dir_delete_paths.items(): + final = _final_path(raw_path) + if not final.startswith(DOCUMENTS_ROOT + "/"): + continue + action_id = _action_id_for(tcid) + + folder_parts = _split_folder_path(final) + if not folder_parts: + continue + folder_id = await _resolve_folder_id( + session, + search_space_id=search_space_id, + folder_parts=folder_parts, + ) + if folder_id is None: + logger.info( + "kb_persistence: skipping rmdir %s (folder not found)", final + ) + continue + + # Re-check emptiness against in-DB state. + docs_in_folder = await session.execute( + select(Document.id) + .where(Document.folder_id == folder_id) + .where(Document.search_space_id == search_space_id) + .limit(1) + ) + if docs_in_folder.scalar_one_or_none() is not None: + logger.warning( + "kb_persistence: refusing rmdir %s — non-empty at commit time", + final, + ) + continue + child_folders = await session.execute( + select(Folder.id) + .where(Folder.parent_id == folder_id) + .where(Folder.search_space_id == search_space_id) + .limit(1) + ) + if child_folders.scalar_one_or_none() is not None: + logger.warning( + "kb_persistence: refusing rmdir %s — has child folders " + "at commit time", + final, + ) + continue + + folder_to_delete_res = await session.execute( + select(Folder).where(Folder.id == folder_id) + ) + folder_to_delete = folder_to_delete_res.scalar_one_or_none() + if folder_to_delete is None: + continue + + folder_pk = folder_to_delete.id + folder_name = folder_to_delete.name + folder_parent_id = folder_to_delete.parent_id + folder_position = folder_to_delete.position + + try: + async with session.begin_nested(): + if snapshot_enabled and action_id is not None: + rev = FolderRevision( + folder_id=folder_pk, + search_space_id=search_space_id, + name_before=folder_name, + parent_id_before=folder_parent_id, + position_before=folder_position, + created_by_turn_id=tcid, + agent_action_id=action_id, + ) + session.add(rev) + await session.flush() + await _mark_action_reversible(session, action_id=action_id) + await session.execute( + delete(Folder).where(Folder.id == folder_pk) + ) + except Exception as exc: + logger.exception( + "kb_persistence: strict rmdir SAVEPOINT for path=%s failed: %s", + final, + exc, + ) + continue + + # B1 — SAVEPOINT released. Defer the reversibility-flip + # dispatch until AFTER the outer commit succeeds so we + # never tell the UI a row is reversible if its snapshot + # gets rolled back. + if snapshot_enabled and action_id is not None: + deferred_dispatches.append(int(action_id)) + + committed_folder_deletes.append( + { + "id": folder_pk, + "name": folder_name, + "searchSpaceId": search_space_id, + "parentId": folder_parent_id, + "virtualPath": final, + } + ) + tree_changed = True + await session.commit() except Exception: # pragma: no cover - rollback safety net logger.exception( "kb_persistence: commit failed (search_space=%s)", search_space_id ) + # Outer commit raised — every SAVEPOINT-released change above + # (snapshots + reversibility flips) is now rolled back. Drop + # the deferred SSE dispatches so the UI stays consistent with + # durable state. + deferred_dispatches.clear() return None + # Outer commit succeeded; flush deferred reversibility-flip + # dispatches now so the chat tool card can light up its Revert + # button without re-fetching ``GET /threads/.../actions``. De-dup + # to avoid emitting the same id twice (e.g. write-then-rm in the + # same turn dispatches once for each snapshot site). + if deferred_dispatches and dispatch_events: + for action_id in dict.fromkeys(deferred_dispatches): + try: + await _dispatch_reversibility_update(action_id) + except Exception: + logger.debug( + "kb_persistence: deferred reversibility dispatch failed for action_id=%s", + action_id, + exc_info=True, + ) + if dispatch_events: for payload in committed_creates: try: @@ -567,11 +1350,34 @@ async def commit_staged_filesystem_state( logger.exception( "kb_persistence: failed to dispatch document_updated event" ) + for payload in committed_deletes: + try: + dispatch_custom_event("document_deleted", payload) + except Exception: + logger.exception( + "kb_persistence: failed to dispatch document_deleted event" + ) + for payload in committed_folder_deletes: + try: + dispatch_custom_event("folder_deleted", payload) + except Exception: + logger.exception( + "kb_persistence: failed to dispatch folder_deleted event" + ) temp_paths = [ p for p in files if isinstance(p, str) and _basename(p).startswith(_TEMP_PREFIX) ] + # Tombstone every committed-delete path so a stale ``state["files"]`` entry + # (which als_info would otherwise interpret as content) cannot survive into + # the next turn and make a now-empty folder look non-empty. + deleted_file_paths = [ + str(payload.get("virtualPath") or "") + for payload in committed_deletes + if payload.get("virtualPath") + ] + doc_id_update: dict[str, int | None] = {**doc_id_path_tombstones} for payload in committed_creates: doc_id_update[str(payload.get("virtualPath") or "")] = int(payload["id"]) @@ -579,23 +1385,38 @@ async def commit_staged_filesystem_state( delta: dict[str, Any] = { "dirty_paths": [_CLEAR], "staged_dirs": [_CLEAR], + "staged_dir_tool_calls": {_CLEAR: True}, "pending_moves": [_CLEAR], + "pending_deletes": [_CLEAR], + "pending_dir_deletes": [_CLEAR], + "dirty_path_tool_calls": {_CLEAR: True}, } + files_delta: dict[str, Any] = {} if temp_paths: - delta["files"] = dict.fromkeys(temp_paths) + files_delta.update(dict.fromkeys(temp_paths)) + for path in deleted_file_paths: + files_delta[path] = None + if files_delta: + delta["files"] = files_delta if doc_id_update: delta["doc_id_by_path"] = doc_id_update if tree_changed: delta["tree_version"] = int(state_dict.get("tree_version") or 0) + 1 + # Avoid 'unused' lint when turn_id_for_revision was only useful for + # diagnostic purposes inside the SAVEPOINT chain above. + _ = turn_id_for_revision + logger.info( "kb_persistence: commit (search_space=%s) creates=%d updates=%d " - "moves=%d staged_dirs=%d discarded=%d", + "moves=%d staged_dirs=%d deletes=%d folder_deletes=%d discarded=%d", search_space_id, len(committed_creates), len(committed_updates), len(applied_moves), len(staged_dirs), + len(committed_deletes), + len(committed_folder_deletes), len(discarded), ) return delta @@ -618,10 +1439,12 @@ class KnowledgeBasePersistenceMiddleware(AgentMiddleware): # type: ignore[type- search_space_id: int, created_by_id: str | None, filesystem_mode: FilesystemMode, + thread_id: int | None = None, ) -> None: self.search_space_id = search_space_id self.created_by_id = created_by_id self.filesystem_mode = filesystem_mode + self.thread_id = thread_id async def aafter_agent( # type: ignore[override] self, @@ -636,6 +1459,7 @@ class KnowledgeBasePersistenceMiddleware(AgentMiddleware): # type: ignore[type- search_space_id=self.search_space_id, created_by_id=self.created_by_id, filesystem_mode=self.filesystem_mode, + thread_id=self.thread_id, ) diff --git a/surfsense_backend/app/agents/new_chat/middleware/kb_postgres_backend.py b/surfsense_backend/app/agents/new_chat/middleware/kb_postgres_backend.py index ddb2d4af1..7cf3bf8cd 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/kb_postgres_backend.py +++ b/surfsense_backend/app/agents/new_chat/middleware/kb_postgres_backend.py @@ -115,6 +115,12 @@ class KBPostgresBackend(BackendProtocol): def _pending_moves(self) -> list[dict[str, Any]]: return list(self.state.get("pending_moves") or []) + def _pending_deletes(self) -> list[dict[str, Any]]: + return list(self.state.get("pending_deletes") or []) + + def _pending_dir_deletes(self) -> list[dict[str, Any]]: + return list(self.state.get("pending_dir_deletes") or []) + def _kb_anon_doc(self) -> dict[str, Any] | None: anon = self.state.get("kb_anon_doc") return anon if isinstance(anon, dict) else None @@ -140,18 +146,28 @@ class KBPostgresBackend(BackendProtocol): return path return path.rstrip("/") if path != "/" else path - def _moved_view_paths( + def _pending_filesystem_view( self, existing: dict[str, dict[str, Any]], - ) -> tuple[set[str], dict[str, str]]: - """Apply ``pending_moves`` to a path set and return ``(removed, alias)``. + ) -> tuple[set[str], dict[str, str], set[str]]: + """Compute removed/aliased/dir-suppressed paths from staged ops. - Removed paths should disappear from listings; ``alias[source] = dest`` - means a virtual entry should appear at ``dest`` even if no DB row is - yet there. + Returns ``(removed, alias, deleted_dirs)`` where: + + * ``removed`` — paths to drop from listings (sources of pending moves + AND paths queued for ``rm``). + * ``alias`` — ``{source: dest}`` for pending moves; the dest should + appear as a virtual entry even when no DB row is at that path yet. + * ``deleted_dirs`` — folder paths queued for ``rmdir``; their entire + subtree (descendants) is suppressed from listings/glob/grep. + + Entries in ``existing`` (the ``files`` state cache) keyed by a + removed path are popped so a same-turn delete-after-write doesn't + leave a stale virtual file in listings. """ removed: set[str] = set() alias: dict[str, str] = {} + deleted_dirs: set[str] = set() for move in self._pending_moves(): src = move.get("source") dst = move.get("dest") @@ -160,7 +176,23 @@ class KBPostgresBackend(BackendProtocol): removed.add(src) alias[src] = dst existing.pop(src, None) - return removed, alias + for entry in self._pending_deletes(): + path = entry.get("path") if isinstance(entry, dict) else None + if not path: + continue + removed.add(path) + existing.pop(path, None) + for entry in self._pending_dir_deletes(): + path = entry.get("path") if isinstance(entry, dict) else None + if not path: + continue + deleted_dirs.add(path) + return removed, alias, deleted_dirs + + @staticmethod + def _is_dir_suppressed(path: str, deleted_dirs: set[str]) -> bool: + """Return True iff ``path`` is at-or-under any directory in ``deleted_dirs``.""" + return any(path == d or _is_under(path, d) for d in deleted_dirs) # ------------------------------------------------------------------ ls/read @@ -189,7 +221,7 @@ class KBPostgresBackend(BackendProtocol): seen.add(anon_path) files = self._state_files() - moved_removed, moved_alias = self._moved_view_paths(files) + moved_removed, moved_alias, deleted_dirs = self._pending_filesystem_view(files) if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/": try: @@ -203,7 +235,12 @@ class KBPostgresBackend(BackendProtocol): for info in db_infos: p = info.get("path", "") - if not p or p in seen or p in moved_removed: + if ( + not p + or p in seen + or p in moved_removed + or self._is_dir_suppressed(p, deleted_dirs) + ): continue infos.append(info) seen.add(p) @@ -212,6 +249,8 @@ class KBPostgresBackend(BackendProtocol): if src not in seen: if not _is_under(dst, normalized): continue + if self._is_dir_suppressed(dst, deleted_dirs): + continue rel = ( dst[len(normalized) :].lstrip("/") if normalized != "/" @@ -247,6 +286,8 @@ class KBPostgresBackend(BackendProtocol): continue if not _is_under(staged, normalized): continue + if self._is_dir_suppressed(staged, deleted_dirs): + continue rel = ( staged[len(normalized) :].lstrip("/") if normalized != "/" @@ -265,14 +306,26 @@ class KBPostgresBackend(BackendProtocol): for sub in sorted(subdir_paths): if sub in seen: continue + if self._is_dir_suppressed(sub, deleted_dirs): + continue infos.append(FileInfo(path=sub, is_dir=True, size=0, modified_at="")) seen.add(sub) for path_key, fd in files.items(): if not isinstance(path_key, str) or path_key in seen: continue + # Tombstones (None values) are deletion markers from `rm`. The + # deepagents reducer normally pops them, but a stale tombstone + # surviving a checkpoint must NOT be reported as a child here — + # otherwise rmdir mistakenly sees the deleted file as content. + if fd is None: + continue if not _is_under(path_key, normalized) or path_key == normalized: continue + if path_key in moved_removed or self._is_dir_suppressed( + path_key, deleted_dirs + ): + continue if normalized == "/": rel = path_key.lstrip("/") else: @@ -550,10 +603,12 @@ class KBPostgresBackend(BackendProtocol): seen: set[str] = set() files = self._state_files() - moved_removed, _ = self._moved_view_paths(files) + moved_removed, _, deleted_dirs = self._pending_filesystem_view(files) regex = re.compile(fnmatch.translate(pattern)) for path_key, fd in files.items(): - if path_key in moved_removed: + if path_key in moved_removed or self._is_dir_suppressed( + path_key, deleted_dirs + ): continue if not _is_under(path_key, normalized): continue @@ -595,7 +650,11 @@ class KBPostgresBackend(BackendProtocol): folder_id=row.folder_id, index=index, ) - if candidate in seen or candidate in moved_removed: + if ( + candidate in seen + or candidate in moved_removed + or self._is_dir_suppressed(candidate, deleted_dirs) + ): continue if not _is_under(candidate, normalized): continue @@ -634,10 +693,12 @@ class KBPostgresBackend(BackendProtocol): matches: list[GrepMatch] = [] files = self._state_files() - moved_removed, _ = self._moved_view_paths(files) + moved_removed, _, deleted_dirs = self._pending_filesystem_view(files) glob_re = re.compile(fnmatch.translate(glob)) if glob else None for path_key, fd in files.items(): - if path_key in moved_removed: + if path_key in moved_removed or self._is_dir_suppressed( + path_key, deleted_dirs + ): continue if not _is_under(path_key, normalized): continue @@ -695,7 +756,11 @@ class KBPostgresBackend(BackendProtocol): ) for doc_id, chunk_id, content in chunk_buffer: candidate = doc_id_to_path.get(doc_id) - if not candidate or candidate in moved_removed: + if ( + not candidate + or candidate in moved_removed + or self._is_dir_suppressed(candidate, deleted_dirs) + ): continue if not _is_under(candidate, normalized): continue @@ -769,7 +834,7 @@ class KBPostgresBackend(BackendProtocol): return {"entries": [], "truncated": False} files = self._state_files() - moved_removed, _ = self._moved_view_paths(files) + moved_removed, _, deleted_dirs = self._pending_filesystem_view(files) anon = self._kb_anon_doc() anon_path = str(anon.get("path") or "") if anon else "" @@ -795,6 +860,8 @@ class KBPostgresBackend(BackendProtocol): for _fid, fpath in sorted(index.folder_paths.items(), key=lambda kv: kv[1]): if not _is_under(fpath, normalized): continue + if self._is_dir_suppressed(fpath, deleted_dirs): + continue depth = _depth_of(fpath) if max_depth is not None and depth > max_depth: continue @@ -811,6 +878,8 @@ class KBPostgresBackend(BackendProtocol): for staged in self._staged_dirs(): if not _is_under(staged, normalized): continue + if self._is_dir_suppressed(staged, deleted_dirs): + continue depth = _depth_of(staged) if max_depth is not None and depth > max_depth: continue @@ -835,7 +904,9 @@ class KBPostgresBackend(BackendProtocol): folder_id=row.folder_id, index=index, ) - if candidate in moved_removed: + if candidate in moved_removed or self._is_dir_suppressed( + candidate, deleted_dirs + ): continue if not _is_under(candidate, normalized): continue @@ -875,6 +946,10 @@ class KBPostgresBackend(BackendProtocol): continue if not _is_under(path_key, normalized): continue + if path_key in moved_removed or self._is_dir_suppressed( + path_key, deleted_dirs + ): + continue if any(e["path"] == path_key for e in entries): continue if not ( diff --git a/surfsense_backend/app/agents/new_chat/middleware/knowledge_tree.py b/surfsense_backend/app/agents/new_chat/middleware/knowledge_tree.py index 467d19747..e67be8221 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/knowledge_tree.py +++ b/surfsense_backend/app/agents/new_chat/middleware/knowledge_tree.py @@ -201,6 +201,12 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg] ) all_paths = sorted(set(folder_paths + doc_paths + [DOCUMENTS_ROOT])) + # Pre-compute which folders have at least one descendant (folder or doc). + # A folder is "empty" iff no path in `all_paths` is strictly under it. + # Used to emit an explicit "(empty)" marker so the LLM doesn't have to + # infer emptiness from indentation alone. + non_empty_folders = self._compute_non_empty_folders(folder_paths, doc_paths) + lines: list[str] = [] for path in all_paths: depth = ( @@ -214,7 +220,10 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg] path.rsplit("/", 1)[-1] if path != DOCUMENTS_ROOT else "/documents" ) if is_dir: - lines.append(f"{indent}{display}/") + if path != DOCUMENTS_ROOT and path not in non_empty_folders: + lines.append(f"{indent}{display}/ (empty)") + else: + lines.append(f"{indent}{display}/") else: lines.append(f"{indent}{display}") if len(lines) >= self.max_entries: @@ -235,6 +244,35 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg] return self._format_root_summary(folder_paths, doc_paths) + @staticmethod + def _compute_non_empty_folders( + folder_paths: list[str], doc_paths: list[str] + ) -> set[str]: + """Return the set of folder paths that contain at least one descendant. + + A folder is "non-empty" if any document path or any other folder path + is strictly under it. Documents propagate emptiness up to every + ancestor folder, while a sub-folder only marks its direct ancestors + non-empty (so a chain of empty folders all read ``(empty)``). + """ + non_empty: set[str] = set() + folder_set = set(folder_paths) + + for doc_path in doc_paths: + parent = doc_path.rsplit("/", 1)[0] + while parent and parent != DOCUMENTS_ROOT: + if parent in folder_set: + non_empty.add(parent) + parent = parent.rsplit("/", 1)[0] + + for child in folder_paths: + parent = child.rsplit("/", 1)[0] + while parent and parent != DOCUMENTS_ROOT and parent in folder_set: + non_empty.add(parent) + parent = parent.rsplit("/", 1)[0] + + return non_empty + def _format_root_summary( self, folder_paths: list[str], doc_paths: list[str] ) -> str: diff --git a/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py b/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py index 565fcb48b..4db9943cb 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py +++ b/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py @@ -360,6 +360,74 @@ class LocalFolderBackend: self.move, source_path, destination_path, overwrite ) + def delete_file(self, file_path: str) -> WriteResult: + """Hard-delete a single file under root. + + Refuses directories, root, and missing paths. Roughly mirrors POSIX + ``rm path``; ``-r`` recursion and glob expansion are explicitly + out of scope. + """ + try: + path = self._resolve_virtual(file_path) + except ValueError: + return WriteResult(error=f"Error: Invalid path '{file_path}'") + with self._lock_for(file_path): + if not path.exists(): + return WriteResult(error=f"Error: File '{file_path}' not found") + if path.is_dir(): + return WriteResult( + error=( + f"Error: '{file_path}' is a directory. " + "Use rmdir for empty directories." + ) + ) + try: + os.unlink(path) + except OSError as exc: + return WriteResult( + error=f"Error: failed to delete '{file_path}': {exc}" + ) + return WriteResult(path=file_path, files_update=None) + + async def adelete_file(self, file_path: str) -> WriteResult: + return await asyncio.to_thread(self.delete_file, file_path) + + def rmdir(self, dir_path: str) -> WriteResult: + """Hard-delete an empty directory under root. + + Refuses files, root, missing paths, and non-empty directories. + ``os.rmdir`` is naturally empty-only; we pre-check so the error is + clearer for the agent. + """ + try: + path = self._resolve_virtual(dir_path) + except ValueError: + return WriteResult(error=f"Error: Invalid path '{dir_path}'") + with self._lock_for(dir_path): + if not path.exists(): + return WriteResult(error=f"Error: Directory '{dir_path}' not found") + if not path.is_dir(): + return WriteResult(error=f"Error: '{dir_path}' is not a directory") + try: + next(path.iterdir()) + except StopIteration: + pass + else: + return WriteResult( + error=( + f"Error: directory '{dir_path}' is not empty. " + "Remove its contents first." + ) + ) + try: + os.rmdir(path) + except OSError as exc: + return WriteResult(error=f"Error: failed to rmdir '{dir_path}': {exc}") + return WriteResult(path=dir_path, files_update=None) + + async def armdir(self, dir_path: str) -> WriteResult: + return await asyncio.to_thread(self.rmdir, dir_path) + def edit( self, file_path: str, diff --git a/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py b/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py index 93eabe6ff..a5add6248 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py +++ b/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py @@ -285,6 +285,34 @@ class MultiRootLocalFolderBackend: overwrite, ) + def delete_file(self, file_path: str) -> WriteResult: + try: + mount, local_path = self._split_mount_path(file_path) + except ValueError as exc: + return WriteResult(error=f"Error: {exc}") + result = self._mount_to_backend[mount].delete_file(local_path) + if result.path: + result.path = self._prefix_mount_path(mount, result.path) + return result + + async def adelete_file(self, file_path: str) -> WriteResult: + return await asyncio.to_thread(self.delete_file, file_path) + + def rmdir(self, dir_path: str) -> WriteResult: + try: + mount, local_path = self._split_mount_path(dir_path) + except ValueError as exc: + return WriteResult(error=f"Error: {exc}") + if local_path == "/": + return WriteResult(error=f"Error: cannot rmdir mount root '{dir_path}'") + result = self._mount_to_backend[mount].rmdir(local_path) + if result.path: + result.path = self._prefix_mount_path(mount, result.path) + return result + + async def armdir(self, dir_path: str) -> WriteResult: + return await asyncio.to_thread(self.rmdir, dir_path) + def edit( self, file_path: str, diff --git a/surfsense_backend/app/agents/new_chat/state_reducers.py b/surfsense_backend/app/agents/new_chat/state_reducers.py index ce32406e6..89fc86367 100644 --- a/surfsense_backend/app/agents/new_chat/state_reducers.py +++ b/surfsense_backend/app/agents/new_chat/state_reducers.py @@ -181,9 +181,13 @@ def _initial_filesystem_state() -> dict[str, Any]: return { "cwd": "/documents", "staged_dirs": [], + "staged_dir_tool_calls": {}, "pending_moves": [], + "pending_deletes": [], + "pending_dir_deletes": [], "doc_id_by_path": {}, "dirty_paths": [], + "dirty_path_tool_calls": {}, "kb_priority": [], "kb_matched_chunk_ids": {}, "kb_anon_doc": None, diff --git a/surfsense_backend/app/agents/new_chat/subagents/config.py b/surfsense_backend/app/agents/new_chat/subagents/config.py index b36d35fa0..84ca516e0 100644 --- a/surfsense_backend/app/agents/new_chat/subagents/config.py +++ b/surfsense_backend/app/agents/new_chat/subagents/config.py @@ -84,6 +84,8 @@ WRITE_TOOL_DENY_PATTERNS: tuple[str, ...] = ( "write_file", "move_file", "mkdir", + "rm", + "rmdir", "update_memory", "update_memory_team", "update_memory_private", diff --git a/surfsense_backend/app/agents/new_chat/tools/hitl.py b/surfsense_backend/app/agents/new_chat/tools/hitl.py index 8480e57b1..92248c2c9 100644 --- a/surfsense_backend/app/agents/new_chat/tools/hitl.py +++ b/surfsense_backend/app/agents/new_chat/tools/hitl.py @@ -30,6 +30,35 @@ from langgraph.types import interrupt logger = logging.getLogger(__name__) +# Tools that mirror the safety profile of ``write_file`` against the +# SurfSense KB: each call creates ONE artifact in the user's own workspace +# with no external visibility (drafts aren't sent; new files aren't shared +# unless the user shares them later). These are auto-approved by default +# so the agent can compose drafts and seed scratch files without a popup +# on every call. +# +# Members of this set still call ``request_approval`` exactly as before; +# the function returns immediately with ``decision_type="auto_approved"`` +# and the original params untouched. This preserves the call-site shape +# (logging, metadata fetching, account fallbacks) so the only behavior +# change is "no interrupt fires". +# +# To re-enable prompting, the future per-search-space rules table +# (``agent_permission_rules``) takes precedence — see the ``# (future)`` +# layer-3 comment in :mod:`app.agents.new_chat.chat_deepagent`. +DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset( + { + "create_gmail_draft", + "update_gmail_draft", + "create_notion_page", + "create_confluence_page", + "create_google_drive_file", + "create_dropbox_file", + "create_onedrive_file", + } +) + + @dataclass(frozen=True, slots=True) class HITLResult: """Outcome of a human-in-the-loop approval request.""" @@ -119,6 +148,19 @@ def request_approval( logger.info("Tool '%s' is user-trusted — skipping HITL", tool_name) return HITLResult(rejected=False, decision_type="trusted", params=dict(params)) + if tool_name in DEFAULT_AUTO_APPROVED_TOOLS: + # Default policy: low-stakes creation tools (drafts + new-file + # creates) skip HITL because they're as recoverable as a local + # ``write_file`` against the SurfSense KB. The user can still + # delete the artifact in <30s if it's wrong. + logger.info( + "Tool '%s' is in DEFAULT_AUTO_APPROVED_TOOLS — skipping HITL", + tool_name, + ) + return HITLResult( + rejected=False, decision_type="auto_approved", params=dict(params) + ) + approval = interrupt( { "type": action_type, diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 75342a8e1..91d19fb4f 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -689,6 +689,12 @@ class NewChatMessage(BaseModel, TimestampMixin): index=True, ) + # Per-turn correlation id sourced from ``configurable.turn_id`` at + # streaming time (``f"{chat_id}:{ms}"``). Nullable because legacy rows + # predate the column. Used by C1's edit-from-arbitrary-position to map + # a message back to the LangGraph checkpoint that produced its turn. + turn_id = Column(String(64), nullable=True, index=True) + # Relationships thread = relationship("NewChatThread", back_populates="messages") author = relationship("User") @@ -2292,7 +2298,13 @@ class AgentActionLog(BaseModel): nullable=False, index=True, ) + # ``turn_id`` historically held the LangChain ``tool_call.id``. It has + # been renamed to ``tool_call_id`` (with a parallel column kept for one + # release for back-compat). The real chat-turn id lives in + # ``chat_turn_id`` and is sourced from ``configurable.turn_id``. turn_id = Column(String(64), nullable=True, index=True) + tool_call_id = Column(String(64), nullable=True, index=True) + chat_turn_id = Column(String(64), nullable=True, index=True) message_id = Column(String(128), nullable=True, index=True) tool_name = Column(String(255), nullable=False, index=True) args = Column(JSONB, nullable=True) @@ -2318,6 +2330,16 @@ class AgentActionLog(BaseModel): __table_args__ = ( Index("ix_agent_action_log_thread_created", "thread_id", "created_at"), + # Partial unique index enforces "at most one revert per + # original action". Created in migration 137 with + # ``WHERE reverse_of IS NOT NULL`` so non-revert rows + # (the vast majority) are unaffected and NULLs don't collide. + Index( + "ux_agent_action_log_reverse_of", + "reverse_of", + unique=True, + postgresql_where=text("reverse_of IS NOT NULL"), + ), ) @@ -2332,10 +2354,13 @@ class DocumentRevision(BaseModel): __tablename__ = "document_revisions" + # ``ON DELETE SET NULL`` (not CASCADE) so the snapshot survives the + # hard-delete it describes — without that, ``rm`` would wipe the row + # we'd need to undo it. See migration ``134_relax_revision_fks``. document_id = Column( Integer, - ForeignKey("documents.id", ondelete="CASCADE"), - nullable=False, + ForeignKey("documents.id", ondelete="SET NULL"), + nullable=True, index=True, ) search_space_id = Column( @@ -2370,10 +2395,13 @@ class FolderRevision(BaseModel): __tablename__ = "folder_revisions" + # ``ON DELETE SET NULL`` (not CASCADE) so the snapshot survives the + # hard-delete it describes — without that, ``rmdir`` would wipe the + # row we'd need to undo it. See migration ``134_relax_revision_fks``. folder_id = Column( Integer, - ForeignKey("folders.id", ondelete="CASCADE"), - nullable=False, + ForeignKey("folders.id", ondelete="SET NULL"), + nullable=True, index=True, ) search_space_id = Column( diff --git a/surfsense_backend/app/routes/agent_action_log_route.py b/surfsense_backend/app/routes/agent_action_log_route.py index 458635761..2608aa3b1 100644 --- a/surfsense_backend/app/routes/agent_action_log_route.py +++ b/surfsense_backend/app/routes/agent_action_log_route.py @@ -65,6 +65,13 @@ class AgentActionRead(BaseModel): reverse_of: int | None reverted_by_action_id: int | None is_revert_action: bool + # Correlation ids added in migration 135. ``tool_call_id`` is the + # LangChain tool-call id (joinable to ``data-action-log`` SSE events + # via ``langchainToolCallId``). ``chat_turn_id`` is the per-turn id + # from ``configurable.turn_id`` (used by the + # ``revert-turn/{chat_turn_id}`` endpoint). + tool_call_id: str | None = None + chat_turn_id: str | None = None created_at: datetime @@ -172,6 +179,8 @@ async def list_thread_actions( reverse_of=row.reverse_of, reverted_by_action_id=revert_map.get(row.id), is_revert_action=row.reverse_of is not None, + tool_call_id=row.tool_call_id, + chat_turn_id=row.chat_turn_id, created_at=row.created_at, ) for row in rows diff --git a/surfsense_backend/app/routes/agent_revert_route.py b/surfsense_backend/app/routes/agent_revert_route.py index 12484ff53..711081b15 100644 --- a/surfsense_backend/app/routes/agent_revert_route.py +++ b/surfsense_backend/app/routes/agent_revert_route.py @@ -11,14 +11,25 @@ flag flips. Once enabled, the route runs: 4. Revert dispatch via :func:`app.services.revert_service.revert_action`. 5. Idempotent on retries: if the same action is reverted twice the second call returns 409 ``"already reverted"``. + +This module also hosts the per-turn batch endpoint +``POST /api/threads/{thread_id}/revert-turn/{chat_turn_id}``. It +walks every reversible action emitted during a chat turn in reverse +``created_at`` order and reverts each independently. Partial success is the +common case — the response always contains a per-action result list and a +``status`` of ``"ok"`` or ``"partial"``; we never collapse the batch into a +whole-batch 4xx. """ from __future__ import annotations import logging +from typing import Literal from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel from sqlalchemy import select +from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.feature_flags import get_flags @@ -97,6 +108,16 @@ async def revert_agent_action( action=action, requester_user_id=str(user.id) if user is not None else None, ) + except IntegrityError: + # Partial unique index ``ux_agent_action_log_reverse_of`` caught + # a concurrent revert. Translate to the existing 409 "already + # reverted" contract so racing clients see consistent + # behaviour with the pre-flight TOCTOU check above. + await session.rollback() + raise HTTPException( + status_code=409, + detail="This action has already been reverted.", + ) from None except Exception as err: logger.exception("Revert dispatch raised for action_id=%s", action_id) await session.rollback() @@ -105,7 +126,16 @@ async def revert_agent_action( ) from err if outcome.status == "ok": - await session.commit() + try: + await session.commit() + except IntegrityError: + # Race lost on commit (constraint enforced at flush in some + # configs but at commit in others — defensive). + await session.rollback() + raise HTTPException( + status_code=409, + detail="This action has already been reverted.", + ) from None return { "status": "ok", "message": outcome.message, @@ -122,3 +152,357 @@ async def revert_agent_action( raise HTTPException(status_code=501, detail=outcome.message) # not_reversible raise HTTPException(status_code=409, detail=outcome.message) + + +# --------------------------------------------------------------------------- +# Per-turn revert batch endpoint +# --------------------------------------------------------------------------- + + +PerActionStatus = Literal[ + "reverted", + "already_reverted", + "not_reversible", + "permission_denied", + "failed", + "skipped", +] + + +class RevertTurnActionResult(BaseModel): + """Per-action outcome inside a ``revert-turn`` batch response.""" + + action_id: int + tool_name: str + status: PerActionStatus + message: str | None = None + new_action_id: int | None = None + error: str | None = None + + +class RevertTurnResponse(BaseModel): + """Top-level response for ``POST /threads/{id}/revert-turn/{chat_turn_id}``. + + ``status`` is ``"ok"`` only when every reversible row succeeded. Any + ``failed`` / ``not_reversible`` / ``permission_denied`` entry downgrades + it to ``"partial"``. Empty turns (no rows) return ``"ok"`` with an empty + ``results`` list — callers should treat that as a no-op. + + Counter invariant: + ``total == reverted + already_reverted + not_reversible + + permission_denied + failed + skipped`` + + Frontend toasts and the ``RevertTurnButton`` summary rely on this + invariant to display "X of Y reverted, Z could not be undone" without + silently dropping ``permission_denied`` or ``skipped`` rows. + """ + + status: Literal["ok", "partial"] + chat_turn_id: str + total: int + reverted: int + already_reverted: int + not_reversible: int + permission_denied: int = 0 + failed: int = 0 + skipped: int = 0 + results: list[RevertTurnActionResult] + + +def _classify_outcome(outcome: RevertOutcome) -> PerActionStatus: + if outcome.status == "ok": + return "reverted" + if outcome.status == "permission_denied": + return "permission_denied" + # ``not_found`` / ``tool_unavailable`` / ``reverse_not_implemented`` / + # ``not_reversible`` are all surfaced to the caller as "not_reversible" + # — they share the same UX (this row cannot be undone) and only the + # ``message`` differs. + return "not_reversible" + + +async def _was_already_reverted(session: AsyncSession, *, action_id: int) -> int | None: + """Return the id of an existing successful revert row, if any. + + Single-action variant — kept for the post-IntegrityError lookup + path where we already know we lost a race for one specific id. + """ + stmt = select(AgentActionLog.id).where(AgentActionLog.reverse_of == action_id) + result = await session.execute(stmt) + return result.scalars().first() + + +async def _was_already_reverted_batch( + session: AsyncSession, *, action_ids: list[int] +) -> dict[int, int]: + """Batch idempotency probe for the revert-turn loop. + + Replaces N individual ``SELECT id WHERE reverse_of = :id`` queries + (one per row in the turn) with a single ``SELECT id, reverse_of + WHERE reverse_of IN (:ids)``. The route still iterates rows in + reverse-chronological order, but the membership check is O(1) per + iteration after this query. For a turn with 30 actions that's 30 + fewer round-trips through asyncpg + a smaller transaction footprint. + + Returns a ``{original_action_id -> revert_action_id}`` map. Missing + keys mean "not yet reverted" — callers should treat them as + eligible for revert. + """ + if not action_ids: + return {} + stmt = select(AgentActionLog.id, AgentActionLog.reverse_of).where( + AgentActionLog.reverse_of.in_(action_ids) + ) + result = await session.execute(stmt) + return { + original_id: revert_id + for revert_id, original_id in result.all() + if original_id is not None + } + + +@router.post( + "/threads/{thread_id}/revert-turn/{chat_turn_id}", + response_model=RevertTurnResponse, +) +async def revert_agent_turn( + thread_id: int, + chat_turn_id: str, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +) -> RevertTurnResponse: + """Revert every reversible action emitted during ``chat_turn_id``. + + Walks ``AgentActionLog`` rows for the turn in reverse ``created_at`` + order so dependencies (e.g. ``mkdir`` -> ``write_file`` inside the new + folder) unwind in the right sequence. Each action is reverted in its + own SAVEPOINT so a single failure does not poison the batch. + + Partial success is intentional and returned with HTTP 200. Callers + must inspect ``results[*].status`` to find rows that need attention. + """ + + flags = get_flags() + if flags.disable_new_agent_stack or not flags.enable_revert_route: + raise HTTPException( + status_code=503, + detail=( + "Revert is not available on this deployment yet. The route " + "ships before the UI; flip SURFSENSE_ENABLE_REVERT_ROUTE to " + "enable it." + ), + ) + + thread = await load_thread(session, thread_id=thread_id) + if thread is None: + raise HTTPException(status_code=404, detail="Thread not found.") + + # Reverse-chronological so the latest mutation in the turn unwinds + # first. ``id.desc()`` is the deterministic tiebreaker for actions + # written in the same millisecond. + rows_stmt = ( + select(AgentActionLog) + .where( + AgentActionLog.thread_id == thread_id, + AgentActionLog.chat_turn_id == chat_turn_id, + ) + .order_by(AgentActionLog.created_at.desc(), AgentActionLog.id.desc()) + ) + rows = (await session.execute(rows_stmt)).scalars().all() + + requester_user_id = str(user.id) if user is not None else None + results: list[RevertTurnActionResult] = [] + # Counters MUST be exhaustive so the response invariant + # ``total == sum(counters)`` always holds. Frontend toasts and + # ``RevertTurnButton`` rely on this for "X of Y reverted" math. + counts: dict[str, int] = { + "reverted": 0, + "already_reverted": 0, + "not_reversible": 0, + "permission_denied": 0, + "failed": 0, + "skipped": 0, + } + + # Single batched idempotency probe replaces the previous per-row + # SELECT. ``rows`` are filtered in the loop so we pre-collect only + # the original-action ids (skip rows that are themselves + # reverts). + eligible_ids = [r.id for r in rows if r.reverse_of is None] + already_reverted_map = await _was_already_reverted_batch( + session, action_ids=eligible_ids + ) + + for action in rows: + # Skip rows that ARE reverts of an earlier action — reverting a + # revert is meaningless inside a batch (the user wants to wipe + # the original effects, not chase tail). + if action.reverse_of is not None: + counts["skipped"] += 1 + results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="skipped", + message="Row is itself a revert action; skipped.", + ) + ) + continue + + # Idempotency: surface "already_reverted" instead of failing. + existing_revert_id = already_reverted_map.get(action.id) + if existing_revert_id is not None: + counts["already_reverted"] += 1 + results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="already_reverted", + new_action_id=existing_revert_id, + ) + ) + continue + + if not can_revert( + requester_user_id=requester_user_id, + action=action, + is_admin=False, + ): + counts["permission_denied"] += 1 + results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="permission_denied", + message="You are not allowed to revert this action.", + ) + ) + continue + + # Per-row SAVEPOINT so one failed revert never poisons later + # successful ones. + try: + async with session.begin_nested(): + outcome = await revert_action( + session, + action=action, + requester_user_id=requester_user_id, + ) + if outcome.status != "ok": + raise _OutcomeRollbackError(outcome) + except _OutcomeRollbackError as rollback: + outcome = rollback.outcome + classified = _classify_outcome(outcome) + if classified == "permission_denied": + counts["permission_denied"] += 1 + else: + counts["not_reversible"] += 1 + results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status=classified, + message=outcome.message, + ) + ) + continue + except IntegrityError: + # Partial unique index caught a concurrent revert that won + # the race against our pre-flight ``_was_already_reverted`` + # SELECT. Look up the winner so + # we can surface its ``new_action_id`` to the client. + existing_revert_id = await _was_already_reverted( + session, action_id=action.id + ) + counts["already_reverted"] += 1 + results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="already_reverted", + new_action_id=existing_revert_id, + ) + ) + continue + except Exception as err: # pragma: no cover — defensive, logged + logger.exception( + "Unexpected revert failure inside batch for action_id=%s", + action.id, + ) + counts["failed"] += 1 + results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="failed", + error=str(err) or err.__class__.__name__, + ) + ) + continue + + counts["reverted"] += 1 + results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="reverted", + message=outcome.message, + new_action_id=outcome.new_action_id, + ) + ) + + # Single commit at the end — successful SAVEPOINTs above already + # released; failed ones rolled back to their savepoint. No row leaks + # across the boundary. + try: + await session.commit() + except Exception as err: # pragma: no cover — defensive + logger.exception( + "Final commit for revert-turn failed (thread=%s turn=%s)", + thread_id, + chat_turn_id, + ) + await session.rollback() + raise HTTPException( + status_code=500, + detail="Internal error while finalising revert-turn batch.", + ) from err + + has_partial = ( + counts["failed"] > 0 + or counts["not_reversible"] > 0 + or counts["permission_denied"] > 0 + ) + overall_status: Literal["ok", "partial"] = "partial" if has_partial else "ok" + + return RevertTurnResponse( + status=overall_status, + chat_turn_id=chat_turn_id, + total=len(rows), + reverted=counts["reverted"], + already_reverted=counts["already_reverted"], + not_reversible=counts["not_reversible"], + permission_denied=counts["permission_denied"], + failed=counts["failed"], + skipped=counts["skipped"], + results=results, + ) + + +class _OutcomeRollbackError(Exception): + """Sentinel raised inside the SAVEPOINT to roll back a non-OK outcome. + + ``revert_action`` writes a new ``agent_action_log`` row only on the + happy path, but on the failure paths it sometimes mutates the + ``DocumentRevision``/``Document`` tables before deciding the action + is not reversible. Wrapping each call in ``begin_nested`` and raising + this from the failure branch ensures we always discard partial + writes for failed rows. + """ + + def __init__(self, outcome: RevertOutcome) -> None: + self.outcome = outcome + super().__init__(outcome.message) + + +__all__ = ["router"] diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index b5560d90d..26c72bd45 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -11,6 +11,7 @@ These endpoints support the ThreadHistoryAdapter pattern from assistant-ui: """ import asyncio +import json import logging from datetime import UTC, datetime @@ -136,6 +137,260 @@ def _resolve_filesystem_selection( ) +def _find_pre_turn_checkpoint_id( + checkpoint_tuples: list, + *, + turn_id: str, +) -> str | None: + """Locate the LangGraph checkpoint immediately before ``turn_id`` started. + + ``checkpoint_tuples`` arrives newest-first from + ``checkpointer.alist(config)``. We walk OLDEST-first (``reversed``) + and remember the most recent checkpoint that does NOT belong to the + edited turn. As soon as we cross into the edited turn (a checkpoint + whose ``turn_id`` matches), we return the previously-tracked + checkpoint — that's the state immediately before ``turn_id`` began. + + The naive "newest-first, return first non-matching" approach is + INCORRECT when later turns exist after ``turn_id``: their + checkpoints also satisfy ``cp_turn_id != turn_id`` and would be + returned before the real pre-turn boundary is reached. + + Reads from ``cp_tuple.metadata`` (the durable surface promoted from + ``configurable`` at write time) rather than ``config["configurable"]`` + so the lookup is portable across checkpointer implementations. + + Returns ``None`` when no eligible pre-turn checkpoint exists (e.g. + the edited turn is the very first turn of the thread). Callers fall + back to the oldest available checkpoint in that case. + """ + + last_pre_turn_target: str | None = None + for cp_tuple in reversed(checkpoint_tuples): # oldest -> newest + metadata = getattr(cp_tuple, "metadata", None) or {} + cp_turn_id = metadata.get("turn_id") if isinstance(metadata, dict) else None + if cp_turn_id == turn_id: + # Crossed into the edited turn; the previous tracked + # checkpoint is the rewind target. May be ``None`` if we hit + # the edited turn on the very first iteration. + return last_pre_turn_target + try: + last_pre_turn_target = cp_tuple.config["configurable"]["checkpoint_id"] + except (KeyError, TypeError): + continue + return last_pre_turn_target + + +async def _revert_turns_for_regenerate( + *, + thread_id: int, + chat_turn_ids: list[str], + requester_user_id: str, +) -> dict: + """Best-effort revert pass for every ``chat_turn_id`` in ``chat_turn_ids``. + + Runs BEFORE the regenerate stream so the frontend can surface + partial-rollback feedback alongside the new assistant turn. Each + turn's actions are reverted in their own SAVEPOINTs (handled + inside :mod:`app.routes.agent_revert_route`'s helpers) so a single + failure never poisons the batch. + + Sequencing inside the request: revert THEN regenerate. The + operation is NOT atomic and partial state IS surfaced — see the + plan's "Sequencing inside the request" note. + """ + + from app.routes.agent_revert_route import ( + RevertTurnActionResult, + _classify_outcome, + _OutcomeRollbackError, + _was_already_reverted, + _was_already_reverted_batch, + ) + from app.services.revert_service import ( + can_revert, + revert_action, + ) + + aggregated_results: list[dict] = [] + # Exhaustive counters keep the response invariant + # ``total == sum(counters)`` true for ``data-revert-results``. + counts = { + "reverted": 0, + "already_reverted": 0, + "not_reversible": 0, + "permission_denied": 0, + "failed": 0, + "skipped": 0, + } + + # Local import keeps the route module's existing imports tidy and + # avoids a circular dependency at module-load time. + from app.db import AgentActionLog as _AgentActionLog + + async with shielded_async_session() as session: + for chat_turn_id in chat_turn_ids: + rows_stmt = ( + select(_AgentActionLog) + .where( + _AgentActionLog.thread_id == thread_id, + _AgentActionLog.chat_turn_id == chat_turn_id, + ) + .order_by( + _AgentActionLog.created_at.desc(), + _AgentActionLog.id.desc(), + ) + ) + rows = (await session.execute(rows_stmt)).scalars().all() + + # Batch idempotency probe across the turn (single SELECT + # instead of one per row). + eligible_ids = [r.id for r in rows if r.reverse_of is None] + already_reverted_map = await _was_already_reverted_batch( + session, action_ids=eligible_ids + ) + + for action in rows: + if action.reverse_of is not None: + counts["skipped"] += 1 + aggregated_results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="skipped", + message="Row is itself a revert action; skipped.", + ).model_dump() + ) + continue + + existing_revert_id = already_reverted_map.get(action.id) + if existing_revert_id is not None: + counts["already_reverted"] += 1 + aggregated_results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="already_reverted", + new_action_id=existing_revert_id, + ).model_dump() + ) + continue + + if not can_revert( + requester_user_id=requester_user_id, + action=action, + is_admin=False, + ): + counts["permission_denied"] += 1 + aggregated_results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="permission_denied", + message="You are not allowed to revert this action.", + ).model_dump() + ) + continue + + try: + async with session.begin_nested(): + outcome = await revert_action( + session, + action=action, + requester_user_id=requester_user_id, + ) + if outcome.status != "ok": + raise _OutcomeRollbackError(outcome) + except _OutcomeRollbackError as rollback: + outcome = rollback.outcome + classified = _classify_outcome(outcome) + if classified == "permission_denied": + counts["permission_denied"] += 1 + else: + counts["not_reversible"] += 1 + aggregated_results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status=classified, + message=outcome.message, + ).model_dump() + ) + continue + except IntegrityError: + # Concurrent revert won the race against the + # pre-flight ``_was_already_reverted`` SELECT. + # Surface the winning revert id so the client can + # treat this as a successful idempotent op. + existing_revert_id = await _was_already_reverted( + session, action_id=action.id + ) + counts["already_reverted"] += 1 + aggregated_results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="already_reverted", + new_action_id=existing_revert_id, + ).model_dump() + ) + continue + except Exception as err: # pragma: no cover — defensive + _logger.exception( + "Unexpected revert failure during regenerate batch " + "for action_id=%s", + action.id, + ) + counts["failed"] += 1 + aggregated_results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="failed", + error=str(err) or err.__class__.__name__, + ).model_dump() + ) + continue + + counts["reverted"] += 1 + aggregated_results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="reverted", + message=outcome.message, + new_action_id=outcome.new_action_id, + ).model_dump() + ) + + try: + await session.commit() + except Exception: + _logger.exception( + "[regenerate-revert] Final commit failed; rolling back batch." + ) + await session.rollback() + + has_partial = ( + counts["failed"] > 0 + or counts["not_reversible"] > 0 + or counts["permission_denied"] > 0 + ) + + return { + "status": "partial" if has_partial else "ok", + "chat_turn_ids": chat_turn_ids, + "total": len(aggregated_results), + "reverted": counts["reverted"], + "already_reverted": counts["already_reverted"], + "not_reversible": counts["not_reversible"], + "permission_denied": counts["permission_denied"], + "failed": counts["failed"], + "skipped": counts["skipped"], + "results": aggregated_results, + } + + def _try_delete_sandbox(thread_id: int) -> None: """Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked.""" from app.agents.new_chat.sandbox import ( @@ -574,6 +829,7 @@ async def get_thread_messages( token_usage=TokenUsageSummary.model_validate(msg.token_usage) if msg.token_usage else None, + turn_id=msg.turn_id, ) for msg in db_messages ] @@ -1006,12 +1262,24 @@ async def append_message( # Check thread-level access based on visibility await check_thread_access(session, thread, user) - # Create message + # Create message. ``turn_id`` is the per-turn correlation id from + # ``configurable.turn_id`` (added in migration 136) — when the + # client streams it back to ``appendMessage``, we persist it so + # C1's edit-from-arbitrary-position can later map this message + # back to the LangGraph checkpoint that produced its turn. + raw_turn_id = raw_body.get("turn_id") + turn_id_value = ( + str(raw_turn_id).strip() + if isinstance(raw_turn_id, str) and raw_turn_id.strip() + else None + ) + db_message = NewChatMessage( thread_id=thread_id, role=message_role, content=content, author_id=user.id, + turn_id=turn_id_value, ) session.add(db_message) @@ -1050,6 +1318,7 @@ async def append_message( created_at=db_message.created_at, author_id=db_message.author_id, token_usage=None, + turn_id=db_message.turn_id, ) except HTTPException: @@ -1373,43 +1642,123 @@ async def regenerate_response( user_query_to_use = request.user_query regenerate_image_urls: list[str] = [] - # Look through checkpoints to find the right one - # We want to find the checkpoint just before the last HumanMessage - for i, cp_tuple in enumerate(checkpoint_tuples): - # Access the checkpoint's channel_values which contains "messages" - checkpoint_data = cp_tuple.checkpoint - channel_values = checkpoint_data.get("channel_values", {}) - state_messages = channel_values.get("messages", []) + # --------------------------------------------------------------- + # Edit-from-arbitrary-position. When the client passes + # ``from_message_id`` we look up its persisted ``turn_id`` (added + # in migration 136) and pick the checkpoint immediately before + # that turn started. + # + # Legacy graceful-degradation contract: + # * Rows persisted BEFORE migration 136 have ``turn_id IS NULL``. + # Returning 400 in that case is the wrong UX — the user is + # editing an old message in an existing thread and just wants + # it to work. We instead skip the checkpoint rewind (the + # stream falls back to the latest state) and skip the revert + # pass (no chat_turn_id available to walk). Deletion still + # uses ``created_at``, so the messages-after-cursor slice is + # correct on both legacy and post-136 rows. + # --------------------------------------------------------------- + from_message_turn_id: str | None = None + from_message_created_at: datetime | None = None + legacy_from_message: bool = False + if request.from_message_id is not None: + from_msg_row = await session.execute( + select(NewChatMessage).filter( + NewChatMessage.id == request.from_message_id, + NewChatMessage.thread_id == thread_id, + ) + ) + from_msg = from_msg_row.scalars().first() + if from_msg is None: + raise HTTPException( + status_code=404, + detail="from_message_id not found in this thread.", + ) + from_message_created_at = from_msg.created_at + if not from_msg.turn_id: + # Legacy row — surface the degradation in logs but let + # the request proceed with the slice-based delete and a + # cold-start checkpoint. + legacy_from_message = True + _logger.warning( + "[regenerate] from_message_id=%s on thread=%s has no " + "turn_id (legacy row pre-migration-136). Falling back " + "to slice-based delete without checkpoint rewind. " + "revert_actions=%s will be ignored.", + request.from_message_id, + thread_id, + request.revert_actions, + ) + else: + from_message_turn_id = from_msg.turn_id - if state_messages: - last_msg = state_messages[-1] - # Find a checkpoint where the last message is NOT a HumanMessage - # This means we're at a state before the user's last message - if not isinstance(last_msg, HumanMessage): - # If no new user_query provided (reload), extract from a later checkpoint - if user_query_to_use is None and i > 0: - # Get the user query from a more recent checkpoint - for prev_cp_tuple in checkpoint_tuples[:i]: - prev_checkpoint_data = prev_cp_tuple.checkpoint - prev_channel_values = prev_checkpoint_data.get( - "channel_values", {} - ) - prev_messages = prev_channel_values.get("messages", []) - for msg in reversed(prev_messages): - if isinstance(msg, HumanMessage): - q, imgs = split_langchain_human_content(msg.content) - user_query_to_use = q - regenerate_image_urls = imgs - break - if user_query_to_use is not None and ( - str(user_query_to_use).strip() or regenerate_image_urls - ): - break - - target_checkpoint_id = cp_tuple.config["configurable"][ + # Walk oldest-to-newest and pick the LAST checkpoint whose + # ``turn_id`` differs from the edited turn — that's the state + # immediately before this turn started running. We read from + # ``metadata`` (the durable surface) rather than + # ``config["configurable"]`` so the lookup works across + # checkpointer implementations. + target_checkpoint_id = _find_pre_turn_checkpoint_id( + checkpoint_tuples, + turn_id=from_message_turn_id, + ) + if target_checkpoint_id is None and len(checkpoint_tuples) > 0: + # Fall back to the oldest checkpoint — better than + # 400ing when the agent didn't checkpoint pre-turn + # (e.g. very first turn of the thread). + target_checkpoint_id = checkpoint_tuples[-1].config["configurable"][ "checkpoint_id" ] - break + + # Look through checkpoints to find the right one + # We want to find the checkpoint just before the last HumanMessage. + # We enter this branch when: + # * the client did NOT pin ``from_message_id`` (legacy reload/edit), OR + # * the client pinned ``from_message_id`` but the row is a + # legacy pre-migration-136 row with no ``turn_id`` (we + # downgraded to the same heuristic as a regular reload). + # We DO skip it when a real turn_id pinned ``target_checkpoint_id`` + # — that's the C1 happy path and the heuristic below would just + # re-derive a worse target. + if request.from_message_id is None or legacy_from_message: + for i, cp_tuple in enumerate(checkpoint_tuples): + # Access the checkpoint's channel_values which contains "messages" + checkpoint_data = cp_tuple.checkpoint + channel_values = checkpoint_data.get("channel_values", {}) + state_messages = channel_values.get("messages", []) + + if state_messages: + last_msg = state_messages[-1] + # Find a checkpoint where the last message is NOT a HumanMessage + # This means we're at a state before the user's last message + if not isinstance(last_msg, HumanMessage): + # If no new user_query provided (reload), extract from a later checkpoint + if user_query_to_use is None and i > 0: + # Get the user query from a more recent checkpoint + for prev_cp_tuple in checkpoint_tuples[:i]: + prev_checkpoint_data = prev_cp_tuple.checkpoint + prev_channel_values = prev_checkpoint_data.get( + "channel_values", {} + ) + prev_messages = prev_channel_values.get("messages", []) + for msg in reversed(prev_messages): + if isinstance(msg, HumanMessage): + q, imgs = split_langchain_human_content( + msg.content + ) + user_query_to_use = q + regenerate_image_urls = imgs + break + if user_query_to_use is not None and ( + str(user_query_to_use).strip() + or regenerate_image_urls + ): + break + + target_checkpoint_id = cp_tuple.config["configurable"][ + "checkpoint_id" + ] + break # If we couldn't find a good checkpoint, try alternative approaches if target_checkpoint_id is None and checkpoint_tuples: @@ -1472,18 +1821,51 @@ async def regenerate_response( detail="Could not determine user query for regeneration. Please provide a user_query.", ) - # Get the last two messages to delete AFTER streaming succeeds - # This prevents data loss if streaming fails - last_messages_result = await session.execute( - select(NewChatMessage) - .filter(NewChatMessage.thread_id == thread_id) - .order_by(NewChatMessage.created_at.desc()) - .limit(2) - ) + # Get the messages to delete AFTER streaming succeeds. + # This prevents data loss if streaming fails. + # + # When ``from_message_id`` is set we slice from that message + # forward (using ``created_at`` so we also catch any tool/system + # messages persisted into the same turn). Otherwise + # we keep the legacy "last 2 messages" rewind. + if request.from_message_id is not None and from_message_created_at is not None: + last_messages_result = await session.execute( + select(NewChatMessage) + .filter( + NewChatMessage.thread_id == thread_id, + NewChatMessage.created_at >= from_message_created_at, + ) + .order_by(NewChatMessage.created_at.desc()) + ) + else: + last_messages_result = await session.execute( + select(NewChatMessage) + .filter(NewChatMessage.thread_id == thread_id) + .order_by(NewChatMessage.created_at.desc()) + .limit(2) + ) messages_to_delete = list(last_messages_result.scalars().all()) message_ids_to_delete = [msg.id for msg in messages_to_delete] + # When revert_actions is requested, collect the set of + # ``chat_turn_id``s present in the slice we're about to delete. + # Each one will be reverted (best-effort) BEFORE the regenerate + # stream begins. Legacy rows have ``turn_id=None`` and silently + # contribute nothing — we already logged the degradation above. + revert_turn_ids: list[str] = [] + if ( + request.revert_actions + and request.from_message_id is not None + and not legacy_from_message + ): + seen_turns: set[str] = set() + for msg in messages_to_delete: + tid = msg.turn_id + if tid and tid not in seen_turns: + seen_turns.add(tid) + revert_turn_ids.append(tid) + # Get search space for LLM config search_space_result = await session.execute( select(SearchSpace).filter(SearchSpace.id == request.search_space_id) @@ -1507,6 +1889,24 @@ async def regenerate_response( # This prevents data loss if streaming fails (network error, LLM error, etc.) async def stream_with_cleanup(): streaming_completed = False + # Best-effort revert pass BEFORE the regenerate stream begins. + # Each turn is reverted independently (per-row SAVEPOINTs + # inside the route helper) and the per-action results are surfaced + # on a single ``data-revert-results`` SSE event so the frontend + # can render any failed rows alongside the new turn. Failures here + # do NOT abort the regeneration — partial rollback is documented + # behaviour. + if revert_turn_ids: + revert_results = await _revert_turns_for_regenerate( + thread_id=thread_id, + chat_turn_ids=revert_turn_ids, + requester_user_id=str(user.id), + ) + envelope = { + "type": "data-revert-results", + "data": revert_results, + } + yield f"data: {json.dumps(envelope, default=str)}\n\n".encode() try: async for chunk in stream_new_chat( user_query=str(user_query_to_use), diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index 477fdf2ca..c7284e901 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -51,6 +51,11 @@ class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel): author_display_name: str | None = None author_avatar_url: str | None = None token_usage: TokenUsageSummary | None = None + # Per-turn correlation id (``f"{chat_id}:{ms}"``) from + # ``configurable.turn_id`` at streaming time. Nullable because + # legacy rows predate the column; clients should treat NULL as + # "edit-from-this-message is unavailable". + turn_id: str | None = None model_config = ConfigDict(from_attributes=True) @@ -241,6 +246,15 @@ class RegenerateRequest(BaseModel): For edit, optional user_images (when not None) replaces image URLs resolved from checkpoint/DB so the client can send the full user turn (text and/or images). + + Edit-from-arbitrary-position. When ``from_message_id`` is provided + the route slices conversation history starting at that message (instead of + the legacy "last 2 messages" rewind), rewinds the LangGraph checkpoint by + matching ``configurable.turn_id`` stored on the message (added in migration 136), and + optionally reverts every reversible action emitted in turns at or after + ``from_message_id``. The revert step is best-effort and runs BEFORE the + regenerate stream — partial failures are surfaced via SSE + ``data-revert-results`` and do not abort the regeneration. """ search_space_id: int @@ -257,6 +271,28 @@ class RegenerateRequest(BaseModel): default=None, description="If set, use these images for the regenerated turn (edit); overrides checkpoint/DB", ) + from_message_id: int | None = Field( + default=None, + description=( + "Message id to rewind to. When set, history is sliced " + "from this message forward and the LangGraph checkpoint is " + "rewound to the state immediately preceding this turn. Legacy " + "rows that predate migration 136 have ``turn_id=None`` and " + "still process — the route logs a warning, skips the " + "checkpoint rewind, and ignores ``revert_actions`` (no " + "chat_turn_id available to walk)." + ), + ) + revert_actions: bool = Field( + default=False, + description=( + "When true, every reversible action emitted at or " + "after ``from_message_id`` is reverted before the regenerate " + "stream begins. Per-action results are surfaced via the " + "``data-revert-results`` SSE event. Partial failures DO NOT " + "abort the regeneration." + ), + ) @model_validator(mode="after") def _validate_regenerate_user_images(self) -> Self: @@ -264,6 +300,14 @@ class RegenerateRequest(BaseModel): raise ValueError(f"At most {MAX_NEW_CHAT_IMAGES} images allowed") return self + @model_validator(mode="after") + def _validate_revert_actions_requires_from_message(self) -> Self: + if self.revert_actions and self.from_message_id is None: + raise ValueError( + "revert_actions requires from_message_id; specify which message to rewind to" + ) + return self + # ============================================================================= # Agent Tools Schemas diff --git a/surfsense_backend/app/services/new_streaming_service.py b/surfsense_backend/app/services/new_streaming_service.py index 52a215997..5dbae91c5 100644 --- a/surfsense_backend/app/services/new_streaming_service.py +++ b/surfsense_backend/app/services/new_streaming_service.py @@ -584,13 +584,24 @@ class VercelStreamingService: # Tool Parts # ========================================================================= - def format_tool_input_start(self, tool_call_id: str, tool_name: str) -> str: + def format_tool_input_start( + self, + tool_call_id: str, + tool_name: str, + *, + langchain_tool_call_id: str | None = None, + ) -> str: """ Format the start of tool input streaming. Args: - tool_call_id: The unique tool call identifier - tool_name: The name of the tool being called + tool_call_id: The unique tool call identifier (synthetic, derived + from LangGraph ``run_id`` so the frontend has a stable card id). + tool_name: The name of the tool being called. + langchain_tool_call_id: Optional authoritative LangChain + ``tool_call.id``. When set, surfaces as + ``langchainToolCallId`` so the frontend can join this card + to the action-log row written by ``ActionLogMiddleware``. Returns: str: SSE formatted tool input start part @@ -598,13 +609,14 @@ class VercelStreamingService: Example output: data: {"type":"tool-input-start","toolCallId":"call_abc123","toolName":"getWeather"} """ - return self._format_sse( - { - "type": "tool-input-start", - "toolCallId": tool_call_id, - "toolName": tool_name, - } - ) + payload: dict[str, Any] = { + "type": "tool-input-start", + "toolCallId": tool_call_id, + "toolName": tool_name, + } + if langchain_tool_call_id: + payload["langchainToolCallId"] = langchain_tool_call_id + return self._format_sse(payload) def format_tool_input_delta(self, tool_call_id: str, input_text_delta: str) -> str: """ @@ -629,7 +641,12 @@ class VercelStreamingService: ) def format_tool_input_available( - self, tool_call_id: str, tool_name: str, input_data: dict[str, Any] + self, + tool_call_id: str, + tool_name: str, + input_data: dict[str, Any], + *, + langchain_tool_call_id: str | None = None, ) -> str: """ Format the completion of tool input. @@ -638,6 +655,8 @@ class VercelStreamingService: tool_call_id: The tool call identifier tool_name: The name of the tool input_data: The complete tool input parameters + langchain_tool_call_id: Optional authoritative LangChain + ``tool_call.id`` (see ``format_tool_input_start``). Returns: str: SSE formatted tool input available part @@ -645,22 +664,34 @@ class VercelStreamingService: Example output: data: {"type":"tool-input-available","toolCallId":"call_abc123","toolName":"getWeather","input":{"city":"SF"}} """ - return self._format_sse( - { - "type": "tool-input-available", - "toolCallId": tool_call_id, - "toolName": tool_name, - "input": input_data, - } - ) + payload: dict[str, Any] = { + "type": "tool-input-available", + "toolCallId": tool_call_id, + "toolName": tool_name, + "input": input_data, + } + if langchain_tool_call_id: + payload["langchainToolCallId"] = langchain_tool_call_id + return self._format_sse(payload) - def format_tool_output_available(self, tool_call_id: str, output: Any) -> str: + def format_tool_output_available( + self, + tool_call_id: str, + output: Any, + *, + langchain_tool_call_id: str | None = None, + ) -> str: """ Format tool execution output. Args: tool_call_id: The tool call identifier output: The tool execution result + langchain_tool_call_id: Optional authoritative LangChain + ``tool_call.id`` extracted from ``ToolMessage.tool_call_id``. + When set, the frontend can backfill any card whose + ``langchainToolCallId`` was not yet known at + ``tool-input-start`` time. Returns: str: SSE formatted tool output available part @@ -668,13 +699,14 @@ class VercelStreamingService: Example output: data: {"type":"tool-output-available","toolCallId":"call_abc123","output":{"weather":"sunny"}} """ - return self._format_sse( - { - "type": "tool-output-available", - "toolCallId": tool_call_id, - "output": output, - } - ) + payload: dict[str, Any] = { + "type": "tool-output-available", + "toolCallId": tool_call_id, + "output": output, + } + if langchain_tool_call_id: + payload["langchainToolCallId"] = langchain_tool_call_id + return self._format_sse(payload) # ========================================================================= # Step Parts diff --git a/surfsense_backend/app/services/revert_service.py b/surfsense_backend/app/services/revert_service.py index f3630e0b4..d02a31345 100644 --- a/surfsense_backend/app/services/revert_service.py +++ b/surfsense_backend/app/services/revert_service.py @@ -8,7 +8,9 @@ Operation outcomes mirror the plan: * **KB-owned actions** (NOTE / FILE / FOLDER mutations): restore from :class:`app.db.DocumentRevision` / :class:`app.db.FolderRevision` rows - written before the original mutation. + written before the original mutation. ``rm``/``rmdir`` re-INSERT a fresh + row from the snapshot; ``write_file`` create / ``mkdir`` DELETE the row + that was created; everything else is an in-place restore. * **Connector-owned actions with a declared ``reverse_descriptor``**: invoke the inverse tool through the agent's normal permission stack (NOT bypassed). Out of scope for this PR — returns ``REVERSE_NOT_IMPLEMENTED``. @@ -18,6 +20,11 @@ Operation outcomes mirror the plan: A successful revert appends a NEW row to ``agent_action_log`` with ``reverse_of=`` and the requesting user's ``user_id``, preserving an auditable chain. + +Dispatch must be exact-match (``tool_name == name``), NOT prefix matching. +``"rmdir".startswith("rm")`` would otherwise mis-route directory revert +to the document branch (and ``delete_note`` vs ``delete_folder`` is the +same trap waiting to happen). """ from __future__ import annotations @@ -25,17 +32,31 @@ from __future__ import annotations import logging from dataclasses import dataclass from datetime import UTC, datetime -from typing import Literal +from typing import Any, Literal -from sqlalchemy import select +from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.path_resolver import ( + DOCUMENTS_ROOT, + safe_filename, + safe_folder_segment, +) from app.db import ( AgentActionLog, + Chunk, + Document, DocumentRevision, + DocumentType, + Folder, FolderRevision, NewChatThread, ) +from app.utils.document_converters import ( + embed_texts, + generate_content_hash, + generate_unique_identifier_hash, +) logger = logging.getLogger(__name__) @@ -110,14 +131,244 @@ def can_revert( # --------------------------------------------------------------------------- -# Revert paths +# Helper: reconstruct virtual path from a snapshot # --------------------------------------------------------------------------- +async def _virtual_path_from_snapshot( + session: AsyncSession, + revision: DocumentRevision, +) -> str | None: + """Reconstruct the virtual_path the document was at before mutation. + + Preference order: + 1. ``metadata_before["virtual_path"]`` — written by every snapshot + helper since this PR. + 2. Compose ``"/"`` from + ``folder_id_before`` + ``title_before``. Walks the folder chain via + ``parent_id``. + """ + metadata = revision.metadata_before or {} + candidate = metadata.get("virtual_path") if isinstance(metadata, dict) else None + if isinstance(candidate, str) and candidate.startswith(DOCUMENTS_ROOT): + return candidate + + title = revision.title_before + if not isinstance(title, str) or not title: + return None + + parts: list[str] = [] + cursor: int | None = revision.folder_id_before + visited: set[int] = set() + while cursor is not None and cursor not in visited: + visited.add(cursor) + folder = await session.get(Folder, cursor) + if folder is None: + return None + parts.append(safe_folder_segment(str(folder.name or ""))) + cursor = folder.parent_id + parts.reverse() + + base = f"{DOCUMENTS_ROOT}/" + "/".join(parts) if parts else DOCUMENTS_ROOT + filename = safe_filename(title) + return f"{base}/{filename}" + + +# --------------------------------------------------------------------------- +# Document revision restore (write/edit/move/rm) +# --------------------------------------------------------------------------- + + +def _set_field(target: Any, field: str, value: Any) -> None: + if value is not None: + setattr(target, field, value) + + +async def _restore_in_place_document( + session: AsyncSession, + *, + revision: DocumentRevision, +) -> RevertOutcome: + """Apply an in-place restore to an existing :class:`Document`.""" + if revision.document_id is None: + return RevertOutcome( + status="tool_unavailable", + message=( + "Original document was hard-deleted; in-place restore is not possible." + ), + ) + doc = await session.get(Document, revision.document_id) + if doc is None: + return RevertOutcome( + status="tool_unavailable", + message="Original document has been deleted; revert cannot proceed.", + ) + + _set_field(doc, "content", revision.content_before) + _set_field(doc, "source_markdown", revision.content_before) + _set_field(doc, "title", revision.title_before) + _set_field(doc, "folder_id", revision.folder_id_before) + metadata_before = revision.metadata_before or {} + if isinstance(metadata_before, dict) and metadata_before: + doc.document_metadata = dict(metadata_before) + + if isinstance(revision.content_before, str): + doc.content_hash = generate_content_hash( + revision.content_before, doc.search_space_id + ) + + virtual_path = await _virtual_path_from_snapshot(session, revision) + if virtual_path: + doc.unique_identifier_hash = generate_unique_identifier_hash( + DocumentType.NOTE, + virtual_path, + doc.search_space_id, + ) + + chunks_before = revision.chunks_before + if isinstance(chunks_before, list): + await session.execute(delete(Chunk).where(Chunk.document_id == doc.id)) + chunk_texts = [ + str(c.get("content")) + for c in chunks_before + if isinstance(c, dict) and isinstance(c.get("content"), str) + ] + if chunk_texts: + chunk_embeddings = embed_texts(chunk_texts) + session.add_all( + [ + Chunk(document_id=doc.id, content=text, embedding=embedding) + for text, embedding in zip( + chunk_texts, chunk_embeddings, strict=True + ) + ] + ) + if isinstance(revision.content_before, str): + doc.embedding = embed_texts([revision.content_before])[0] + + doc.updated_at = datetime.now(UTC) + return RevertOutcome(status="ok", message="Document restored from snapshot.") + + +async def _reinsert_document_from_revision( + session: AsyncSession, + *, + revision: DocumentRevision, +) -> RevertOutcome: + """Re-INSERT a deleted :class:`Document` from a snapshot row (``rm`` revert).""" + if not isinstance(revision.title_before, str) or not revision.title_before: + return RevertOutcome( + status="not_reversible", + message="Snapshot lacks title_before; cannot recreate document.", + ) + if not isinstance(revision.content_before, str): + return RevertOutcome( + status="not_reversible", + message="Snapshot lacks content_before; cannot recreate document.", + ) + + virtual_path = await _virtual_path_from_snapshot(session, revision) + if not virtual_path: + return RevertOutcome( + status="not_reversible", + message=( + "Snapshot is missing both metadata_before['virtual_path'] AND " + "a resolvable (folder_id_before, title_before) pair." + ), + ) + + search_space_id = revision.search_space_id + unique_identifier_hash = generate_unique_identifier_hash( + DocumentType.NOTE, + virtual_path, + search_space_id, + ) + collision = await session.execute( + select(Document.id).where( + Document.search_space_id == search_space_id, + Document.unique_identifier_hash == unique_identifier_hash, + ) + ) + if collision.scalar_one_or_none() is not None: + return RevertOutcome( + status="tool_unavailable", + message=( + f"A document already exists at '{virtual_path}'; revert would " + "collide. Move the live doc out of the way first." + ), + ) + + metadata = revision.metadata_before or {} + if not isinstance(metadata, dict): + metadata = {} + metadata = dict(metadata) + metadata["virtual_path"] = virtual_path + + content = revision.content_before + new_doc = Document( + title=revision.title_before, + document_type=DocumentType.NOTE, + document_metadata=metadata, + content=content, + content_hash=generate_content_hash(content, search_space_id), + unique_identifier_hash=unique_identifier_hash, + source_markdown=content, + search_space_id=search_space_id, + folder_id=revision.folder_id_before, + updated_at=datetime.now(UTC), + ) + session.add(new_doc) + await session.flush() + + new_doc.embedding = embed_texts([content])[0] + chunk_texts = [] + chunks_before = revision.chunks_before + if isinstance(chunks_before, list): + chunk_texts = [ + str(c.get("content")) + for c in chunks_before + if isinstance(c, dict) and isinstance(c.get("content"), str) + ] + if chunk_texts: + chunk_embeddings = embed_texts(chunk_texts) + session.add_all( + [ + Chunk(document_id=new_doc.id, content=text, embedding=embedding) + for text, embedding in zip(chunk_texts, chunk_embeddings, strict=True) + ] + ) + + # Repoint the snapshot at the recreated row so a follow-up revert of + # the same row works as expected. + revision.document_id = new_doc.id + return RevertOutcome( + status="ok", + message=f"Re-inserted document '{revision.title_before}' from snapshot.", + ) + + +async def _delete_created_document( + session: AsyncSession, + *, + revision: DocumentRevision, +) -> RevertOutcome: + """Delete the document that ``write_file`` created (``content_before IS NULL``).""" + if revision.document_id is None: + return RevertOutcome( + status="ok", + message="No live row to delete (already removed elsewhere).", + ) + await session.execute(delete(Document).where(Document.id == revision.document_id)) + return RevertOutcome( + status="ok", + message="Deleted the document that was created by this action.", + ) + + async def _restore_document_revision( session: AsyncSession, *, action: AgentActionLog ) -> RevertOutcome: - """Restore the most recent :class:`DocumentRevision` for ``action``.""" + """Dispatch document-level revert based on ``action.tool_name``.""" stmt = ( select(DocumentRevision) .where(DocumentRevision.agent_action_id == action.id) @@ -132,23 +383,111 @@ async def _restore_document_revision( message="No document_revisions row tied to this action.", ) - from app.db import Document # late import to avoid cycles at module load + tool_name = (action.tool_name or "").lower() - doc = await session.get(Document, revision.document_id) - if doc is None: + if tool_name == "rm": + return await _reinsert_document_from_revision(session, revision=revision) + + if tool_name == "write_file" and revision.content_before is None: + return await _delete_created_document(session, revision=revision) + + return await _restore_in_place_document(session, revision=revision) + + +# --------------------------------------------------------------------------- +# Folder revision restore (mkdir/rmdir/rename/move) +# --------------------------------------------------------------------------- + + +async def _restore_in_place_folder( + session: AsyncSession, + *, + revision: FolderRevision, +) -> RevertOutcome: + if revision.folder_id is None: return RevertOutcome( status="tool_unavailable", - message="Original document has been deleted; revert cannot proceed.", + message="Original folder was hard-deleted; in-place restore is impossible.", + ) + folder = await session.get(Folder, revision.folder_id) + if folder is None: + return RevertOutcome( + status="tool_unavailable", + message="Original folder has been deleted; revert cannot proceed.", + ) + _set_field(folder, "name", revision.name_before) + _set_field(folder, "parent_id", revision.parent_id_before) + _set_field(folder, "position", revision.position_before) + folder.updated_at = datetime.now(UTC) + return RevertOutcome(status="ok", message="Folder restored from snapshot.") + + +async def _reinsert_folder_from_revision( + session: AsyncSession, + *, + revision: FolderRevision, +) -> RevertOutcome: + if not isinstance(revision.name_before, str) or not revision.name_before: + return RevertOutcome( + status="not_reversible", + message="Snapshot lacks name_before; cannot recreate folder.", + ) + new_folder = Folder( + name=revision.name_before, + parent_id=revision.parent_id_before, + position=revision.position_before, + search_space_id=revision.search_space_id, + updated_at=datetime.now(UTC), + ) + session.add(new_folder) + await session.flush() + revision.folder_id = new_folder.id + return RevertOutcome( + status="ok", + message=f"Re-inserted folder '{revision.name_before}' from snapshot.", + ) + + +async def _delete_created_folder( + session: AsyncSession, + *, + revision: FolderRevision, +) -> RevertOutcome: + if revision.folder_id is None: + return RevertOutcome( + status="ok", + message="No live folder row to delete (already removed elsewhere).", + ) + folder_id = revision.folder_id + + has_doc = await session.execute( + select(Document.id).where(Document.folder_id == folder_id).limit(1) + ) + if has_doc.scalar_one_or_none() is not None: + return RevertOutcome( + status="tool_unavailable", + message=( + "Folder is no longer empty (documents have been added since " + "mkdir); cannot revert." + ), + ) + has_child = await session.execute( + select(Folder.id).where(Folder.parent_id == folder_id).limit(1) + ) + if has_child.scalar_one_or_none() is not None: + return RevertOutcome( + status="tool_unavailable", + message=( + "Folder is no longer empty (sub-folders have been added " + "since mkdir); cannot revert." + ), ) - if revision.content_before is not None: - doc.content = revision.content_before - if revision.title_before is not None: - doc.title = revision.title_before - if revision.folder_id_before is not None: - doc.folder_id = revision.folder_id_before - doc.updated_at = datetime.now(UTC) - return RevertOutcome(status="ok", message="Document restored from snapshot.") + await session.execute(delete(Folder).where(Folder.id == folder_id)) + return RevertOutcome( + status="ok", + message="Deleted the folder that was created by this action.", + ) async def _restore_folder_revision( @@ -168,41 +507,44 @@ async def _restore_folder_revision( message="No folder_revisions row tied to this action.", ) - from app.db import Folder + tool_name = (action.tool_name or "").lower() - folder = await session.get(Folder, revision.folder_id) - if folder is None: - return RevertOutcome( - status="tool_unavailable", - message="Original folder has been deleted; revert cannot proceed.", - ) + if tool_name == "rmdir": + return await _reinsert_folder_from_revision(session, revision=revision) - if revision.name_before is not None: - folder.name = revision.name_before - if revision.parent_id_before is not None: - folder.parent_id = revision.parent_id_before - if revision.position_before is not None: - folder.position = revision.position_before - folder.updated_at = datetime.now(UTC) - return RevertOutcome(status="ok", message="Folder restored from snapshot.") + if tool_name == "mkdir": + return await _delete_created_folder(session, revision=revision) + + return await _restore_in_place_folder(session, revision=revision) -# Tool-name prefixes that route to KB document / folder revert paths. Kept -# as data so a future PR adding new KB-owned tools doesn't have to touch -# this module's control flow. -_DOC_TOOL_PREFIXES: tuple[str, ...] = ( - "edit_file", - "write_file", - "update_memory", - "create_note", - "update_note", - "delete_note", +# --------------------------------------------------------------------------- +# Dispatch +# --------------------------------------------------------------------------- +# +# Exact-name dispatch: ``tool_name == name``, NOT ``startswith(...)``. +# Prefix-matching mis-routes pairs like ``rm``/``rmdir`` and +# ``delete_note``/``delete_folder``. + +_DOC_TOOLS: frozenset[str] = frozenset( + { + "edit_file", + "write_file", + "move_file", + "rm", + "update_memory", + "create_note", + "update_note", + "delete_note", + } ) -_FOLDER_TOOL_PREFIXES: tuple[str, ...] = ( - "mkdir", - "move_file", - "rename_folder", - "delete_folder", +_FOLDER_TOOLS: frozenset[str] = frozenset( + { + "mkdir", + "rmdir", + "rename_folder", + "delete_folder", + } ) @@ -220,9 +562,9 @@ async def revert_action( """ tool_name = (action.tool_name or "").lower() - if tool_name.startswith(_DOC_TOOL_PREFIXES): + if tool_name in _DOC_TOOLS: outcome = await _restore_document_revision(session, action=action) - elif tool_name.startswith(_FOLDER_TOOL_PREFIXES): + elif tool_name in _FOLDER_TOOLS: outcome = await _restore_folder_revision(session, action=action) elif action.reverse_descriptor: # Connector-owned reversibles run through the normal permission diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index c254e66e2..2f8e33ba9 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -30,6 +30,7 @@ from sqlalchemy.orm import selectinload from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent from app.agents.new_chat.checkpointer import get_checkpointer +from app.agents.new_chat.feature_flags import get_flags from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection from app.agents.new_chat.llm_config import ( AgentConfig, @@ -70,6 +71,91 @@ _background_tasks: set[asyncio.Task] = set() _perf_log = get_perf_logger() +def _extract_chunk_parts(chunk: Any) -> dict[str, Any]: + """Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts. + + Returns a dict with three keys: + + * ``text`` — concatenated string content (empty string if the chunk + contributes none). + * ``reasoning`` — concatenated reasoning content (empty string if the + chunk contributes none). + * ``tool_call_chunks`` — flat list of LangChain ``tool_call_chunk`` + dicts surfaced from either the typed-block list or the + ``tool_call_chunks`` attribute. + + Background + ---------- + ``AIMessageChunk.content`` can be: + + * a ``str`` (most providers), or + * a ``list`` of typed blocks ``{type: 'text' | 'reasoning' | + 'tool_call_chunk' | 'tool_use' | ..., text/content/...}`` for + Anthropic, Bedrock, and several reasoning configurations. + + Reasoning may also live under + ``chunk.additional_kwargs['reasoning_content']`` (some providers + surface it that way instead of as a typed block). Tool-call chunks + may live under ``chunk.tool_call_chunks`` even when ``content`` is a + plain string. + + Earlier versions only handled the ``isinstance(content, str)`` branch + and silently dropped reasoning blocks + tool-call chunks emitted by + LangChain ``AIMessageChunk``s. + """ + out: dict[str, Any] = {"text": "", "reasoning": "", "tool_call_chunks": []} + if chunk is None: + return out + + content = getattr(chunk, "content", None) + if isinstance(content, str): + if content: + out["text"] = content + elif isinstance(content, list): + text_parts: list[str] = [] + reasoning_parts: list[str] = [] + for block in content: + if not isinstance(block, dict): + continue + block_type = block.get("type") + if block_type == "text": + value = block.get("text") or block.get("content") or "" + if isinstance(value, str) and value: + text_parts.append(value) + elif block_type == "reasoning": + value = ( + block.get("reasoning") + or block.get("text") + or block.get("content") + or "" + ) + if isinstance(value, str) and value: + reasoning_parts.append(value) + elif block_type in ("tool_call_chunk", "tool_use"): + out["tool_call_chunks"].append(block) + if text_parts: + out["text"] = "".join(text_parts) + if reasoning_parts: + out["reasoning"] = "".join(reasoning_parts) + + additional = getattr(chunk, "additional_kwargs", None) or {} + if isinstance(additional, dict): + extra_reasoning = additional.get("reasoning_content") + if isinstance(extra_reasoning, str) and extra_reasoning: + existing = out["reasoning"] + out["reasoning"] = ( + (existing + extra_reasoning) if existing else extra_reasoning + ) + + extra_tool_chunks = getattr(chunk, "tool_call_chunks", None) + if isinstance(extra_tool_chunks, list): + for tcc in extra_tool_chunks: + if isinstance(tcc, dict): + out["tool_call_chunks"].append(tcc) + + return out + + def format_mentioned_surfsense_docs_as_context( documents: list[SurfsenseDocsDocument], ) -> str: @@ -266,6 +352,7 @@ async def _stream_agent_events( fallback_commit_search_space_id: int | None = None, fallback_commit_created_by_id: str | None = None, fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD, + fallback_commit_thread_id: int | None = None, ) -> AsyncGenerator[str, None]: """Shared async generator that streams and formats astream_events from the agent. @@ -298,6 +385,41 @@ async def _stream_agent_events( active_tool_depth: int = 0 # Track nesting: >0 means we're inside a tool called_update_memory: bool = False + # Reasoning-block streaming. We open a reasoning block on the + # first reasoning delta of a step, append deltas as they arrive, and + # close it when text starts (the model has switched to writing its + # answer) or ``on_chat_model_end`` fires for the model node. Reuses + # the same Vercel format-helpers as text-start/delta/end. + current_reasoning_id: str | None = None + + # Streaming-parity v2 feature flag. When OFF we keep the legacy + # shape: str-only content, no reasoning blocks, no + # ``langchainToolCallId`` propagation. The schema migrations + # (135 / 136) ship unconditionally because they're forward-compatible. + parity_v2 = bool(get_flags().enable_stream_parity_v2) + + # Best-effort attach of LangChain ``tool_call_id`` to the synthetic + # ``call_`` card id we already emit. We accumulate + # ``tool_call_chunks`` from ``on_chat_model_stream``, key them by + # name, and pop the next unconsumed entry at ``on_tool_start``. The + # authoritative id is later filled in at ``on_tool_end`` from + # ``ToolMessage.tool_call_id``. + pending_tool_call_chunks: list[dict[str, Any]] = [] + lc_tool_call_id_by_run: dict[str, str] = {} + + # Per-tool-end mutable cache for the LangChain tool_call_id resolved + # at ``on_tool_end``. ``_emit_tool_output`` reads this so every + # ``format_tool_output_available`` call automatically carries the + # authoritative id without duplicating the kwarg at every call site. + current_lc_tool_call_id: dict[str, str | None] = {"value": None} + + def _emit_tool_output(call_id: str, output: Any) -> str: + return streaming_service.format_tool_output_available( + call_id, + output, + langchain_tool_call_id=current_lc_tool_call_id["value"], + ) + def next_thinking_step_id() -> str: nonlocal thinking_step_counter thinking_step_counter += 1 @@ -326,22 +448,61 @@ async def _stream_agent_events( if "surfsense:internal" in event.get("tags", []): continue # Suppress middleware-internal LLM tokens (e.g. KB search classification) chunk = event.get("data", {}).get("chunk") - if chunk and hasattr(chunk, "content"): - content = chunk.content - if content and isinstance(content, str): - if current_text_id is None: - completion_event = complete_current_step() - if completion_event: - yield completion_event - if just_finished_tool: - last_active_step_id = None - last_active_step_title = "" - last_active_step_items = [] - just_finished_tool = False - current_text_id = streaming_service.generate_text_id() - yield streaming_service.format_text_start(current_text_id) - yield streaming_service.format_text_delta(current_text_id, content) - accumulated_text += content + if not chunk: + continue + parts = _extract_chunk_parts(chunk) + + # Accumulate any tool_call_chunks for best-effort + # correlation with ``on_tool_start`` below. We don't emit + # anything here; the matching is done at tool-start time. + if parity_v2 and parts["tool_call_chunks"]: + for tcc in parts["tool_call_chunks"]: + pending_tool_call_chunks.append(tcc) + + reasoning_delta = parts["reasoning"] + text_delta = parts["text"] + + # Reasoning streaming. Open a reasoning block on first + # delta; append every subsequent delta until text begins. + # When text starts we close the reasoning block first so the + # frontend sees the natural hand-off. Gated behind the + # parity-v2 flag so legacy deployments keep today's shape. + if parity_v2 and reasoning_delta: + if current_text_id is not None: + yield streaming_service.format_text_end(current_text_id) + current_text_id = None + if current_reasoning_id is None: + completion_event = complete_current_step() + if completion_event: + yield completion_event + if just_finished_tool: + last_active_step_id = None + last_active_step_title = "" + last_active_step_items = [] + just_finished_tool = False + current_reasoning_id = streaming_service.generate_reasoning_id() + yield streaming_service.format_reasoning_start(current_reasoning_id) + yield streaming_service.format_reasoning_delta( + current_reasoning_id, reasoning_delta + ) + + if text_delta: + if current_reasoning_id is not None: + yield streaming_service.format_reasoning_end(current_reasoning_id) + current_reasoning_id = None + if current_text_id is None: + completion_event = complete_current_step() + if completion_event: + yield completion_event + if just_finished_tool: + last_active_step_id = None + last_active_step_title = "" + last_active_step_items = [] + just_finished_tool = False + current_text_id = streaming_service.generate_text_id() + yield streaming_service.format_text_start(current_text_id) + yield streaming_service.format_text_delta(current_text_id, text_delta) + accumulated_text += text_delta elif event_type == "on_tool_start": active_tool_depth += 1 @@ -581,7 +742,39 @@ async def _stream_agent_events( if run_id else streaming_service.generate_tool_call_id() ) - yield streaming_service.format_tool_input_start(tool_call_id, tool_name) + + # Best-effort attach the LangChain ``tool_call_id``. We + # pop the first chunk in ``pending_tool_call_chunks`` whose + # name matches; if none match (the chunked args may not yet + # carry a ``name`` field, or the model skipped the chunked + # form) we leave ``langchainToolCallId`` unset for now and + # fill it in authoritatively at ``on_tool_end`` from + # ``ToolMessage.tool_call_id``. + langchain_tool_call_id: str | None = None + if parity_v2 and pending_tool_call_chunks: + matched_idx: int | None = None + for idx, tcc in enumerate(pending_tool_call_chunks): + if tcc.get("name") == tool_name and tcc.get("id"): + matched_idx = idx + break + if matched_idx is None: + for idx, tcc in enumerate(pending_tool_call_chunks): + if tcc.get("id"): + matched_idx = idx + break + if matched_idx is not None: + matched = pending_tool_call_chunks.pop(matched_idx) + candidate = matched.get("id") + if isinstance(candidate, str) and candidate: + langchain_tool_call_id = candidate + if run_id: + lc_tool_call_id_by_run[run_id] = candidate + + yield streaming_service.format_tool_input_start( + tool_call_id, + tool_name, + langchain_tool_call_id=langchain_tool_call_id, + ) # Sanitize tool_input: strip runtime-injected non-serializable # values (e.g. LangChain ToolRuntime) before sending over SSE. if isinstance(tool_input, dict): @@ -598,6 +791,7 @@ async def _stream_agent_events( tool_call_id, tool_name, _safe_input, + langchain_tool_call_id=langchain_tool_call_id, ) elif event_type == "on_tool_end": @@ -639,6 +833,23 @@ async def _stream_agent_events( ) completed_step_ids.add(original_step_id) + # Authoritative LangChain tool_call_id from the returned + # ``ToolMessage``. Falls back to whatever we matched + # at ``on_tool_start`` time (kept in ``lc_tool_call_id_by_run``) + # if the output isn't a ToolMessage. The value is stored in + # ``current_lc_tool_call_id`` so ``_emit_tool_output`` + # picks it up for every output emit below. Stays None when + # parity_v2 is off so legacy emit paths are untouched. + current_lc_tool_call_id["value"] = None + if parity_v2: + authoritative = getattr(raw_output, "tool_call_id", None) + if isinstance(authoritative, str) and authoritative: + current_lc_tool_call_id["value"] = authoritative + if run_id: + lc_tool_call_id_by_run[run_id] = authoritative + elif run_id and run_id in lc_tool_call_id_by_run: + current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id] + if tool_name == "read_file": yield streaming_service.format_thinking_step( step_id=original_step_id, @@ -938,7 +1149,7 @@ async def _stream_agent_events( last_active_step_items = [] if tool_name == "generate_podcast": - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, tool_output if isinstance(tool_output, dict) @@ -963,7 +1174,7 @@ async def _stream_agent_events( "error", ) elif tool_name == "generate_video_presentation": - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, tool_output if isinstance(tool_output, dict) @@ -991,7 +1202,7 @@ async def _stream_agent_events( "error", ) elif tool_name == "generate_image": - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, tool_output if isinstance(tool_output, dict) @@ -1018,12 +1229,12 @@ async def _stream_agent_events( display_output["content_preview"] = ( content[:500] + "..." if len(content) > 500 else content ) - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, display_output, ) else: - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, {"result": tool_output}, ) @@ -1051,7 +1262,7 @@ async def _stream_agent_events( ) result_text = _tool_output_to_text(tool_output) if _tool_output_has_error(tool_output): - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, { "status": "error", @@ -1060,7 +1271,7 @@ async def _stream_agent_events( }, ) else: - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, { "status": "completed", @@ -1070,7 +1281,7 @@ async def _stream_agent_events( ) elif tool_name == "generate_report": # Stream the full report result so frontend can render the ReportCard - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, tool_output if isinstance(tool_output, dict) @@ -1097,7 +1308,7 @@ async def _stream_agent_events( "error", ) elif tool_name == "generate_resume": - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, tool_output if isinstance(tool_output, dict) @@ -1148,7 +1359,7 @@ async def _stream_agent_events( "update_confluence_page", "delete_confluence_page", ): - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, tool_output if isinstance(tool_output, dict) @@ -1176,7 +1387,7 @@ async def _stream_agent_events( if fpath and fpath not in result.sandbox_files: result.sandbox_files.append(fpath) - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, { "exit_code": exit_code, @@ -1211,12 +1422,12 @@ async def _stream_agent_events( citations[chunk_url]["snippet"] = ( content[:200] + "…" if len(content) > 200 else content ) - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, {"status": "completed", "citations": citations}, ) else: - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, {"status": "completed", "result_length": len(str(tool_output))}, ) @@ -1274,6 +1485,25 @@ async def _stream_agent_events( }, ) + elif event_type == "on_custom_event" and event.get("name") == "action_log": + # Surface a freshly committed AgentActionLog row so the chat + # tool card can render its Revert button immediately. + data = event.get("data", {}) + if data.get("id") is not None: + yield streaming_service.format_data("action-log", data) + + elif ( + event_type == "on_custom_event" + and event.get("name") == "action_log_updated" + ): + # Reversibility flipped in kb_persistence after the SAVEPOINT + # for a destructive op (rm/rmdir/move/edit/write) committed. + # Frontend uses this to flip the card's Revert + # button on without re-fetching the actions list. + data = event.get("data", {}) + if data.get("id") is not None: + yield streaming_service.format_data("action-log-updated", data) + elif event_type in ("on_chain_end", "on_agent_end"): if current_text_id is not None: yield streaming_service.format_text_end(current_text_id) @@ -1291,11 +1521,12 @@ async def _stream_agent_events( # Safety net: if astream_events was cancelled before # KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work - # (dirty_paths / staged_dirs / pending_moves) will still be in the - # checkpointed state. Run the SAME shared commit helper here so the - # turn's writes don't get lost on client disconnect, then push the - # delta back into the graph using `as_node=...` so reducers fire as if - # the after_agent hook produced it. + # (dirty_paths / staged_dirs / pending_moves / pending_deletes / + # pending_dir_deletes) will still be in the checkpointed state. Run + # the SAME shared commit helper here so the turn's writes don't get + # lost on client disconnect, then push the delta back into the graph + # using `as_node=...` so reducers fire as if the after_agent hook + # produced it. if ( fallback_commit_filesystem_mode == FilesystemMode.CLOUD and fallback_commit_search_space_id is not None @@ -1303,6 +1534,8 @@ async def _stream_agent_events( (state_values.get("dirty_paths") or []) or (state_values.get("staged_dirs") or []) or (state_values.get("pending_moves") or []) + or (state_values.get("pending_deletes") or []) + or (state_values.get("pending_dir_deletes") or []) ) ): try: @@ -1311,6 +1544,7 @@ async def _stream_agent_events( search_space_id=fallback_commit_search_space_id, created_by_id=fallback_commit_created_by_id, filesystem_mode=fallback_commit_filesystem_mode, + thread_id=fallback_commit_thread_id, dispatch_events=False, ) if delta: @@ -1726,6 +1960,17 @@ async def stream_new_chat( yield streaming_service.format_message_start() yield streaming_service.format_start_step() + # Surface the per-turn correlation id at the very start of the + # stream so the frontend can stamp it onto the in-flight + # assistant message and replay it via ``appendMessage`` + # for durable storage. Tool/action-log events DO carry it later, + # but pure-text turns never produce action-log events; this + # event guarantees the frontend learns the turn id regardless. + yield streaming_service.format_data( + "turn-info", + {"chat_turn_id": stream_result.turn_id}, + ) + # Initial thinking step - analyzing the request if mentioned_surfsense_docs: initial_title = "Analyzing referenced content" @@ -1876,6 +2121,7 @@ async def stream_new_chat( if filesystem_selection else FilesystemMode.CLOUD ), + fallback_commit_thread_id=chat_id, ): if not _first_event_logged: _perf_log.info( @@ -2308,6 +2554,13 @@ async def stream_resume_chat( yield streaming_service.format_message_start() yield streaming_service.format_start_step() + # Same rationale as ``stream_new_chat``: emit the turn id so + # resumed streams can be persisted with their correlation id + # intact. + yield streaming_service.format_data( + "turn-info", + {"chat_turn_id": stream_result.turn_id}, + ) _t_stream_start = time.perf_counter() _first_event_logged = False @@ -2325,6 +2578,7 @@ async def stream_resume_chat( if filesystem_selection else FilesystemMode.CLOUD ), + fallback_commit_thread_id=chat_id, ): if not _first_event_logged: _perf_log.info( diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py b/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py index aad1524c9..8ef1430a9 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py @@ -15,6 +15,17 @@ from app.agents.new_chat.middleware.action_log import ActionLogMiddleware from app.agents.new_chat.tools.registry import ToolDefinition +@dataclass +class _FakeRuntime: + """Minimal stand-in for ``ToolRuntime`` used in unit tests. + + ``ActionLogMiddleware`` reads ``runtime.config['configurable']['turn_id']`` + to populate the new ``chat_turn_id`` column (see migration 135). + """ + + config: dict[str, Any] | None = None + + @dataclass class _FakeRequest: """Minimal stand-in for ToolCallRequest used in unit tests.""" @@ -120,6 +131,9 @@ class TestActionLogMiddlewarePersistence: "args": {"color": "red", "size": 3}, "id": "tc-abc", }, + runtime=_FakeRuntime( + config={"configurable": {"turn_id": "42:1700000000000"}} + ), ) result_msg = ToolMessage(content="ok", tool_call_id="tc-abc", id="msg-1") handler = AsyncMock(return_value=result_msg) @@ -142,6 +156,32 @@ class TestActionLogMiddlewarePersistence: assert row.error is None assert row.reverse_descriptor is None assert row.reversible is False + # Migration 135: ``turn_id`` is the deprecated alias of ``tool_call_id``; + # ``chat_turn_id`` comes from ``runtime.config['configurable']['turn_id']``. + assert row.tool_call_id == "tc-abc" + assert row.turn_id == "tc-abc" + assert row.chat_turn_id == "42:1700000000000" + + @pytest.mark.asyncio + async def test_chat_turn_id_none_when_runtime_missing( + self, patch_get_flags, fake_session_factory + ) -> None: + """``chat_turn_id`` falls back to NULL when ``runtime.config`` is absent.""" + captured, factory = fake_session_factory + mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None) + request = _FakeRequest( + tool_call={"name": "make_widget", "args": {}, "id": "tc-1"}, + runtime=None, + ) + handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc-1")) + with ( + patch_get_flags(_enabled_flags()), + patch("app.db.shielded_async_session", side_effect=lambda: factory()), + ): + await mw.awrap_tool_call(request, handler) + row = captured["rows"][0] + assert row.tool_call_id == "tc-1" + assert row.chat_turn_id is None @pytest.mark.asyncio async def test_writes_row_on_failure_and_reraises( @@ -293,6 +333,76 @@ class TestReverseDescriptor: assert row.reversible is False +class TestActionLogDispatch: + """Verify ``adispatch_custom_event`` fires after commit.""" + + @pytest.mark.asyncio + async def test_dispatches_action_log_event_on_success( + self, patch_get_flags, fake_session_factory + ) -> None: + _captured, factory = fake_session_factory + mw = ActionLogMiddleware(thread_id=42, search_space_id=7, user_id="u1") + request = _FakeRequest( + tool_call={ + "name": "make_widget", + "args": {"color": "red"}, + "id": "tc-evt", + }, + runtime=_FakeRuntime( + config={"configurable": {"turn_id": "42:1700000000000"}} + ), + ) + result_msg = ToolMessage(content="ok", tool_call_id="tc-evt", id="msg-42") + handler = AsyncMock(return_value=result_msg) + + dispatch_mock = AsyncMock() + with ( + patch_get_flags(_enabled_flags()), + patch("app.db.shielded_async_session", side_effect=lambda: factory()), + patch( + "app.agents.new_chat.middleware.action_log.adispatch_custom_event", + dispatch_mock, + ), + ): + await mw.awrap_tool_call(request, handler) + + dispatch_mock.assert_awaited_once() + call_args = dispatch_mock.await_args + assert call_args is not None + assert call_args.args[0] == "action_log" + payload = call_args.args[1] + assert payload["lc_tool_call_id"] == "tc-evt" + assert payload["chat_turn_id"] == "42:1700000000000" + assert payload["tool_name"] == "make_widget" + assert payload["reversible"] is False + assert payload["reverse_descriptor_present"] is False + assert payload["error"] is False + + @pytest.mark.asyncio + async def test_no_dispatch_when_persistence_fails(self, patch_get_flags) -> None: + """If commit fails the dispatch is suppressed (no row to surface).""" + mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None) + request = _FakeRequest( + tool_call={"name": "make_widget", "args": {}, "id": "tc1"} + ) + handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1")) + dispatch_mock = AsyncMock() + + def _exploding_session(): + raise RuntimeError("DB is down") + + with ( + patch_get_flags(_enabled_flags()), + patch("app.db.shielded_async_session", side_effect=_exploding_session), + patch( + "app.agents.new_chat.middleware.action_log.adispatch_custom_event", + dispatch_mock, + ), + ): + await mw.awrap_tool_call(request, handler) + dispatch_mock.assert_not_awaited() + + class TestArgsTruncation: @pytest.mark.asyncio async def test_huge_args_payload_is_truncated( diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_desktop_safety_rules.py b/surfsense_backend/tests/unit/agents/new_chat/test_desktop_safety_rules.py new file mode 100644 index 000000000..653175eab --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_desktop_safety_rules.py @@ -0,0 +1,122 @@ +"""Tests for the desktop-mode safety ruleset. + +In desktop mode the agent operates against the user's real disk with no +revision history, so destructive filesystem operations must require +explicit approval. These tests pin the set of tools that get the ``ask`` +gate so it cannot silently regress. +""" + +from __future__ import annotations + +import pytest + +from app.agents.new_chat.middleware.permission import PermissionMiddleware +from app.agents.new_chat.permissions import ( + Rule, + Ruleset, + aggregate_action, + evaluate_many, +) + +pytestmark = pytest.mark.unit + + +# Mirror the ruleset built inside ``chat_deepagent._build_compiled_agent_blocking`` +# when ``filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER``. Keeping a +# copy here means the rule contract has a focused regression test even when +# the larger graph-build helper is hard to instantiate in unit tests. +DESKTOP_SAFETY_RULESET = Ruleset( + rules=[ + Rule(permission="rm", pattern="*", action="ask"), + Rule(permission="rmdir", pattern="*", action="ask"), + Rule(permission="move_file", pattern="*", action="ask"), + Rule(permission="edit_file", pattern="*", action="ask"), + Rule(permission="write_file", pattern="*", action="ask"), + ], + origin="desktop_safety", +) + +SURFSENSE_DEFAULTS = Ruleset( + rules=[Rule(permission="*", pattern="*", action="allow")], + origin="surfsense_defaults", +) + + +def _action_for(tool_name: str, *rulesets: Ruleset) -> str: + rules = evaluate_many(tool_name, [tool_name], *rulesets) + return aggregate_action(rules) + + +class TestDesktopSafetyRulesGateDestructiveOps: + @pytest.mark.parametrize( + "tool_name", + ["rm", "rmdir", "move_file", "edit_file", "write_file"], + ) + def test_destructive_op_resolves_to_ask(self, tool_name: str) -> None: + # surfsense_defaults says "allow */*"; desktop_safety must override + # because it's layered later (last-match-wins). + action = _action_for(tool_name, SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET) + assert action == "ask", ( + f"{tool_name} must require approval in desktop mode " + f"(no revert path on real disk); got {action!r}" + ) + + @pytest.mark.parametrize( + "tool_name", + ["read_file", "ls", "list_tree", "grep", "glob", "cd", "pwd", "mkdir"], + ) + def test_safe_ops_remain_allowed(self, tool_name: str) -> None: + # Read-only and trivially-reversible tools must NOT get gated — + # otherwise every navigation in desktop mode pops an interrupt. + action = _action_for(tool_name, SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET) + assert action == "allow", ( + f"{tool_name} should not be gated in desktop mode; got {action!r}" + ) + + +class TestDesktopSafetyOverridesAllowDefault: + def test_layer_order_last_match_wins(self) -> None: + # If desktop_safety is layered BEFORE surfsense_defaults, the allow + # default would win and the safety net would be inert. This test + # protects against accidentally swapping the rulesets in + # ``_build_compiled_agent_blocking``. + action = _action_for("rm", DESKTOP_SAFETY_RULESET, SURFSENSE_DEFAULTS) + # Layered "wrong way" — the broad allow now wins. + assert action == "allow" + + # Correct order: defaults < desktop_safety -> ask wins. + action = _action_for("rm", SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET) + assert action == "ask" + + +class TestPermissionMiddlewareIntegration: + def test_middleware_raises_interrupt_for_rm_in_desktop_mode(self) -> None: + from langchain_core.messages import AIMessage + + from app.agents.new_chat.errors import RejectedError + + mw = PermissionMiddleware(rulesets=[SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET]) + # Stub the interrupt to a "reject" decision so we can assert the + # ask path was taken without spinning up the LangGraph runtime. + mw._raise_interrupt = lambda **kw: {"decision_type": "reject"} # type: ignore[assignment] + + state = { + "messages": [ + AIMessage( + content="", + tool_calls=[ + { + "name": "rm", + "args": {"path": "/Users/me/Documents/important.docx"}, + "id": "tc-rm", + } + ], + ) + ] + } + + class _FakeRuntime: + config: dict = {"configurable": {"thread_id": "test"}} + + with pytest.raises(RejectedError): + mw.after_model(state, _FakeRuntime()) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_hitl_auto_approve.py b/surfsense_backend/tests/unit/agents/new_chat/test_hitl_auto_approve.py new file mode 100644 index 000000000..0bbdf37bf --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_hitl_auto_approve.py @@ -0,0 +1,111 @@ +"""Tests for the default auto-approval list in ``hitl.request_approval``. + +These pin the policy that low-stakes connector creation tools (drafts, +new-file creates) skip the HITL interrupt by default. Without this set, +every "draft my newsletter" turn used to fire ~3 interrupts before any +useful work happened. +""" + +from __future__ import annotations + +import pytest + +from app.agents.new_chat.tools.hitl import ( + DEFAULT_AUTO_APPROVED_TOOLS, + HITLResult, + request_approval, +) + +pytestmark = pytest.mark.unit + + +class TestDefaultAutoApprovedToolsList: + def test_set_contains_expected_creation_tools(self) -> None: + # If anyone changes the policy list, we want a single test to + # update so the contract is explicit. Keep this in sync with + # ``hitl.DEFAULT_AUTO_APPROVED_TOOLS``. + expected = { + "create_gmail_draft", + "update_gmail_draft", + "create_notion_page", + "create_confluence_page", + "create_google_drive_file", + "create_dropbox_file", + "create_onedrive_file", + } + assert expected == DEFAULT_AUTO_APPROVED_TOOLS + + def test_set_is_immutable(self) -> None: + # frozenset prevents accidental at-runtime mutation that would + # silently widen the auto-approval surface. + assert isinstance(DEFAULT_AUTO_APPROVED_TOOLS, frozenset) + + def test_send_tools_are_not_auto_approved(self) -> None: + # External-broadcast tools must always prompt. + for tool_name in ( + "send_gmail_email", + "send_discord_message", + "send_teams_message", + "delete_notion_page", + "create_calendar_event", + "delete_calendar_event", + ): + assert tool_name not in DEFAULT_AUTO_APPROVED_TOOLS, ( + f"{tool_name} must remain HITL-gated" + ) + + +class TestRequestApprovalAutoBypass: + def test_auto_approved_tool_skips_interrupt(self) -> None: + # No interrupt mock set up — if the function attempted to call + # ``langgraph.types.interrupt`` it would raise GraphInterrupt. + # The fact that we get a clean HITLResult proves the bypass. + result = request_approval( + action_type="gmail_draft_creation", + tool_name="create_gmail_draft", + params={"to": "alice@example.com", "subject": "hi", "body": "hey"}, + ) + assert isinstance(result, HITLResult) + assert result.rejected is False + assert result.decision_type == "auto_approved" + # Original params are preserved untouched (no user edits possible). + assert result.params == { + "to": "alice@example.com", + "subject": "hi", + "body": "hey", + } + + def test_non_listed_tool_still_attempts_interrupt(self) -> None: + # A tool NOT in the default list must reach ``langgraph.interrupt``. + # Outside a runnable context that call raises a RuntimeError — + # which is exactly the signal we want: the bypass did NOT fire. + with pytest.raises(RuntimeError, match="runnable context"): + request_approval( + action_type="gmail_email_send", + tool_name="send_gmail_email", + params={"to": "alice@example.com", "subject": "hi", "body": "hey"}, + ) + + def test_user_trusted_tools_still_take_precedence(self) -> None: + # ``trusted_tools`` (per-connector "always allow" from MCP/UI) + # was checked BEFORE the default list and must keep working + # for tools outside the default list. + result = request_approval( + action_type="mcp_tool_call", + tool_name="my_custom_mcp_tool", + params={"x": 1}, + trusted_tools=["my_custom_mcp_tool"], + ) + assert result.decision_type == "trusted" + assert result.rejected is False + + def test_auto_approved_overrides_no_trusted_tools(self) -> None: + # When trusted_tools is empty and tool is in the default list, + # we should still bypass — proves the order in request_approval. + result = request_approval( + action_type="notion_page_creation", + tool_name="create_notion_page", + params={"title": "Plan"}, + trusted_tools=[], + ) + assert result.decision_type == "auto_approved" diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_rm_rmdir_cloud.py b/surfsense_backend/tests/unit/agents/new_chat/test_rm_rmdir_cloud.py new file mode 100644 index 000000000..7cabb6524 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_rm_rmdir_cloud.py @@ -0,0 +1,333 @@ +"""Cloud-mode behavior tests for the new ``rm`` and ``rmdir`` filesystem tools. + +The tools build ``Command(update=...)`` payloads that the persistence +middleware applies at end of turn. These tests stub out the backend and +runtime to assert the staging payload shape: + +* ``rm`` queues into ``pending_deletes`` and tombstones state files. +* ``rm`` rejects directories, ``/documents``, root, and the anonymous doc. +* ``rmdir`` queues into ``pending_dir_deletes`` and rejects non-empty dirs. +* ``rmdir`` un-stages a same-turn ``mkdir`` rather than queuing a delete. +* ``rmdir`` refuses to drop the cwd or any of its ancestors. +* ``KBPostgresBackend`` view-helpers honor staged deletes. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock + +import pytest + +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware +from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend + +pytestmark = pytest.mark.unit + + +def _make_middleware(mode: FilesystemMode = FilesystemMode.CLOUD): + middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) + middleware._filesystem_mode = mode + middleware._custom_tool_descriptions = {} + return middleware + + +def _runtime(state: dict[str, Any] | None = None, *, tool_call_id: str = "tc-abc"): + state = state or {} + state.setdefault("cwd", "/documents") + return SimpleNamespace(state=state, tool_call_id=tool_call_id) + + +class _KBBackendStub(KBPostgresBackend): + """Construct-able subclass of :class:`KBPostgresBackend` for tests. + + We bypass the real ``__init__`` (which expects a runtime + DB session) + and inject just the methods the rm/rmdir tools touch. The class + inheritance keeps ``isinstance(backend, KBPostgresBackend)`` checks + inside the tools happy, which is what gates them from the desktop + code path. + """ + + def __init__(self, *, children=None, file_data=None) -> None: + self.als_info = AsyncMock(return_value=children or []) + self._load_file_data = AsyncMock( + return_value=(file_data, 17) if file_data is not None else None + ) + + +def _make_backend_stub(*, children=None, file_data=None) -> KBPostgresBackend: + return _KBBackendStub(children=children, file_data=file_data) + + +def _bind_backend(middleware, backend): + """Inject a backend resolver onto the middleware test instance.""" + middleware._get_backend = lambda runtime: backend + return backend + + +# --------------------------------------------------------------------------- +# rm +# --------------------------------------------------------------------------- + + +class TestRmStaging: + @pytest.mark.asyncio + async def test_stages_delete_and_tombstones_state(self): + m = _make_middleware() + _bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]})) + runtime = _runtime( + { + "cwd": "/documents", + "files": {"/documents/notes.md": {"content": ["hello"]}}, + "doc_id_by_path": {"/documents/notes.md": 17}, + }, + tool_call_id="tc-1", + ) + + tool = m._create_rm_tool() + result = await tool.coroutine("/documents/notes.md", runtime=runtime) + + assert hasattr(result, "update"), f"expected Command, got {result!r}" + update = result.update + assert update["pending_deletes"] == [ + {"path": "/documents/notes.md", "tool_call_id": "tc-1"} + ] + assert update["files"] == {"/documents/notes.md": None} + assert update["doc_id_by_path"] == {"/documents/notes.md": None} + + @pytest.mark.asyncio + async def test_rejects_documents_root(self): + m = _make_middleware() + runtime = _runtime() + tool = m._create_rm_tool() + result = await tool.coroutine("/documents", runtime=runtime) + assert isinstance(result, str) + assert "refusing to rm" in result + + @pytest.mark.asyncio + async def test_rejects_root(self): + m = _make_middleware() + runtime = _runtime() + tool = m._create_rm_tool() + result = await tool.coroutine("/", runtime=runtime) + assert isinstance(result, str) + assert "refusing to rm" in result + + @pytest.mark.asyncio + async def test_rejects_directory_via_staged_dirs(self): + m = _make_middleware() + runtime = _runtime( + { + "staged_dirs": ["/documents/team-x"], + } + ) + tool = m._create_rm_tool() + result = await tool.coroutine("/documents/team-x", runtime=runtime) + assert isinstance(result, str) + assert "directory" in result.lower() + assert "rmdir" in result + + @pytest.mark.asyncio + async def test_rejects_directory_via_listing(self): + m = _make_middleware() + _bind_backend( + m, + _make_backend_stub( + children=[{"path": "/documents/foo/x.md", "is_dir": False}] + ), + ) + runtime = _runtime() + tool = m._create_rm_tool() + result = await tool.coroutine("/documents/foo", runtime=runtime) + assert isinstance(result, str) + assert "directory" in result.lower() + + @pytest.mark.asyncio + async def test_rejects_anonymous_doc(self): + m = _make_middleware() + runtime = _runtime( + { + "kb_anon_doc": { + "path": "/documents/uploaded.xml", + "title": "uploaded", + "content": "", + "chunks": [], + } + } + ) + tool = m._create_rm_tool() + result = await tool.coroutine("/documents/uploaded.xml", runtime=runtime) + assert isinstance(result, str) + assert "read-only" in result + + @pytest.mark.asyncio + async def test_drops_path_from_dirty_paths(self): + m = _make_middleware() + _bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]})) + runtime = _runtime( + { + "files": {"/documents/notes.md": {"content": ["x"]}}, + "doc_id_by_path": {"/documents/notes.md": 17}, + "dirty_paths": ["/documents/notes.md"], + } + ) + tool = m._create_rm_tool() + result = await tool.coroutine("/documents/notes.md", runtime=runtime) + update = result.update + # First element is _CLEAR sentinel; the rest must NOT contain the + # rm'd path. + dirty = update.get("dirty_paths") or [] + assert "/documents/notes.md" not in dirty[1:] + + +# --------------------------------------------------------------------------- +# rmdir +# --------------------------------------------------------------------------- + + +class TestRmdirStaging: + @pytest.mark.asyncio + async def test_stages_dir_delete_when_empty_and_db_backed(self): + m = _make_middleware() + backend = _bind_backend(m, _make_backend_stub(children=[])) + # Override _load_file_data to return None (folder, not a file) and + # parent listing to claim the folder exists. + backend._load_file_data = AsyncMock(return_value=None) + backend.als_info = AsyncMock( + side_effect=[ + [], # children of /documents/proj + [ + {"path": "/documents/proj", "is_dir": True}, + ], # parent listing + ] + ) + runtime = _runtime( + { + "cwd": "/documents", + }, + tool_call_id="tc-rd", + ) + + tool = m._create_rmdir_tool() + result = await tool.coroutine("/documents/proj", runtime=runtime) + + assert hasattr(result, "update") + update = result.update + assert update["pending_dir_deletes"] == [ + {"path": "/documents/proj", "tool_call_id": "tc-rd"} + ] + + @pytest.mark.asyncio + async def test_rejects_non_empty(self): + m = _make_middleware() + _bind_backend( + m, + _make_backend_stub( + children=[{"path": "/documents/proj/x.md", "is_dir": False}] + ), + ) + runtime = _runtime() + tool = m._create_rmdir_tool() + result = await tool.coroutine("/documents/proj", runtime=runtime) + assert isinstance(result, str) + assert "not empty" in result + + @pytest.mark.asyncio + async def test_unstages_same_turn_mkdir(self): + m = _make_middleware() + _bind_backend(m, _make_backend_stub(children=[])) + runtime = _runtime( + { + "cwd": "/documents", + "staged_dirs": ["/documents/scratch"], + }, + tool_call_id="tc-rd", + ) + tool = m._create_rmdir_tool() + result = await tool.coroutine("/documents/scratch", runtime=runtime) + + assert hasattr(result, "update") + update = result.update + assert "pending_dir_deletes" not in update + # _CLEAR sentinel + remaining items (in this case, none). + staged_after = update["staged_dirs"] + assert staged_after[0] == "\x00__SURFSENSE_FILESYSTEM_CLEAR__\x00" + assert "/documents/scratch" not in staged_after[1:] + + @pytest.mark.asyncio + async def test_rejects_root(self): + m = _make_middleware() + runtime = _runtime() + tool = m._create_rmdir_tool() + for victim in ("/", "/documents"): + result = await tool.coroutine(victim, runtime=runtime) + assert isinstance(result, str) + assert "refusing to rmdir" in result + + @pytest.mark.asyncio + async def test_rejects_cwd(self): + m = _make_middleware() + runtime = _runtime({"cwd": "/documents/proj"}) + tool = m._create_rmdir_tool() + result = await tool.coroutine("/documents/proj", runtime=runtime) + assert isinstance(result, str) + assert "cwd" in result.lower() + + @pytest.mark.asyncio + async def test_rejects_ancestor_of_cwd(self): + m = _make_middleware() + runtime = _runtime({"cwd": "/documents/proj/sub"}) + tool = m._create_rmdir_tool() + result = await tool.coroutine("/documents/proj", runtime=runtime) + assert isinstance(result, str) + assert "cwd" in result.lower() + + @pytest.mark.asyncio + async def test_rejects_files(self): + m = _make_middleware() + _bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]})) + runtime = _runtime() + tool = m._create_rmdir_tool() + result = await tool.coroutine("/documents/notes.md", runtime=runtime) + assert isinstance(result, str) + assert "is a file" in result + + +# --------------------------------------------------------------------------- +# KBPostgresBackend view filter +# --------------------------------------------------------------------------- + + +class TestKBPostgresBackendDeleteFilter: + """als_info / glob / grep should suppress paths queued for delete.""" + + def _make_backend(self, state: dict[str, Any]) -> KBPostgresBackend: + runtime = SimpleNamespace(state=state) + backend = KBPostgresBackend(search_space_id=1, runtime=runtime) + return backend + + def test_pending_filesystem_view_returns_deleted_paths(self): + backend = self._make_backend( + { + "pending_deletes": [ + {"path": "/documents/x.md", "tool_call_id": "t1"}, + ], + "pending_dir_deletes": [ + {"path": "/documents/d1", "tool_call_id": "t2"}, + ], + } + ) + removed, alias, deleted_dirs = backend._pending_filesystem_view({}) + assert "/documents/x.md" in removed + assert "/documents/d1" in deleted_dirs + assert alias == {} + + def test_dir_suppressed_covers_descendants(self): + backend = self._make_backend({}) + deleted_dirs = {"/documents/d"} + assert backend._is_dir_suppressed("/documents/d", deleted_dirs) + assert backend._is_dir_suppressed("/documents/d/x.md", deleted_dirs) + assert backend._is_dir_suppressed("/documents/d/sub/y.md", deleted_dirs) + assert not backend._is_dir_suppressed("/documents/other.md", deleted_dirs) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_state_reducers.py b/surfsense_backend/tests/unit/agents/new_chat/test_state_reducers.py index 3caeb9a34..185753990 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_state_reducers.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_state_reducers.py @@ -98,10 +98,54 @@ class TestInitialFilesystemState: state = _initial_filesystem_state() assert state["cwd"] == "/documents" assert state["staged_dirs"] == [] + assert state["staged_dir_tool_calls"] == {} assert state["pending_moves"] == [] + assert state["pending_deletes"] == [] + assert state["pending_dir_deletes"] == [] assert state["doc_id_by_path"] == {} assert state["dirty_paths"] == [] + assert state["dirty_path_tool_calls"] == {} assert state["kb_priority"] == [] assert state["kb_matched_chunk_ids"] == {} assert state["kb_anon_doc"] is None assert state["tree_version"] == 0 + + +class TestMultiEditSamePathCoalescing: + """Multi-edit-same-path turns must coalesce into ONE binding record. + + The persistence body uses ``dirty_path_tool_calls[path]`` to find the + tool_call_id that produced the current state on disk. Because + ``dirty_paths`` dedupes via :func:`_add_unique_reducer` the second + edit doesn't append a new path entry — and because + ``_dict_merge_with_tombstones_reducer`` lets the right-hand side + overwrite, the LATEST tool_call_id wins. That's the correct behavior + for snapshotting: revert restores to the pre-mutation state, and + multiple back-to-back edits in one turn coalesce into a single + revisible op (the user sees ONE Revert button per turn-per-path, + not N). + """ + + def test_dirty_paths_dedupes_repeated_writes(self): + # ``_add_unique_reducer`` is applied to ``dirty_paths``. Two writes + # to the same path produce one entry, not two. + first = _add_unique_reducer([], ["/documents/a.md"]) + second = _add_unique_reducer(first, ["/documents/a.md"]) + assert second == ["/documents/a.md"] + + def test_dirty_path_tool_calls_keeps_latest_tool_call_id(self): + # First write tags the path with tcid-1. + merged = _dict_merge_with_tombstones_reducer({}, {"/documents/a.md": "tcid-1"}) + # Second write to the same path tags it with tcid-2 (latest wins). + merged = _dict_merge_with_tombstones_reducer( + merged, {"/documents/a.md": "tcid-2"} + ) + assert merged == {"/documents/a.md": "tcid-2"} + + def test_rm_tombstones_dirty_path_tool_call(self): + # ``rm`` writes ``{path: None}`` into dirty_path_tool_calls to + # prevent a stale binding from leaking past the delete. + merged = _dict_merge_with_tombstones_reducer( + {"/documents/a.md": "tcid-1"}, {"/documents/a.md": None} + ) + assert merged == {} diff --git a/surfsense_backend/tests/unit/db/__init__.py b/surfsense_backend/tests/unit/db/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/db/test_relax_revision_fks_migration.py b/surfsense_backend/tests/unit/db/test_relax_revision_fks_migration.py new file mode 100644 index 000000000..82c299488 --- /dev/null +++ b/surfsense_backend/tests/unit/db/test_relax_revision_fks_migration.py @@ -0,0 +1,83 @@ +"""Smoke test for the ``134_relax_revision_fks`` Alembic migration. + +A full apply/rollback test would require a live Postgres; here we verify +the migration module's static contract: + +* The chain wires it as a successor of ``133_drop_documents_content_hash_unique``. +* ``upgrade()`` declares two FK creations with ``ondelete='SET NULL'`` + (one for ``document_revisions.document_id``, one for + ``folder_revisions.folder_id``). +* ``downgrade()`` re-establishes ``ondelete='CASCADE'`` after draining + orphaned revisions. + +If any of these invariants regress the snapshot/revert pipeline silently +loses the ability to undo ``rm`` / ``rmdir`` on environments that ran the +migration "down" or never ran it at all. +""" + +from __future__ import annotations + +import importlib.util +import inspect +from pathlib import Path + +import pytest + +pytestmark = pytest.mark.unit + + +_MIGRATION_PATH = ( + Path(__file__).resolve().parents[3] + / "alembic" + / "versions" + / "134_relax_revision_fks.py" +) + + +def _load_migration(): + """Load the migration module by file path (no package import needed).""" + spec = importlib.util.spec_from_file_location("_migration_134", _MIGRATION_PATH) + assert spec and spec.loader, "could not load migration spec" + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_migration_chain_revision_ids() -> None: + module = _load_migration() + # The migration file uses short numeric revision IDs to match the + # in-tree convention (cf. ``133`` -> ``134``); the ``134_.py`` + # filename is documentation, not the canonical revision string. + assert getattr(module, "revision", None) == "134" + assert getattr(module, "down_revision", None) == "133" + + +def test_migration_exposes_upgrade_and_downgrade() -> None: + module = _load_migration() + upgrade = getattr(module, "upgrade", None) + downgrade = getattr(module, "downgrade", None) + assert callable(upgrade), "upgrade() is required" + assert callable(downgrade), "downgrade() is required" + + +def test_upgrade_creates_set_null_fks_for_both_revision_tables() -> None: + module = _load_migration() + src = inspect.getsource(module.upgrade) + assert "document_revisions" in src + assert "folder_revisions" in src + # Both new FKs MUST be ON DELETE SET NULL — that's the entire point + # of the migration: snapshots must outlive their parent row. + assert src.count('ondelete="SET NULL"') >= 2 + # And the ``document_id`` / ``folder_id`` columns become nullable. + assert "nullable=True" in src + + +def test_downgrade_drains_orphans_then_restores_cascade() -> None: + module = _load_migration() + src = inspect.getsource(module.downgrade) + # Drain orphaned rows BEFORE we can re-impose NOT NULL. + assert "DELETE FROM document_revisions WHERE document_id IS NULL" in src + assert "DELETE FROM folder_revisions WHERE folder_id IS NULL" in src + # Then restore the original CASCADE/NOT NULL contract. + assert src.count('ondelete="CASCADE"') >= 2 + assert "nullable=False" in src diff --git a/surfsense_backend/tests/unit/middleware/test_filesystem_middleware.py b/surfsense_backend/tests/unit/middleware/test_filesystem_middleware.py index c2e304399..70430f4ca 100644 --- a/surfsense_backend/tests/unit/middleware/test_filesystem_middleware.py +++ b/surfsense_backend/tests/unit/middleware/test_filesystem_middleware.py @@ -168,6 +168,8 @@ class TestModeSpecificPrompts: "edit_file", "move_file", "mkdir", + "rm", + "rmdir", "list_tree", "grep", ): @@ -182,6 +184,8 @@ class TestModeSpecificPrompts: "edit_file", "move_file", "mkdir", + "rm", + "rmdir", "list_tree", "grep", ): @@ -190,6 +194,18 @@ class TestModeSpecificPrompts: assert "/documents/" not in text, f"{name} mentions cloud namespace" assert "temp_" not in text, f"{name} mentions cloud temp_ semantics" + def test_cloud_descs_include_rm_and_rmdir(self): + descs = _build_tool_descriptions(FilesystemMode.CLOUD) + assert "rm" in descs and "rmdir" in descs + assert "Deletes a single file" in descs["rm"] + assert "Deletes an empty directory" in descs["rmdir"] + assert "rmdir" in descs["rmdir"] and "POSIX" in descs["rmdir"] + + def test_desktop_descs_warn_about_irreversibility(self): + descs = _build_tool_descriptions(FilesystemMode.DESKTOP_LOCAL_FOLDER) + assert "NOT reversible" in descs["rm"] + assert "NOT reversible" in descs["rmdir"] + def test_sandbox_addendum_appended_when_available(self): prompt = _build_filesystem_system_prompt( FilesystemMode.CLOUD, sandbox_available=True diff --git a/surfsense_backend/tests/unit/middleware/test_kb_persistence_revisions.py b/surfsense_backend/tests/unit/middleware/test_kb_persistence_revisions.py new file mode 100644 index 000000000..feca23d27 --- /dev/null +++ b/surfsense_backend/tests/unit/middleware/test_kb_persistence_revisions.py @@ -0,0 +1,309 @@ +"""Unit tests for the kb_persistence snapshot helpers. + +The full ``commit_staged_filesystem_state`` body exercises a real session +in integration tests; here we verify the building blocks used by the +snapshot/revert pipeline: + +* ``_find_action_ids_batch`` issues a SINGLE query for N tool_call_ids + (regression guard against the N+1 lookup pattern). +* ``_mark_action_reversible`` is a no-op when ``action_id`` is ``None``. +* ``_doc_revision_payload`` and ``_load_chunks_for_snapshot`` produce the + shape the snapshot helpers consume. + +These tests use ``MagicMock`` / ``AsyncMock`` against a fake session so +the assertions run in milliseconds and don't require Postgres. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.agents.new_chat.middleware import kb_persistence + +pytestmark = pytest.mark.unit + + +class _FakeResult: + def __init__(self, rows: list[Any] | None = None, scalar: Any = None) -> None: + self._rows = rows or [] + self._scalar = scalar + + def all(self) -> list[Any]: + return list(self._rows) + + def scalar_one_or_none(self) -> Any: + return self._scalar + + +class _FakeSession: + def __init__(self) -> None: + self.execute = AsyncMock() + + +@pytest.mark.asyncio +async def test_find_action_ids_batch_issues_single_query() -> None: + """The lookup MUST be a single ``IN (...)`` SELECT, not N selects.""" + session = _FakeSession() + session.execute.return_value = _FakeResult( + rows=[ + MagicMock(id=11, tool_call_id="tc-a"), + MagicMock(id=22, tool_call_id="tc-b"), + MagicMock(id=33, tool_call_id="tc-c"), + ] + ) + + mapping = await kb_persistence._find_action_ids_batch( + session, # type: ignore[arg-type] + thread_id=1, + tool_call_ids={"tc-a", "tc-b", "tc-c"}, + ) + + assert mapping == {"tc-a": 11, "tc-b": 22, "tc-c": 33} + assert session.execute.await_count == 1, ( + "Snapshot binding must batch into ONE query; got " + f"{session.execute.await_count} (regression: N+1 lookup pattern)." + ) + + +@pytest.mark.asyncio +async def test_find_action_ids_batch_short_circuits_when_thread_id_missing() -> None: + session = _FakeSession() + mapping = await kb_persistence._find_action_ids_batch( + session, # type: ignore[arg-type] + thread_id=None, + tool_call_ids={"tc-a"}, + ) + assert mapping == {} + assert session.execute.await_count == 0 + + +@pytest.mark.asyncio +async def test_find_action_ids_batch_short_circuits_when_no_calls() -> None: + session = _FakeSession() + mapping = await kb_persistence._find_action_ids_batch( + session, # type: ignore[arg-type] + thread_id=42, + tool_call_ids=set(), + ) + assert mapping == {} + assert session.execute.await_count == 0 + + +@pytest.mark.asyncio +async def test_mark_action_reversible_is_noop_for_null_id() -> None: + session = _FakeSession() + await kb_persistence._mark_action_reversible(session, action_id=None) # type: ignore[arg-type] + assert session.execute.await_count == 0 + + +@pytest.mark.asyncio +async def test_mark_action_reversible_runs_update_for_real_id() -> None: + session = _FakeSession() + await kb_persistence._mark_action_reversible(session, action_id=99) # type: ignore[arg-type] + assert session.execute.await_count == 1 + + +def test_doc_revision_payload_captures_metadata_virtual_path() -> None: + """Snapshot helpers must capture ``metadata_before`` for revert reuse.""" + doc = MagicMock() + doc.content = "body" + doc.title = "notes.md" + doc.folder_id = 7 + doc.document_metadata = {"virtual_path": "/documents/team/notes.md"} + + payload = kb_persistence._doc_revision_payload( + doc, chunks_before=[{"content": "x"}] + ) + + assert payload["title_before"] == "notes.md" + assert payload["folder_id_before"] == 7 + assert payload["content_before"] == "body" + assert payload["chunks_before"] == [{"content": "x"}] + assert payload["metadata_before"] == {"virtual_path": "/documents/team/notes.md"} + + +def test_doc_revision_payload_handles_missing_metadata() -> None: + doc = MagicMock() + doc.content = "" + doc.title = "" + doc.folder_id = None + doc.document_metadata = None + payload = kb_persistence._doc_revision_payload(doc) + assert payload["metadata_before"] is None + + +@pytest.mark.asyncio +async def test_load_chunks_for_snapshot_returns_content_only() -> None: + """Snapshot chunks intentionally omit embeddings (regenerated on revert).""" + session = _FakeSession() + session.execute.return_value = _FakeResult( + rows=[ + MagicMock(content="alpha"), + MagicMock(content="beta"), + ] + ) + chunks = await kb_persistence._load_chunks_for_snapshot( + session, + doc_id=42, # type: ignore[arg-type] + ) + assert chunks == [{"content": "alpha"}, {"content": "beta"}] + + +# --------------------------------------------------------------------------- +# Deferred reversibility-flip dispatches. +# +# The snapshot helpers used to dispatch ``action_log_updated`` directly +# from inside the SAVEPOINT block. That meant the SSE side-channel +# could tell the UI a row was reversible while the OUTER transaction +# was still pending — and if the outer commit failed, every SAVEPOINT +# rolled back too, leaving the UI in a state inconsistent with +# durable storage. The deferred-dispatch contract fixes that: +# +# • when a ``deferred_dispatches`` list is provided, the helper +# APPENDS the action_id and does NOT dispatch; +# • the caller (``commit_staged_filesystem_state``) flushes the list +# only AFTER ``await session.commit()`` succeeds; on rollback it +# clears the list so nothing is emitted. +# --------------------------------------------------------------------------- + + +class _NestedCtx: + """Async context manager mimicking ``session.begin_nested()``.""" + + async def __aenter__(self) -> _NestedCtx: + return self + + async def __aexit__(self, exc_type, exc, tb) -> bool: + return False + + +@pytest.mark.asyncio +async def test_pre_write_snapshot_defers_dispatch_when_list_provided( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Helpers MUST queue dispatches when ``deferred_dispatches`` is set.""" + session = MagicMock() + session.begin_nested = MagicMock(return_value=_NestedCtx()) + session.execute = AsyncMock(return_value=_FakeResult(rows=[])) + session.flush = AsyncMock() + + def _add(rev: Any) -> None: + rev.id = 17 + + session.add = MagicMock(side_effect=_add) + + dispatched: list[int] = [] + + async def _fake_dispatch(action_id: int | None) -> None: + if action_id is not None: + dispatched.append(int(action_id)) + + monkeypatch.setattr( + kb_persistence, "_dispatch_reversibility_update", _fake_dispatch + ) + + deferred: list[int] = [] + doc = MagicMock(id=99, document_metadata={"virtual_path": "/documents/x.md"}) + doc.title = "x.md" + doc.folder_id = None + doc.content = "body" + + rev_id = await kb_persistence._snapshot_document_pre_write( + session, # type: ignore[arg-type] + doc=doc, + action_id=42, + search_space_id=1, + turn_id="t-1", + deferred_dispatches=deferred, + ) + + assert rev_id == 17 + # Inline dispatch must NOT have fired; the action_id is queued. + assert dispatched == [] + assert deferred == [42] + + +@pytest.mark.asyncio +async def test_pre_write_snapshot_dispatches_inline_when_list_omitted( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Direct callers (no outer transaction) keep the legacy inline dispatch.""" + session = MagicMock() + session.begin_nested = MagicMock(return_value=_NestedCtx()) + session.execute = AsyncMock(return_value=_FakeResult(rows=[])) + session.flush = AsyncMock() + + def _add(rev: Any) -> None: + rev.id = 7 + + session.add = MagicMock(side_effect=_add) + + dispatched: list[int] = [] + + async def _fake_dispatch(action_id: int | None) -> None: + if action_id is not None: + dispatched.append(int(action_id)) + + monkeypatch.setattr( + kb_persistence, "_dispatch_reversibility_update", _fake_dispatch + ) + + doc = MagicMock(id=11, document_metadata={"virtual_path": "/documents/y.md"}) + doc.title = "y.md" + doc.folder_id = None + doc.content = "body" + + await kb_persistence._snapshot_document_pre_write( + session, # type: ignore[arg-type] + doc=doc, + action_id=88, + search_space_id=1, + turn_id="t-1", + # No deferred_dispatches arg — fall back to inline dispatch. + ) + + assert dispatched == [88] + + +@pytest.mark.asyncio +async def test_pre_mkdir_snapshot_defers_dispatch_when_list_provided( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Folder mkdir snapshots honour the same deferred-dispatch contract.""" + session = MagicMock() + session.begin_nested = MagicMock(return_value=_NestedCtx()) + session.execute = AsyncMock() # _mark_action_reversible calls execute + session.flush = AsyncMock() + + def _add(rev: Any) -> None: + rev.id = 3 + + session.add = MagicMock(side_effect=_add) + + dispatched: list[int] = [] + + async def _fake_dispatch(action_id: int | None) -> None: + if action_id is not None: + dispatched.append(int(action_id)) + + monkeypatch.setattr( + kb_persistence, "_dispatch_reversibility_update", _fake_dispatch + ) + + deferred: list[int] = [] + folder = MagicMock(id=2, name="f", parent_id=None, position="a0") + + await kb_persistence._snapshot_folder_pre_mkdir( + session, # type: ignore[arg-type] + folder=folder, + action_id=55, + search_space_id=1, + turn_id="t-1", + deferred_dispatches=deferred, + ) + + assert dispatched == [] + assert deferred == [55] diff --git a/surfsense_backend/tests/unit/middleware/test_knowledge_tree.py b/surfsense_backend/tests/unit/middleware/test_knowledge_tree.py new file mode 100644 index 000000000..caaec3114 --- /dev/null +++ b/surfsense_backend/tests/unit/middleware/test_knowledge_tree.py @@ -0,0 +1,139 @@ +"""Unit tests for ``KnowledgeTreeMiddleware`` rendering. + +The empty-folder marker is critical UX: without it, the LLM cannot +distinguish a leaf folder containing one document from a leaf folder +that has no descendants at all, and ends up firing ``rmdir`` on +non-empty folders. These tests pin the rendering contract so that +contract cannot silently regress. +""" + +from __future__ import annotations + +from app.agents.new_chat.middleware.knowledge_tree import KnowledgeTreeMiddleware +from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT + + +def _compute(folder_paths: list[str], doc_paths: list[str]) -> set[str]: + return KnowledgeTreeMiddleware._compute_non_empty_folders(folder_paths, doc_paths) + + +class TestComputeNonEmptyFolders: + def test_folder_with_direct_document_is_non_empty(self): + folder_paths = [f"{DOCUMENTS_ROOT}/Travel/Boarding Pass"] + doc_paths = [ + f"{DOCUMENTS_ROOT}/Travel/Boarding Pass/southwest.pdf.xml", + ] + non_empty = _compute(folder_paths, doc_paths) + assert f"{DOCUMENTS_ROOT}/Travel/Boarding Pass" in non_empty + + def test_truly_empty_leaf_folder_is_not_non_empty(self): + folder_paths = [f"{DOCUMENTS_ROOT}/Travel/Boarding Pass"] + doc_paths: list[str] = [] + assert _compute(folder_paths, doc_paths) == set() + + def test_documents_propagate_up_to_all_ancestors(self): + folder_paths = [ + f"{DOCUMENTS_ROOT}/A", + f"{DOCUMENTS_ROOT}/A/B", + f"{DOCUMENTS_ROOT}/A/B/C", + ] + doc_paths = [f"{DOCUMENTS_ROOT}/A/B/C/file.xml"] + non_empty = _compute(folder_paths, doc_paths) + assert non_empty == { + f"{DOCUMENTS_ROOT}/A", + f"{DOCUMENTS_ROOT}/A/B", + f"{DOCUMENTS_ROOT}/A/B/C", + } + + def test_chain_with_subfolders_marks_only_leaf_empty(self): + # POSIX-like semantic: a folder is "empty" only if it has no + # immediate children (docs OR sub-folders). The model needs this + # because parallel ``rmdir`` calls all see the same starting state, + # so trying to rmdir a parent before its children is never safe. + folder_paths = [ + f"{DOCUMENTS_ROOT}/X", + f"{DOCUMENTS_ROOT}/X/Y", + f"{DOCUMENTS_ROOT}/X/Y/Z", + ] + non_empty = _compute(folder_paths, []) + # Only ``X/Y/Z`` (the leaf) is empty. ``X`` and ``X/Y`` each have a + # sub-folder child, so they are non-empty and should NOT carry the + # ``(empty)`` marker. + assert non_empty == {f"{DOCUMENTS_ROOT}/X", f"{DOCUMENTS_ROOT}/X/Y"} + + def test_sibling_with_doc_does_not_mark_other_sibling_non_empty(self): + # Mirrors a real DB layout where every intermediate folder is + # materialized in the ``folders`` table. + folder_paths = [ + f"{DOCUMENTS_ROOT}/Travel", + f"{DOCUMENTS_ROOT}/Travel/Boarding Pass", + f"{DOCUMENTS_ROOT}/Travel/Notes", + ] + doc_paths = [f"{DOCUMENTS_ROOT}/Travel/Notes/itinerary.xml"] + non_empty = _compute(folder_paths, doc_paths) + # ``Travel`` is non-empty because it has children, ``Notes`` is non-empty + # because of the doc, but ``Boarding Pass`` (sibling leaf) is empty. + assert f"{DOCUMENTS_ROOT}/Travel" in non_empty + assert f"{DOCUMENTS_ROOT}/Travel/Notes" in non_empty + assert f"{DOCUMENTS_ROOT}/Travel/Boarding Pass" not in non_empty + + +class TestFormatTreeRendering: + """Integration check: empty leaf gets ``(empty)`` marker; non-empty doesn't.""" + + def _render( + self, + folder_paths: list[str], + doc_specs: list[dict], + ) -> str: + from app.agents.new_chat.path_resolver import PathIndex + + index = PathIndex( + folder_paths={i + 1: p for i, p in enumerate(folder_paths)}, + ) + + class _Row: + def __init__(self, **kw): + self.__dict__.update(kw) + + docs = [_Row(**spec) for spec in doc_specs] + + mw = KnowledgeTreeMiddleware( + search_space_id=1, + filesystem_mode=None, # type: ignore[arg-type] + ) + return mw._format_tree(index, docs) + + def test_renders_empty_marker_only_for_truly_empty_folders(self): + # Reproduces the failure scenario from the bug report: + # ``Boarding Pass`` is empty (its only doc was just deleted), while + # ``Tax Returns`` still has ``federal.pdf``. All intermediate + # folders are present in the index, mirroring the real DB layout. + folder_paths = [ + "/documents/File Upload", + "/documents/File Upload/2026-04-08", + "/documents/File Upload/2026-04-08/Travel", + "/documents/File Upload/2026-04-08/Travel/Boarding Pass", + "/documents/File Upload/2026-04-15", + "/documents/File Upload/2026-04-15/Finance", + "/documents/File Upload/2026-04-15/Finance/Tax Returns", + ] + tax_returns_folder_id = ( + folder_paths.index("/documents/File Upload/2026-04-15/Finance/Tax Returns") + + 1 + ) + rendered = self._render( + folder_paths=folder_paths, + doc_specs=[ + { + "id": 100, + "title": "federal.pdf", + "folder_id": tax_returns_folder_id, + }, + ], + ) + assert "Boarding Pass/ (empty)" in rendered + assert "Tax Returns/ (empty)" not in rendered + # Intermediate ancestors of the doc must NOT be marked empty. + assert "Finance/ (empty)" not in rendered + assert "2026-04-15/ (empty)" not in rendered diff --git a/surfsense_backend/tests/unit/middleware/test_local_folder_backend.py b/surfsense_backend/tests/unit/middleware/test_local_folder_backend.py index 7dfc68402..6e81ecf8e 100644 --- a/surfsense_backend/tests/unit/middleware/test_local_folder_backend.py +++ b/surfsense_backend/tests/unit/middleware/test_local_folder_backend.py @@ -69,3 +69,74 @@ def test_local_backend_write_rejects_missing_parent_directory(tmp_path: Path): assert write.error is not None assert "parent directory" in write.error assert not (tmp_path / "tempoo").exists() + + +def test_local_backend_delete_file_success(tmp_path: Path): + backend = LocalFolderBackend(str(tmp_path)) + (tmp_path / "delete-me.md").write_text("bye") + + res = backend.delete_file("/delete-me.md") + assert res.error is None + assert res.path == "/delete-me.md" + assert not (tmp_path / "delete-me.md").exists() + + +def test_local_backend_delete_file_rejects_directory(tmp_path: Path): + backend = LocalFolderBackend(str(tmp_path)) + (tmp_path / "subdir").mkdir() + + res = backend.delete_file("/subdir") + assert res.error is not None + assert "directory" in res.error + assert (tmp_path / "subdir").exists() + + +def test_local_backend_delete_file_missing_returns_error(tmp_path: Path): + backend = LocalFolderBackend(str(tmp_path)) + + res = backend.delete_file("/nope.md") + assert res.error is not None + assert "not found" in res.error + + +def test_local_backend_rmdir_success(tmp_path: Path): + backend = LocalFolderBackend(str(tmp_path)) + (tmp_path / "empty").mkdir() + + res = backend.rmdir("/empty") + assert res.error is None + assert res.path == "/empty" + assert not (tmp_path / "empty").exists() + + +def test_local_backend_rmdir_rejects_non_empty(tmp_path: Path): + backend = LocalFolderBackend(str(tmp_path)) + (tmp_path / "withkid").mkdir() + (tmp_path / "withkid" / "child.md").write_text("x") + + res = backend.rmdir("/withkid") + assert res.error is not None + assert "not empty" in res.error + assert (tmp_path / "withkid" / "child.md").exists() + + +def test_local_backend_rmdir_rejects_file(tmp_path: Path): + backend = LocalFolderBackend(str(tmp_path)) + (tmp_path / "f.md").write_text("x") + + res = backend.rmdir("/f.md") + assert res.error is not None + assert "not a directory" in res.error + + +def test_local_backend_rmdir_rejects_root(tmp_path: Path): + """``rmdir /`` MUST fail. The exact error wording comes from + ``_resolve_virtual`` (root resolves to outside the sandbox); what + matters is that the call returns an error and does NOT delete the + sandbox root on disk.""" + backend = LocalFolderBackend(str(tmp_path)) + + res = backend.rmdir("/") + assert res.error is not None + assert "Invalid path" in res.error or "root" in res.error + assert tmp_path.exists() diff --git a/surfsense_backend/tests/unit/routes/__init__.py b/surfsense_backend/tests/unit/routes/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/routes/test_regenerate_from_message_id.py b/surfsense_backend/tests/unit/routes/test_regenerate_from_message_id.py new file mode 100644 index 000000000..709014d55 --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_regenerate_from_message_id.py @@ -0,0 +1,143 @@ +"""Unit tests for the edit-from-arbitrary-position helpers inside ``new_chat_routes``. + +The regenerate route's edit-from-position path introduces: +* ``_find_pre_turn_checkpoint_id`` — walks LangGraph checkpoint tuples + newest-first and picks the first one whose ``metadata["turn_id"]`` + differs from the edited turn. That checkpoint is the rewind target + (state immediately before the edited turn started). +* ``RegenerateRequest`` accepts ``from_message_id`` + ``revert_actions`` + with a validator that prevents callers from requesting a revert pass + without specifying which turn to roll back. + +These are pure-Python helpers that don't need a live DB, so we exercise +them with a small ``CheckpointTuple``-shaped namespace and direct +schema instantiation. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from app.routes.new_chat_routes import _find_pre_turn_checkpoint_id +from app.schemas.new_chat import RegenerateRequest + + +def _cp(checkpoint_id: str, turn_id: str | None) -> SimpleNamespace: + """Build a fake ``CheckpointTuple`` with the metadata shape we read.""" + return SimpleNamespace( + config={"configurable": {"checkpoint_id": checkpoint_id}}, + metadata={"turn_id": turn_id} if turn_id is not None else {}, + ) + + +class TestFindPreTurnCheckpointId: + def test_returns_last_pre_turn_checkpoint_when_editing_latest_turn(self) -> None: + # Newest-first: T2 is the most-recent turn. The latest non-T2 + # checkpoint (cp2) is the rewind target — state immediately + # before T2 began. + tuples = [ + _cp("cp4", "T2"), + _cp("cp3", "T2"), + _cp("cp2", "T1"), + _cp("cp1", "T1"), + ] + assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2" + + def test_returns_pre_turn_checkpoint_when_later_turns_exist(self) -> None: + # Regression for the bug where walking newest-first returned the + # FIRST cp with ``turn_id != target`` — which is one of the + # later-turn checkpoints, NOT the pre-turn boundary. Editing + # T2 must rewind to the latest T1 checkpoint (cp2), not to the + # latest T3 checkpoint (cp6). + tuples = [ + _cp("cp6", "T3"), + _cp("cp5", "T3"), + _cp("cp4", "T2"), + _cp("cp3", "T2"), + _cp("cp2", "T1"), + _cp("cp1", "T1"), + ] + assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2" + + def test_returns_none_when_editing_first_turn(self) -> None: + # No pre-turn boundary exists; caller is expected to fall back + # to the oldest checkpoint or special-case "first turn of the + # thread". + tuples = [ + _cp("cp4", "T2"), + _cp("cp3", "T2"), + _cp("cp2", "T1"), + _cp("cp1", "T1"), + ] + assert _find_pre_turn_checkpoint_id(tuples, turn_id="T1") is None + + def test_returns_none_when_only_edited_turn_present(self) -> None: + tuples = [_cp("cp2", "T2"), _cp("cp1", "T2")] + assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") is None + + def test_returns_none_for_empty_history(self) -> None: + assert _find_pre_turn_checkpoint_id([], turn_id="T1") is None + + def test_legacy_checkpoints_without_turn_id_count_as_pre_turn(self) -> None: + # Checkpoints written before migration 136 have no + # ``metadata.turn_id``. They should be eligible rewind targets + # — they came before the + # edited turn began. + tuples = [ + _cp("cp3", "T2"), + SimpleNamespace( + config={"configurable": {"checkpoint_id": "cp2"}}, + metadata=None, + ), + _cp("cp1", "T1"), + ] + # Walking oldest-first: cp1(T1) tracked, cp2(legacy/None) tracked, + # then cp3(T2) crosses the boundary -> return cp2. + assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2" + + def test_skips_checkpoint_missing_checkpoint_id_in_config(self) -> None: + # If a checkpoint tuple's ``config["configurable"]`` is missing + # the ``checkpoint_id`` key (corrupt / partial), we keep the + # last known good target instead of crashing. + broken = SimpleNamespace( + config={"configurable": {}}, metadata={"turn_id": "T1"} + ) + tuples = [ + _cp("cp3", "T2"), + broken, + _cp("cp1", "T1"), + ] + # cp1(T1) tracked, broken skipped, cp3(T2) -> return cp1. + assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp1" + + +class TestRegenerateRequestValidation: + def test_revert_actions_requires_from_message_id(self) -> None: + with pytest.raises(Exception) as exc: + RegenerateRequest( + search_space_id=1, + user_query="hi", + revert_actions=True, + ) + msg = str(exc.value).lower() + assert "from_message_id" in msg + + def test_from_message_id_without_revert_is_allowed(self) -> None: + req = RegenerateRequest( + search_space_id=1, + user_query="hi", + from_message_id=42, + ) + assert req.from_message_id == 42 + assert req.revert_actions is False + + def test_revert_actions_with_from_message_id_passes(self) -> None: + req = RegenerateRequest( + search_space_id=1, + user_query="hi", + from_message_id=42, + revert_actions=True, + ) + assert req.revert_actions is True diff --git a/surfsense_backend/tests/unit/routes/test_revert_turn_route.py b/surfsense_backend/tests/unit/routes/test_revert_turn_route.py new file mode 100644 index 000000000..1e1cbffb3 --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_revert_turn_route.py @@ -0,0 +1,530 @@ +"""Unit tests for ``POST /threads/{id}/revert-turn/{chat_turn_id}``. + +The per-turn batch revert route walks rows in reverse ``created_at`` +order, reverts each independently, and returns a per-action result +list. Partial success is normal — the response status +is ``"partial"`` whenever any row could not be reverted, but we never +collapse the whole batch into a 4xx. + +These tests stub ``load_thread`` / ``revert_action`` and feed a fake +session, so they exercise the route's dispatch logic without a real DB. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any +from unittest.mock import AsyncMock, patch + +import pytest + +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.routes import agent_revert_route +from app.services.revert_service import RevertOutcome + + +@dataclass +class _FakeAction: + id: int + tool_name: str + user_id: str | None = "u1" + reverse_of: int | None = None + error: dict | None = None + + +@dataclass +class _FakeUser: + id: str = "u1" + + +@dataclass +class _ScalarResult: + rows: list[Any] + + def first(self) -> Any: + return self.rows[0] if self.rows else None + + def all(self) -> list[Any]: + return list(self.rows) + + +@dataclass +class _Result: + rows: list[Any] = field(default_factory=list) + + def scalars(self) -> _ScalarResult: + return _ScalarResult(self.rows) + + def all(self) -> list[Any]: + # ``_was_already_reverted_batch`` calls ``.all()`` directly on + # the row-tuple result (no ``.scalars()`` indirection). The + # rows queued for that helper are list[(revert_id, original_id)]. + return list(self.rows) + + +class _FakeNestedCtx: + """Async context manager that mimics ``session.begin_nested()``. + + The route raises a sentinel exception inside this block to roll back + bad rows. We just pass the exception through. + """ + + async def __aenter__(self) -> _FakeNestedCtx: + return self + + async def __aexit__(self, exc_type, exc, tb) -> bool: + # Returning False (or None) propagates the exception; the route + # catches its own sentinel above this layer. + return False + + +class _FakeSession: + """Minimal AsyncSession stand-in for the revert-turn route. + + Holds a queue of result objects; each ``execute(...)`` pops the next + one. The route calls ``execute`` exactly once per query so this maps + cleanly onto the assertion order of the test. + """ + + def __init__(self) -> None: + self._results: list[_Result] = [] + self.committed = False + self.rolled_back = False + # Count execute() calls to assert "no N+1 reverts". + self.execute_call_count = 0 + + def queue(self, *results: _Result) -> None: + self._results.extend(results) + + async def execute(self, _stmt: Any) -> _Result: + self.execute_call_count += 1 + if not self._results: + return _Result(rows=[]) + return self._results.pop(0) + + def begin_nested(self) -> _FakeNestedCtx: + return _FakeNestedCtx() + + async def commit(self) -> None: + self.committed = True + + async def rollback(self) -> None: + self.rolled_back = True + + +def _enabled_flags() -> AgentFeatureFlags: + return AgentFeatureFlags( + disable_new_agent_stack=False, + enable_action_log=True, + enable_revert_route=True, + ) + + +@pytest.fixture +def patch_get_flags(): + def _patch(flags: AgentFeatureFlags): + return patch( + "app.routes.agent_revert_route.get_flags", + return_value=flags, + ) + + return _patch + + +class TestFlagGuard: + @pytest.mark.asyncio + async def test_returns_503_when_revert_route_disabled( + self, patch_get_flags + ) -> None: + flags = AgentFeatureFlags( + disable_new_agent_stack=False, + enable_action_log=True, + enable_revert_route=False, + ) + session = _FakeSession() + with patch_get_flags(flags), pytest.raises(Exception) as exc: + await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="42:1700000000000", + session=session, + user=_FakeUser(), + ) + assert getattr(exc.value, "status_code", None) == 503 + + +class TestRevertTurnDispatch: + @pytest.mark.asyncio + async def test_empty_turn_returns_ok_with_no_rows(self, patch_get_flags) -> None: + session = _FakeSession() + session.queue(_Result(rows=[])) # rows query returns nothing + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-empty", + session=session, + user=_FakeUser(), + ) + assert response.status == "ok" + assert response.total == 0 + assert response.results == [] + assert session.committed is True + + @pytest.mark.asyncio + async def test_walks_rows_in_reverse_and_reverts_each( + self, patch_get_flags + ) -> None: + rows = [ + _FakeAction(id=10, tool_name="rm"), + _FakeAction(id=9, tool_name="write_file"), + _FakeAction(id=8, tool_name="mkdir"), + ] + session = _FakeSession() + session.queue(_Result(rows=rows)) + # Single batched ``_was_already_reverted_batch`` probe replaces + # the previous N per-row SELECTs. + session.queue(_Result(rows=[])) + + async def _fake_revert(_session, *, action, requester_user_id): + return RevertOutcome( + status="ok", + message=f"reverted-{action.id}", + new_action_id=100 + action.id, + ) + + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + patch.object( + agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert) + ), + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-3", + session=session, + user=_FakeUser(), + ) + + assert response.status == "ok" + assert response.total == 3 + assert response.reverted == 3 + assert [r.action_id for r in response.results] == [10, 9, 8] + assert all(r.status == "reverted" for r in response.results) + assert response.results[0].new_action_id == 110 + # Only TWO ``execute`` calls regardless of the row count: one + # for the rows query, one for the batched + # ``_was_already_reverted_batch`` probe. Regression guard + # against re-introducing the per-row N+1 lookup. + assert session.execute_call_count == 2, ( + "revert-turn loop must batch idempotency probes; got " + f"{session.execute_call_count} execute() calls (expected 2)." + ) + + @pytest.mark.asyncio + async def test_already_reverted_rows_are_marked_idempotent( + self, patch_get_flags + ) -> None: + rows = [_FakeAction(id=5, tool_name="edit_file")] + session = _FakeSession() + session.queue(_Result(rows=rows)) + # Batch probe returns ``[(revert_id, original_id)]``. + session.queue(_Result(rows=[(42, 5)])) + + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert, + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-i", + session=session, + user=_FakeUser(), + ) + assert response.status == "ok" + assert response.already_reverted == 1 + assert response.results[0].status == "already_reverted" + assert response.results[0].new_action_id == 42 + revert.assert_not_called() + + @pytest.mark.asyncio + async def test_revert_action_skips_existing_revert_rows( + self, patch_get_flags + ) -> None: + rows = [_FakeAction(id=99, tool_name="_revert:edit_file", reverse_of=42)] + session = _FakeSession() + session.queue(_Result(rows=rows)) + + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert, + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-rev", + session=session, + user=_FakeUser(), + ) + assert response.status == "ok" + assert response.results[0].status == "skipped" + revert.assert_not_called() + + @pytest.mark.asyncio + async def test_partial_success_when_some_rows_not_reversible( + self, patch_get_flags + ) -> None: + rows = [ + _FakeAction(id=2, tool_name="send_email"), + _FakeAction(id=1, tool_name="edit_file"), + ] + session = _FakeSession() + session.queue(_Result(rows=rows)) + # Single batched idempotency probe. + session.queue(_Result(rows=[])) + + async def _fake_revert(_session, *, action, requester_user_id): + if action.tool_name == "send_email": + return RevertOutcome( + status="not_reversible", + message="connector revert not yet implemented", + ) + return RevertOutcome(status="ok", message="ok", new_action_id=500) + + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + patch.object( + agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert) + ), + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-mix", + session=session, + user=_FakeUser(), + ) + assert response.status == "partial" + assert response.reverted == 1 + assert response.not_reversible == 1 + statuses = sorted(r.status for r in response.results) + assert statuses == ["not_reversible", "reverted"] + + @pytest.mark.asyncio + async def test_unexpected_exception_marks_row_failed_not_batch( + self, patch_get_flags + ) -> None: + rows = [ + _FakeAction(id=20, tool_name="edit_file"), + _FakeAction(id=21, tool_name="edit_file"), + ] + session = _FakeSession() + session.queue(_Result(rows=rows)) + # Single batched idempotency probe. + session.queue(_Result(rows=[])) + + async def _fake_revert(_session, *, action, requester_user_id): + if action.id == 20: + raise RuntimeError("disk on fire") + return RevertOutcome(status="ok", message="ok", new_action_id=999) + + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + patch.object( + agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert) + ), + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-fail", + session=session, + user=_FakeUser(), + ) + assert response.status == "partial" + assert response.failed == 1 + assert response.reverted == 1 + bad = next(r for r in response.results if r.action_id == 20) + assert bad.status == "failed" + assert "disk on fire" in (bad.error or "") + good = next(r for r in response.results if r.action_id == 21) + assert good.status == "reverted" + + @pytest.mark.asyncio + async def test_permission_denied_when_other_user_owns_action( + self, patch_get_flags + ) -> None: + rows = [_FakeAction(id=7, tool_name="edit_file", user_id="someone-else")] + session = _FakeSession() + session.queue(_Result(rows=rows)) + # Batch idempotency probe (no prior reverts). + session.queue(_Result(rows=[])) + + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert, + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-perm", + session=session, + user=_FakeUser(id="not-owner"), + ) + assert response.status == "partial" + assert response.results[0].status == "permission_denied" + # ``permission_denied`` has its own dedicated counter so the + # response invariant ``total == sum(counters)`` always holds + # without overloading ``not_reversible`` (which historically + # absorbed this case and confused frontend toasts). + assert response.permission_denied == 1 + assert response.not_reversible == 0 + revert.assert_not_called() + + @pytest.mark.asyncio + async def test_counter_invariant_holds_across_mixed_outcomes( + self, patch_get_flags + ) -> None: + """Every row is accounted for in EXACTLY ONE counter. + + Mixes one of every supported outcome (reverted, already_reverted, + not_reversible, permission_denied, failed, skipped) and asserts + that the sum of counters equals ``response.total``. + """ + rows = [ + _FakeAction(id=10, tool_name="edit_file"), # ok + _FakeAction(id=9, tool_name="edit_file"), # already_reverted + _FakeAction(id=8, tool_name="send_email"), # not_reversible + _FakeAction(id=7, tool_name="rm", user_id="other"), # permission_denied + _FakeAction(id=6, tool_name="edit_file"), # failed + _FakeAction(id=5, tool_name="_revert:edit_file", reverse_of=99), # skipped + ] + session = _FakeSession() + session.queue(_Result(rows=rows)) + # Single batched probe; only id=9 has a prior revert. + # Schema: list[(revert_id, original_id)]. + session.queue(_Result(rows=[(42, 9)])) + + async def _fake_revert(_session, *, action, requester_user_id): + if action.id == 10: + return RevertOutcome(status="ok", message="ok", new_action_id=500) + if action.id == 8: + return RevertOutcome( + status="not_reversible", + message="connector revert not yet implemented", + ) + if action.id == 6: + raise RuntimeError("boom") + raise AssertionError(f"unexpected revert call for {action.id}") + + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + patch.object( + agent_revert_route, + "revert_action", + AsyncMock(side_effect=_fake_revert), + ), + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-mixed-all", + session=session, + user=_FakeUser(), # only id=7 has a different user_id + ) + + assert response.total == len(rows) == 6 + bucket_sum = ( + response.reverted + + response.already_reverted + + response.not_reversible + + response.permission_denied + + response.failed + + response.skipped + ) + assert bucket_sum == response.total, ( + "Counter invariant broken: total " + f"({response.total}) != sum of counters ({bucket_sum}). " + f"Counters: reverted={response.reverted}, " + f"already_reverted={response.already_reverted}, " + f"not_reversible={response.not_reversible}, " + f"permission_denied={response.permission_denied}, " + f"failed={response.failed}, skipped={response.skipped}" + ) + assert response.reverted == 1 + assert response.already_reverted == 1 + assert response.not_reversible == 1 + assert response.permission_denied == 1 + assert response.failed == 1 + assert response.skipped == 1 + + @pytest.mark.asyncio + async def test_integrity_error_translates_to_already_reverted( + self, patch_get_flags + ) -> None: + """The partial unique index on ``reverse_of`` raises + ``IntegrityError`` when a concurrent revert wins the race against + the pre-flight ``_was_already_reverted`` SELECT. The route MUST + recover by re-querying for the winning revert id and returning + ``status="already_reverted"`` (not ``"failed"``) so racing + clients see consistent idempotent semantics. + """ + from sqlalchemy.exc import IntegrityError + + rows = [_FakeAction(id=33, tool_name="edit_file")] + session = _FakeSession() + session.queue(_Result(rows=rows)) + # Batch pre-flight probe: nothing yet (we'll race). + session.queue(_Result(rows=[])) + # Post-IntegrityError fallback uses the SCALAR + # ``_was_already_reverted`` (single-id lookup) so it pulls + # ``[777]`` via ``.scalars().first()``. + session.queue(_Result(rows=[777])) + + async def _racing_revert(_session, *, action, requester_user_id): + raise IntegrityError("INSERT", {}, Exception("dup reverse_of")) + + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + patch.object( + agent_revert_route, + "revert_action", + AsyncMock(side_effect=_racing_revert), + ), + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-race", + session=session, + user=_FakeUser(), + ) + + assert response.failed == 0, ( + "IntegrityError must NOT surface as a failed row; the unique " + "index is the durable expression of idempotency." + ) + assert response.already_reverted == 1 + assert response.results[0].status == "already_reverted" + assert response.results[0].new_action_id == 777 diff --git a/surfsense_backend/tests/unit/services/test_revert_filesystem_tools.py b/surfsense_backend/tests/unit/services/test_revert_filesystem_tools.py new file mode 100644 index 000000000..95314741a --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_revert_filesystem_tools.py @@ -0,0 +1,370 @@ +"""Unit tests for the filesystem-tool branches of ``revert_service``. + +Covers: + +* Exact-name dispatch — ``rmdir`` does NOT mis-route to the document + branch (``"rmdir".startswith("rm")`` would mis-route under the legacy + prefix-based dispatch). +* ``rm`` revert re-INSERTs a fresh document from the snapshot, including + re-creating chunks. Falls back to ``(folder_id_before, title_before)`` + when ``metadata_before["virtual_path"]`` is missing. +* ``write_file`` create-revert (``content_before IS NULL``) DELETEs the + document. +* ``rmdir`` revert re-INSERTs a fresh folder from the snapshot. +* ``mkdir`` revert DELETEs the empty folder; reports ``tool_unavailable`` + when the folder gained children. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import numpy as np +import pytest + +from app.services import revert_service + +pytestmark = pytest.mark.unit + + +@pytest.fixture(autouse=True) +def _stub_embeddings(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + revert_service, + "embed_texts", + lambda texts: [np.zeros(8, dtype=np.float32) for _ in texts], + ) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _FakeResult: + def __init__(self, rows: list[Any] | None = None, scalar: Any = None) -> None: + self._rows = rows or [] + self._scalar = scalar + + def all(self) -> list[Any]: + return list(self._rows) + + def scalar_one_or_none(self) -> Any: + return self._scalar + + def scalars(self) -> Any: + return _FakeScalarsProxy(self._rows) + + +class _FakeScalarsProxy: + def __init__(self, rows: list[Any]) -> None: + self._rows = rows + + def first(self) -> Any: + return self._rows[0] if self._rows else None + + +class _FakeSession: + def __init__(self) -> None: + self.execute = AsyncMock() + self.added: list[Any] = [] + self.deleted: list[Any] = [] + self.flush = AsyncMock() + # session.get(Model, pk) lookup + self.get = AsyncMock(return_value=None) + + async def _flush_assigning_ids() -> None: + for obj in self.added: + if getattr(obj, "id", None) is None: + obj.id = 999 + + self.flush.side_effect = _flush_assigning_ids + + def add(self, obj: Any) -> None: + self.added.append(obj) + + def add_all(self, objs: list[Any]) -> None: + self.added.extend(objs) + + +def _action(*, tool_name: str, action_id: int = 7): + return MagicMock( + id=action_id, + tool_name=tool_name, + thread_id=1, + search_space_id=2, + user_id="user-1", + reverse_descriptor=None, + ) + + +def _doc_revision( + *, + document_id: int | None = None, + content_before: str | None = "old content", + title_before: str | None = "notes.md", + folder_id_before: int | None = 5, + chunks_before: list[dict[str, str]] | None = None, + metadata_before: dict[str, str] | None = None, +): + revision = MagicMock() + revision.id = 100 + revision.document_id = document_id + revision.search_space_id = 2 + revision.content_before = content_before + revision.title_before = title_before + revision.folder_id_before = folder_id_before + revision.chunks_before = chunks_before or [] + revision.metadata_before = metadata_before + return revision + + +def _folder_revision( + *, + folder_id: int | None = None, + name_before: str | None = "team", + parent_id_before: int | None = None, + position_before: str | None = "a0", +): + revision = MagicMock() + revision.id = 200 + revision.folder_id = folder_id + revision.search_space_id = 2 + revision.name_before = name_before + revision.parent_id_before = parent_id_before + revision.position_before = position_before + return revision + + +# --------------------------------------------------------------------------- +# Exact-name dispatch regression guards +# --------------------------------------------------------------------------- + + +class TestExactDispatch: + """Regression: ``rmdir`` MUST NOT route to the document branch.""" + + @pytest.mark.asyncio + async def test_rmdir_does_not_misroute_to_document(self) -> None: + # If dispatch used `startswith("rm")` we'd hit the document branch + # here. With exact-name lookup `rmdir` lands in `_FOLDER_TOOLS`. + session = _FakeSession() + action = _action(tool_name="rmdir") + # No folder revisions exist for this action. + session.execute.return_value = _FakeResult(rows=[]) + outcome = await revert_service.revert_action( + session, # type: ignore[arg-type] + action=action, + requester_user_id="user-1", + ) + assert outcome.status == "not_reversible" + assert "folder_revisions" in outcome.message + + def test_dispatch_sets_split_doc_and_folder(self) -> None: + # Static guards on the dispatch tables themselves so a future + # refactor doesn't accidentally reintroduce the prefix bug. + assert "rm" in revert_service._DOC_TOOLS + assert "rmdir" in revert_service._FOLDER_TOOLS + assert "rmdir" not in revert_service._DOC_TOOLS + assert "rm" not in revert_service._FOLDER_TOOLS + # ``move_file`` lives only in document tools (it's a doc rename). + assert "move_file" in revert_service._DOC_TOOLS + assert "move_file" not in revert_service._FOLDER_TOOLS + + +# --------------------------------------------------------------------------- +# rm revert (re-INSERT) +# --------------------------------------------------------------------------- + + +class TestRmRevert: + @pytest.mark.asyncio + async def test_re_inserts_document_with_chunks(self) -> None: + session = _FakeSession() + revision = _doc_revision( + document_id=None, # row was hard-deleted + content_before="hello world", + title_before="x.md", + folder_id_before=None, + chunks_before=[{"content": "alpha"}, {"content": "beta"}], + metadata_before={"virtual_path": "/documents/x.md"}, + ) + # No collision check hit and the resulting query returns nothing. + session.execute.return_value = _FakeResult(scalar=None) + + outcome = await revert_service._reinsert_document_from_revision( + session, # type: ignore[arg-type] + revision=revision, + ) + + assert outcome.status == "ok" + # New Document + 2 chunks must have been added. + from app.db import Chunk, Document + + added_docs = [obj for obj in session.added if isinstance(obj, Document)] + added_chunks = [obj for obj in session.added if isinstance(obj, Chunk)] + assert len(added_docs) == 1 + assert added_docs[0].title == "x.md" + assert len(added_chunks) == 2 + # Snapshot was repointed at the new doc id so a follow-up revert works. + assert revision.document_id == added_docs[0].id + + @pytest.mark.asyncio + async def test_falls_back_to_folder_id_and_title_for_virtual_path( + self, + ) -> None: + session = _FakeSession() + # Snapshot with NO metadata_before — the fallback path must kick in. + revision = _doc_revision( + document_id=None, + content_before="hello", + title_before="cap.md", + folder_id_before=42, + chunks_before=[], + metadata_before=None, + ) + # session.get(Folder, 42) returns a folder with a name. + folder = MagicMock() + folder.name = "team" + folder.parent_id = None + # First .get is for the folder lookup in the path-derivation. + session.get = AsyncMock(return_value=folder) + session.execute.return_value = _FakeResult(scalar=None) + + outcome = await revert_service._reinsert_document_from_revision( + session, # type: ignore[arg-type] + revision=revision, + ) + assert outcome.status == "ok" + + @pytest.mark.asyncio + async def test_falls_back_to_root_path_when_no_folder( + self, + ) -> None: + """metadata_before is None and folder_id_before is None still + resolves: title fallback yields ``/documents/`` so revert + proceeds at the root of the documents tree.""" + session = _FakeSession() + revision = _doc_revision( + document_id=None, + content_before="hello", + title_before="x.md", + folder_id_before=None, + metadata_before=None, + ) + # No collision in the documents tree at /documents/x.md. + session.execute.return_value = _FakeResult(scalar=None) + outcome = await revert_service._reinsert_document_from_revision( + session, # type: ignore[arg-type] + revision=revision, + ) + assert outcome.status == "ok" + + @pytest.mark.asyncio + async def test_collision_with_live_doc_returns_tool_unavailable(self) -> None: + session = _FakeSession() + revision = _doc_revision( + document_id=None, + content_before="hi", + title_before="x.md", + folder_id_before=None, + metadata_before={"virtual_path": "/documents/x.md"}, + ) + # SELECT for unique_identifier_hash collision hits an existing row. + session.execute.return_value = _FakeResult(scalar=42) + outcome = await revert_service._reinsert_document_from_revision( + session, # type: ignore[arg-type] + revision=revision, + ) + assert outcome.status == "tool_unavailable" + assert "collide" in outcome.message + + +# --------------------------------------------------------------------------- +# write_file create revert (DELETE) +# --------------------------------------------------------------------------- + + +class TestWriteFileCreateRevert: + @pytest.mark.asyncio + async def test_deletes_created_doc(self) -> None: + session = _FakeSession() + revision = _doc_revision( + document_id=99, + content_before=None, # marker for "created in this action" + title_before=None, + ) + outcome = await revert_service._delete_created_document( + session, # type: ignore[arg-type] + revision=revision, + ) + assert outcome.status == "ok" + # Exactly one DELETE was issued. + assert session.execute.await_count == 1 + + +# --------------------------------------------------------------------------- +# rmdir revert (re-INSERT folder) +# --------------------------------------------------------------------------- + + +class TestRmdirRevert: + @pytest.mark.asyncio + async def test_re_inserts_folder_from_snapshot(self) -> None: + session = _FakeSession() + revision = _folder_revision( + folder_id=None, + name_before="team", + parent_id_before=None, + position_before="a0", + ) + outcome = await revert_service._reinsert_folder_from_revision( + session, # type: ignore[arg-type] + revision=revision, + ) + from app.db import Folder + + assert outcome.status == "ok" + added_folders = [obj for obj in session.added if isinstance(obj, Folder)] + assert len(added_folders) == 1 + assert added_folders[0].name == "team" + assert revision.folder_id == added_folders[0].id + + +# --------------------------------------------------------------------------- +# mkdir revert (DELETE folder) +# --------------------------------------------------------------------------- + + +class TestMkdirRevert: + @pytest.mark.asyncio + async def test_deletes_empty_folder(self) -> None: + session = _FakeSession() + revision = _folder_revision(folder_id=42) + # Both the doc-existence check and the child-folder check return None. + session.execute.side_effect = [ + _FakeResult(scalar=None), # docs + _FakeResult(scalar=None), # children + _FakeResult(scalar=None), # delete (no return value) + ] + outcome = await revert_service._delete_created_folder( + session, # type: ignore[arg-type] + revision=revision, + ) + assert outcome.status == "ok" + # 3 executes: docs check, children check, delete. + assert session.execute.await_count == 3 + + @pytest.mark.asyncio + async def test_reports_tool_unavailable_when_folder_has_children(self) -> None: + session = _FakeSession() + revision = _folder_revision(folder_id=42) + # First check (docs) returns "row found". + session.execute.return_value = _FakeResult(scalar=1) + outcome = await revert_service._delete_created_folder( + session, # type: ignore[arg-type] + revision=revision, + ) + assert outcome.status == "tool_unavailable" + assert "no longer empty" in outcome.message diff --git a/surfsense_backend/tests/unit/tasks/__init__.py b/surfsense_backend/tests/unit/tasks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/tasks/chat/__init__.py b/surfsense_backend/tests/unit/tasks/chat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py b/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py new file mode 100644 index 000000000..7f32bf456 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py @@ -0,0 +1,185 @@ +"""Unit tests for ``stream_new_chat._extract_chunk_parts``. + +Earlier versions only handled ``isinstance(chunk.content, str)`` and +silently dropped every other shape (Anthropic typed-block lists, +Bedrock reasoning blocks, ``additional_kwargs.reasoning_content`` from +a few providers). These regression tests pin those four shapes plus the +defensive cases (``None`` chunk, mixed types, missing fields). +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +import pytest + +from app.tasks.chat.stream_new_chat import _extract_chunk_parts + + +@dataclass +class _FakeChunk: + """Minimal stand-in for ``AIMessageChunk`` used in unit tests.""" + + content: Any = "" + additional_kwargs: dict[str, Any] = field(default_factory=dict) + tool_call_chunks: list[dict[str, Any]] = field(default_factory=list) + + +class TestStringContent: + def test_plain_string_content_extracts_as_text(self) -> None: + chunk = _FakeChunk(content="hello world") + out = _extract_chunk_parts(chunk) + assert out["text"] == "hello world" + assert out["reasoning"] == "" + assert out["tool_call_chunks"] == [] + + def test_empty_string_content_yields_empty_text(self) -> None: + chunk = _FakeChunk(content="") + out = _extract_chunk_parts(chunk) + assert out["text"] == "" + assert out["reasoning"] == "" + assert out["tool_call_chunks"] == [] + + +class TestListContent: + def test_list_of_text_blocks_concatenates(self) -> None: + chunk = _FakeChunk( + content=[ + {"type": "text", "text": "Hello "}, + {"type": "text", "text": "world"}, + ] + ) + out = _extract_chunk_parts(chunk) + assert out["text"] == "Hello world" + assert out["reasoning"] == "" + + def test_mixed_text_and_reasoning_blocks(self) -> None: + chunk = _FakeChunk( + content=[ + {"type": "reasoning", "reasoning": "Let me think... "}, + {"type": "reasoning", "text": "still thinking."}, + {"type": "text", "text": "The answer is 42."}, + ] + ) + out = _extract_chunk_parts(chunk) + assert out["text"] == "The answer is 42." + assert out["reasoning"] == "Let me think... still thinking." + + def test_tool_call_chunks_in_content_list_extracted(self) -> None: + chunk = _FakeChunk( + content=[ + {"type": "text", "text": "Calling tool..."}, + { + "type": "tool_call_chunk", + "id": "call_123", + "name": "make_widget", + "args": '{"color":"red"}', + }, + ] + ) + out = _extract_chunk_parts(chunk) + assert out["text"] == "Calling tool..." + assert out["reasoning"] == "" + assert len(out["tool_call_chunks"]) == 1 + assert out["tool_call_chunks"][0]["id"] == "call_123" + assert out["tool_call_chunks"][0]["name"] == "make_widget" + + def test_tool_use_blocks_also_extracted(self) -> None: + """Some providers (Anthropic) emit ``type='tool_use'`` instead.""" + chunk = _FakeChunk( + content=[ + { + "type": "tool_use", + "id": "call_xyz", + "name": "search", + }, + ] + ) + out = _extract_chunk_parts(chunk) + assert out["tool_call_chunks"] == [ + {"type": "tool_use", "id": "call_xyz", "name": "search"} + ] + + def test_unknown_block_types_are_ignored(self) -> None: + chunk = _FakeChunk( + content=[ + {"type": "image_url", "url": "https://example.com/x.png"}, + {"type": "text", "text": "ok"}, + ] + ) + out = _extract_chunk_parts(chunk) + assert out["text"] == "ok" + + def test_blocks_without_text_field_are_ignored(self) -> None: + chunk = _FakeChunk( + content=[ + {"type": "text"}, # no text/content key + {"type": "text", "text": "kept"}, + ] + ) + out = _extract_chunk_parts(chunk) + assert out["text"] == "kept" + + +class TestAdditionalKwargsReasoning: + def test_reasoning_content_in_additional_kwargs(self) -> None: + """Some providers stash reasoning in ``additional_kwargs.reasoning_content``.""" + chunk = _FakeChunk( + content="visible answer", + additional_kwargs={"reasoning_content": "internal monologue"}, + ) + out = _extract_chunk_parts(chunk) + assert out["text"] == "visible answer" + assert out["reasoning"] == "internal monologue" + + def test_reasoning_appended_to_typed_block_reasoning(self) -> None: + chunk = _FakeChunk( + content=[{"type": "reasoning", "text": "from blocks. "}], + additional_kwargs={"reasoning_content": "from kwargs."}, + ) + out = _extract_chunk_parts(chunk) + assert out["reasoning"] == "from blocks. from kwargs." + + +class TestToolCallChunksAttribute: + def test_tool_call_chunks_attribute_extracted_alongside_string_content( + self, + ) -> None: + chunk = _FakeChunk( + content="streaming text", + tool_call_chunks=[ + {"name": "save_document", "args": '{"title":"x"}', "id": "tc-9"} + ], + ) + out = _extract_chunk_parts(chunk) + assert out["text"] == "streaming text" + assert len(out["tool_call_chunks"]) == 1 + assert out["tool_call_chunks"][0]["id"] == "tc-9" + + def test_attribute_and_typed_block_chunks_both_collected(self) -> None: + chunk = _FakeChunk( + content=[ + { + "type": "tool_call_chunk", + "id": "from-block", + "name": "x", + } + ], + tool_call_chunks=[{"id": "from-attr", "name": "y"}], + ) + out = _extract_chunk_parts(chunk) + ids = [tcc.get("id") for tcc in out["tool_call_chunks"]] + assert ids == ["from-block", "from-attr"] + + +class TestDefensive: + @pytest.mark.parametrize( + "chunk_value", + [None, _FakeChunk(content=None), _FakeChunk(content=42)], + ) + def test_invalid_chunk_returns_empty_parts(self, chunk_value: Any) -> None: + out = _extract_chunk_parts(chunk_value) + assert out["text"] == "" + assert out["reasoning"] == "" + assert out["tool_call_chunks"] == [] diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 7773a438a..c2086e80a 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -14,6 +14,13 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; import { z } from "zod"; import { disabledToolsAtom } from "@/atoms/agent-tools/agent-tools.atoms"; +import { + agentActionsByChatTurnIdAtom, + markAgentActionRevertedAtom, + resetAgentActionMapAtom, + updateAgentActionReversibleAtom, + upsertAgentActionAtom, +} from "@/atoms/chat/agent-actions.atom"; import { clearTargetCommentIdAtom, currentThreadAtom, @@ -36,6 +43,11 @@ import { closeEditorPanelAtom } from "@/atoms/editor/editor-panel.atom"; import { membersAtom } from "@/atoms/members/members-query.atoms"; import { removeChatTabAtom, updateChatTabTitleAtom } from "@/atoms/tabs/tabs.atom"; import { currentUserAtom } from "@/atoms/user/user-query.atoms"; +import { + EditMessageDialog, + type EditMessageDialogChoice, +} from "@/components/assistant-ui/edit-message-dialog"; +import { StepSeparatorDataUI } from "@/components/assistant-ui/step-separator"; import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps"; import { Thread } from "@/components/assistant-ui/thread"; import { @@ -55,14 +67,19 @@ import { setActivePodcastTaskId, } from "@/lib/chat/podcast-state"; import { + addStepSeparator, addToolCall, + appendReasoning, appendText, buildContentForPersistence, buildContentForUI, type ContentPartsState, + endReasoning, FrameBatchedUpdater, + findToolCallIdByLcId, readSSEStream, type ThinkingStepData, + type ToolUIGate, updateThinkingSteps, updateToolCall, } from "@/lib/chat/streaming-state"; @@ -161,44 +178,38 @@ function extractMentionedDocuments(content: unknown): MentionedDocumentInfo[] { } /** - * Tools that should render custom UI in the chat. + * Every tool call renders a card. The legacy + * ``BASE_TOOLS_WITH_UI`` allowlist used to drop unknown tool calls on the + * floor; we now route everything through ``ToolFallback``. Persisted + * payload size stays bounded because the backend's + * ``format_thinking_step`` summarisation and the + * ``result_length``-only default for unknown tools (see + * ``stream_new_chat.py``) keep the JSON from ballooning. */ -const BASE_TOOLS_WITH_UI = new Set([ - "web_search", - "generate_podcast", - "generate_report", - "generate_resume", - "generate_video_presentation", - "display_image", - "generate_image", - "delete_notion_page", - "create_notion_page", - "update_notion_page", - "create_linear_issue", - "update_linear_issue", - "delete_linear_issue", - "create_google_drive_file", - "delete_google_drive_file", - "create_onedrive_file", - "delete_onedrive_file", - "create_dropbox_file", - "delete_dropbox_file", - "create_calendar_event", - "update_calendar_event", - "delete_calendar_event", - "create_gmail_draft", - "update_gmail_draft", - "send_gmail_email", - "trash_gmail_email", - "create_jira_issue", - "update_jira_issue", - "delete_jira_issue", - "create_confluence_page", - "update_confluence_page", - "delete_confluence_page", - "execute", - // "write_todos", // Disabled for now -]); +const TOOLS_WITH_UI_ALL: ToolUIGate = "all"; + +/** + * When a streamed message is persisted, the backend returns the durable + * ``turn_id`` (``configurable.turn_id`` from the agent run). Merge it + * into the assistant-ui message metadata so the per-turn "Revert turn" + * button can scope to this turn's actions even after a full chat reload. + */ +function mergeChatTurnIdIntoMessage( + msg: ThreadMessageLike, + turnId: string | null | undefined +): ThreadMessageLike { + if (!turnId) return msg; + const existingMeta = (msg.metadata ?? {}) as { custom?: Record<string, unknown> }; + const existingCustom = existingMeta.custom ?? {}; + if ((existingCustom as { chatTurnId?: string }).chatTurnId === turnId) return msg; + return { + ...msg, + metadata: { + ...existingMeta, + custom: { ...existingCustom, chatTurnId: turnId }, + }, + }; +} export default function NewChatPage() { const params = useParams(); @@ -215,7 +226,7 @@ export default function NewChatPage() { assistantMsgId: string; interruptData: Record<string, unknown>; } | null>(null); - const toolsWithUI = useMemo(() => new Set([...BASE_TOOLS_WITH_UI]), []); + const toolsWithUI = TOOLS_WITH_UI_ALL; // Get disabled tools from the tool toggle UI const disabledTools = useAtomValue(disabledToolsAtom); @@ -235,6 +246,25 @@ export default function NewChatPage() { const setAgentCreatedDocuments = useSetAtom(agentCreatedDocumentsAtom); const pendingUserImageUrls = useAtomValue(pendingUserImageDataUrlsAtom); const setPendingUserImageUrls = useSetAtom(pendingUserImageDataUrlsAtom); + // Agent action log SSE side-channel. + const upsertAgentAction = useSetAtom(upsertAgentActionAtom); + const updateAgentActionReversible = useSetAtom(updateAgentActionReversibleAtom); + const markAgentActionReverted = useSetAtom(markAgentActionRevertedAtom); + const resetAgentActionMap = useSetAtom(resetAgentActionMapAtom); + // Chat-turn-keyed action map for the edit-from-position pre-flight + // that decides whether to show the confirmation dialog. + const agentActionsByChatTurnId = useAtomValue(agentActionsByChatTurnIdAtom); + // Edit dialog state. Holds the message id being edited and + // the (already extracted) regenerate args so we can resume the edit + // after the user picks "revert all" / "continue" / "cancel". + const [editDialogState, setEditDialogState] = useState<{ + fromMessageId: number; + userQuery: string | null; + userMessageContent: ThreadMessageLike["content"]; + userImages: NewChatUserImagePayload[]; + downstreamReversibleCount: number; + downstreamTotalCount: number; + } | null>(null); // Get current user for author info in shared chats const { data: currentUser } = useAtomValue(currentUserAtom); @@ -327,6 +357,7 @@ export default function NewChatPage() { clearPlanOwnerRegistry(); closeReportPanel(); closeEditorPanel(); + resetAgentActionMap(); try { if (urlChatId > 0) { @@ -395,6 +426,7 @@ export default function NewChatPage() { removeChatTab, searchSpaceId, tokenUsageStore, + resetAgentActionMap, ]); // Initialize on mount, and re-init when switching search spaces (even if urlChatId is the same) @@ -655,11 +687,14 @@ export default function NewChatPage() { const contentPartsState: ContentPartsState = { contentParts: [], currentTextPartIndex: -1, + currentReasoningPartIndex: -1, toolCallIndices: new Map(), }; const { contentParts, toolCallIndices } = contentPartsState; let wasInterrupted = false; let tokenUsageData: Record<string, unknown> | null = null; + // Captured from ``data-turn-info`` at stream start. + let streamedChatTurnId: string | null = null; // Add placeholder assistant message setMessages((prev) => [ @@ -752,21 +787,52 @@ export default function NewChatPage() { scheduleFlush(); break; + case "reasoning-delta": + appendReasoning(contentPartsState, parsed.delta); + scheduleFlush(); + break; + + case "reasoning-end": + endReasoning(contentPartsState); + scheduleFlush(); + break; + + case "start-step": + addStepSeparator(contentPartsState); + scheduleFlush(); + break; + + case "finish-step": + break; + case "tool-input-start": - addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {}); + addToolCall( + contentPartsState, + toolsWithUI, + parsed.toolCallId, + parsed.toolName, + {}, + false, + parsed.langchainToolCallId + ); batcher.flush(); break; case "tool-input-available": { if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {} }); + updateToolCall(contentPartsState, parsed.toolCallId, { + args: parsed.input || {}, + langchainToolCallId: parsed.langchainToolCallId, + }); } else { addToolCall( contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, - parsed.input || {} + parsed.input || {}, + false, + parsed.langchainToolCallId ); } batcher.flush(); @@ -774,7 +840,10 @@ export default function NewChatPage() { } case "tool-output-available": { - updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output }); + updateToolCall(contentPartsState, parsed.toolCallId, { + result: parsed.output, + langchainToolCallId: parsed.langchainToolCallId, + }); markInterruptsCompleted(contentParts); if (parsed.output?.status === "pending" && parsed.output?.podcast_id) { const idx = toolCallIndices.get(parsed.toolCallId); @@ -880,6 +949,50 @@ export default function NewChatPage() { break; } + case "data-action-log": { + const al = parsed.data; + const matchedToolCallId = al.lc_tool_call_id + ? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id) + : null; + upsertAgentAction({ + action: { + id: al.id, + threadId: currentThreadId, + lcToolCallId: al.lc_tool_call_id, + chatTurnId: al.chat_turn_id, + toolName: al.tool_name, + reversible: al.reversible, + reverseDescriptorPresent: al.reverse_descriptor_present, + error: al.error, + revertedByActionId: null, + isRevertAction: false, + createdAt: al.created_at, + }, + toolCallId: matchedToolCallId, + }); + break; + } + + case "data-action-log-updated": { + updateAgentActionReversible({ + id: parsed.data.id, + reversible: parsed.data.reversible, + }); + break; + } + + case "data-turn-info": { + streamedChatTurnId = parsed.data.chat_turn_id || null; + if (streamedChatTurnId) { + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m + ) + ); + } + break; + } + case "data-token-usage": tokenUsageData = parsed.data; tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData); @@ -900,13 +1013,18 @@ export default function NewChatPage() { role: "assistant", content: finalContent, token_usage: tokenUsageData ?? undefined, + turn_id: streamedChatTurnId, }); // Update message ID from temporary to database ID so comments work immediately const newMsgId = `msg-${savedMessage.id}`; tokenUsageStore.rename(assistantMsgId, newMsgId); setMessages((prev) => - prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) + prev.map((m) => + m.id === assistantMsgId + ? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id) + : m + ) ); // Update pending interrupt with the new persisted message ID @@ -929,7 +1047,9 @@ export default function NewChatPage() { const hasContent = contentParts.some( (part) => (part.type === "text" && part.text.length > 0) || - (part.type === "tool-call" && toolsWithUI.has(part.toolName)) + (part.type === "reasoning" && part.text.length > 0) || + (part.type === "tool-call" && + (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) ); if (hasContent && currentThreadId) { const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); @@ -937,12 +1057,17 @@ export default function NewChatPage() { const savedMessage = await appendMessage(currentThreadId, { role: "assistant", content: partialContent, + turn_id: streamedChatTurnId, }); // Update message ID from temporary to database ID const newMsgId = `msg-${savedMessage.id}`; setMessages((prev) => - prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) + prev.map((m) => + m.id === assistantMsgId + ? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id) + : m + ) ); } catch (err) { console.error("Failed to persist partial assistant message:", err); @@ -1030,10 +1155,13 @@ export default function NewChatPage() { const contentPartsState: ContentPartsState = { contentParts: [], currentTextPartIndex: -1, + currentReasoningPartIndex: -1, toolCallIndices: new Map(), }; const { contentParts, toolCallIndices } = contentPartsState; let tokenUsageData: Record<string, unknown> | null = null; + // Captured from ``data-turn-info`` at stream start. + let streamedChatTurnId: string | null = null; const existingMsg = messages.find((m) => m.id === assistantMsgId); if (existingMsg && Array.isArray(existingMsg.content)) { @@ -1136,8 +1264,34 @@ export default function NewChatPage() { scheduleFlush(); break; + case "reasoning-delta": + appendReasoning(contentPartsState, parsed.delta); + scheduleFlush(); + break; + + case "reasoning-end": + endReasoning(contentPartsState); + scheduleFlush(); + break; + + case "start-step": + addStepSeparator(contentPartsState); + scheduleFlush(); + break; + + case "finish-step": + break; + case "tool-input-start": - addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {}); + addToolCall( + contentPartsState, + toolsWithUI, + parsed.toolCallId, + parsed.toolName, + {}, + false, + parsed.langchainToolCallId + ); batcher.flush(); break; @@ -1145,6 +1299,7 @@ export default function NewChatPage() { if (toolCallIndices.has(parsed.toolCallId)) { updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {}, + langchainToolCallId: parsed.langchainToolCallId, }); } else { addToolCall( @@ -1152,7 +1307,9 @@ export default function NewChatPage() { toolsWithUI, parsed.toolCallId, parsed.toolName, - parsed.input || {} + parsed.input || {}, + false, + parsed.langchainToolCallId ); } batcher.flush(); @@ -1161,6 +1318,7 @@ export default function NewChatPage() { case "tool-output-available": updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output, + langchainToolCallId: parsed.langchainToolCallId, }); markInterruptsCompleted(contentParts); batcher.flush(); @@ -1222,6 +1380,50 @@ export default function NewChatPage() { break; } + case "data-action-log": { + const al = parsed.data; + const matchedToolCallId = al.lc_tool_call_id + ? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id) + : null; + upsertAgentAction({ + action: { + id: al.id, + threadId: resumeThreadId, + lcToolCallId: al.lc_tool_call_id, + chatTurnId: al.chat_turn_id, + toolName: al.tool_name, + reversible: al.reversible, + reverseDescriptorPresent: al.reverse_descriptor_present, + error: al.error, + revertedByActionId: null, + isRevertAction: false, + createdAt: al.created_at, + }, + toolCallId: matchedToolCallId, + }); + break; + } + + case "data-action-log-updated": { + updateAgentActionReversible({ + id: parsed.data.id, + reversible: parsed.data.reversible, + }); + break; + } + + case "data-turn-info": { + streamedChatTurnId = parsed.data.chat_turn_id || null; + if (streamedChatTurnId) { + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m + ) + ); + } + break; + } + case "data-token-usage": tokenUsageData = parsed.data; tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData); @@ -1241,11 +1443,16 @@ export default function NewChatPage() { role: "assistant", content: finalContent, token_usage: tokenUsageData ?? undefined, + turn_id: streamedChatTurnId, }); const newMsgId = `msg-${savedMessage.id}`; tokenUsageStore.rename(assistantMsgId, newMsgId); setMessages((prev) => - prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) + prev.map((m) => + m.id === assistantMsgId + ? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id) + : m + ) ); } catch (err) { console.error("Failed to persist resumed assistant message:", err); @@ -1340,6 +1547,12 @@ export default function NewChatPage() { editExtras?: { userMessageContent: ThreadMessageLike["content"]; userImages: NewChatUserImagePayload[]; + }, + editFromPosition?: { + /** Message id (numeric, parsed from ``msg-<n>``) to rewind to. */ + fromMessageId?: number | null; + /** When true, revert reversible downstream actions before stream. */ + revertActions?: boolean; } ) => { if (!threadId) { @@ -1384,9 +1597,20 @@ export default function NewChatPage() { userQueryToDisplay = newUserQuery; } - // Remove the last two messages (user + assistant) from the UI immediately - // The backend will also delete them from the database + // Remove downstream messages from the UI immediately. The + // backend will also delete them from the database. + // + // When an explicit ``fromMessageId`` is passed, slice from + // that message forward; otherwise fall back to the legacy + // "drop the last 2" behaviour. setMessages((prev) => { + if (editFromPosition?.fromMessageId != null) { + const targetId = `msg-${editFromPosition.fromMessageId}`; + const sliceIndex = prev.findIndex((m) => m.id === targetId); + if (sliceIndex >= 0) { + return prev.slice(0, sliceIndex); + } + } if (prev.length >= 2) { return prev.slice(0, -2); } @@ -1406,11 +1630,16 @@ export default function NewChatPage() { const contentPartsState: ContentPartsState = { contentParts: [], currentTextPartIndex: -1, + currentReasoningPartIndex: -1, toolCallIndices: new Map(), }; const { contentParts, toolCallIndices } = contentPartsState; const batcher = new FrameBatchedUpdater(); let tokenUsageData: Record<string, unknown> | null = null; + // Captured from ``data-turn-info`` at stream start; stamped + // onto persisted messages so future edits can locate the + // right LangGraph checkpoint. + let streamedChatTurnId: string | null = null; // Add placeholder messages to UI // Always add back the user message (with new query for edit, or original content for reload) @@ -1449,6 +1678,16 @@ export default function NewChatPage() { if (isEdit) { requestBody.user_images = editExtras?.userImages ?? []; } + // Explicit edit-from-arbitrary-position. Only send + // ``from_message_id`` / ``revert_actions`` when the + // caller asked for them; otherwise the backend keeps the + // legacy "last 2 messages" behaviour for back-compat. + if (editFromPosition?.fromMessageId != null) { + requestBody.from_message_id = editFromPosition.fromMessageId; + if (editFromPosition.revertActions) { + requestBody.revert_actions = true; + } + } const response = await fetch(getRegenerateUrl(threadId), { method: "POST", headers: { @@ -1481,28 +1720,62 @@ export default function NewChatPage() { scheduleFlush(); break; + case "reasoning-delta": + appendReasoning(contentPartsState, parsed.delta); + scheduleFlush(); + break; + + case "reasoning-end": + endReasoning(contentPartsState); + scheduleFlush(); + break; + + case "start-step": + addStepSeparator(contentPartsState); + scheduleFlush(); + break; + + case "finish-step": + break; + case "tool-input-start": - addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {}); + addToolCall( + contentPartsState, + toolsWithUI, + parsed.toolCallId, + parsed.toolName, + {}, + false, + parsed.langchainToolCallId + ); batcher.flush(); break; case "tool-input-available": if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {} }); + updateToolCall(contentPartsState, parsed.toolCallId, { + args: parsed.input || {}, + langchainToolCallId: parsed.langchainToolCallId, + }); } else { addToolCall( contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, - parsed.input || {} + parsed.input || {}, + false, + parsed.langchainToolCallId ); } batcher.flush(); break; case "tool-output-available": - updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output }); + updateToolCall(contentPartsState, parsed.toolCallId, { + result: parsed.output, + langchainToolCallId: parsed.langchainToolCallId, + }); markInterruptsCompleted(contentParts); if (parsed.output?.status === "pending" && parsed.output?.podcast_id) { const idx = toolCallIndices.get(parsed.toolCallId); @@ -1528,6 +1801,82 @@ export default function NewChatPage() { break; } + case "data-action-log": { + const al = parsed.data; + const matchedToolCallId = al.lc_tool_call_id + ? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id) + : null; + upsertAgentAction({ + action: { + id: al.id, + threadId, + lcToolCallId: al.lc_tool_call_id, + chatTurnId: al.chat_turn_id, + toolName: al.tool_name, + reversible: al.reversible, + reverseDescriptorPresent: al.reverse_descriptor_present, + error: al.error, + revertedByActionId: null, + isRevertAction: false, + createdAt: al.created_at, + }, + toolCallId: matchedToolCallId, + }); + break; + } + + case "data-action-log-updated": { + updateAgentActionReversible({ + id: parsed.data.id, + reversible: parsed.data.reversible, + }); + break; + } + + case "data-turn-info": { + streamedChatTurnId = parsed.data.chat_turn_id || null; + if (streamedChatTurnId) { + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m + ) + ); + } + break; + } + + case "data-revert-results": { + const summary = parsed.data; + // failureCount must include every "not undone" bucket + // (not_reversible, permission_denied, failed) so the + // toast's "X could not be rolled back" math matches + // the response invariant ``total === sum(counters)``. + // ``skipped`` rows are batch revert artefacts (revert + // rows themselves) and are not user-facing failures. + const failureCount = + summary.failed + summary.not_reversible + (summary.permission_denied ?? 0); + if (failureCount > 0) { + toast.warning( + `Pre-revert: ${summary.reverted}/${summary.total} undone, ${failureCount} could not be rolled back.` + ); + } else if (summary.reverted > 0) { + toast.success( + summary.reverted === 1 + ? "Reverted 1 downstream action before regenerating." + : `Reverted ${summary.reverted} downstream actions before regenerating.` + ); + } + for (const r of summary.results) { + if (r.status === "reverted" || r.status === "already_reverted") { + markAgentActionReverted({ + id: r.action_id, + newActionId: r.new_action_id ?? null, + }); + } + } + break; + } + case "data-token-usage": tokenUsageData = parsed.data; tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData); @@ -1552,12 +1901,17 @@ export default function NewChatPage() { const savedUserMessage = await appendMessage(threadId, { role: "user", content: userContentToPersist, + turn_id: streamedChatTurnId, }); // Update user message ID to database ID const newUserMsgId = `msg-${savedUserMessage.id}`; setMessages((prev) => - prev.map((m) => (m.id === userMsgId ? { ...m, id: newUserMsgId } : m)) + prev.map((m) => + m.id === userMsgId + ? mergeChatTurnIdIntoMessage({ ...m, id: newUserMsgId }, savedUserMessage.turn_id) + : m + ) ); // Persist assistant message @@ -1565,12 +1919,17 @@ export default function NewChatPage() { role: "assistant", content: finalContent, token_usage: tokenUsageData ?? undefined, + turn_id: streamedChatTurnId, }); const newMsgId = `msg-${savedMessage.id}`; tokenUsageStore.rename(assistantMsgId, newMsgId); setMessages((prev) => - prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) + prev.map((m) => + m.id === assistantMsgId + ? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id) + : m + ) ); trackChatResponseReceived(searchSpaceId, threadId); @@ -1608,7 +1967,14 @@ export default function NewChatPage() { [threadId, searchSpaceId, messages, disabledTools, tokenUsageStore, toolsWithUI] ); - // Handle editing a message - truncates history and regenerates with new query + // Handle editing a message - truncates history and regenerates with new query. + // + // When ``message.sourceId`` is set (the assistant-ui way to say + // "this edit replaces an older message"), we pin + // ``from_message_id`` so the backend rewinds to the right LangGraph + // checkpoint instead of relying on the legacy "last 2 messages" + // rewind. We also count downstream reversible actions and prompt the + // user to revert / continue / cancel before regenerating. const onEdit = useCallback( async (message: AppendMessage) => { const { userQuery, userImages } = extractUserTurnForNewChatApi(message, []); @@ -1619,9 +1985,95 @@ export default function NewChatPage() { } const userMessageContent = message.content as unknown as ThreadMessageLike["content"]; - await handleRegenerate(queryForApi, { userMessageContent, userImages }); + + // ``sourceId`` per @assistant-ui/core's ``AppendMessage`` is + // "the ID of the message that was edited". Parse the numeric + // suffix so we can map it back to a DB row. + const sourceId = (message as { sourceId?: string }).sourceId; + const fromMessageId = + sourceId && /^msg-\d+$/.test(sourceId) + ? Number.parseInt(sourceId.replace(/^msg-/, ""), 10) + : null; + + if (fromMessageId == null) { + // No source id (or non-DB id) — fall back to today's + // last-2 behaviour. The user gets the legacy edit flow. + await handleRegenerate(queryForApi, { userMessageContent, userImages }); + return; + } + + // Pre-flight: count reversible downstream actions so we can + // auto-skip the dialog for harmless edits. + // + // "Downstream" means messages AFTER the edited one. The + // previous slice ``messages.slice(editedIndex)`` included + // the edited message itself in both the total + // count and the reversibility scan (any actions on the + // edited turn would be double-counted). Slice from + // ``editedIndex + 1`` so the dialog text matches reality: + // "N downstream messages will be dropped". + const editedIndex = messages.findIndex((m) => m.id === `msg-${fromMessageId}`); + let downstreamReversibleCount = 0; + let downstreamTotalCount = 0; + if (editedIndex >= 0) { + const downstream = messages.slice(editedIndex + 1); + downstreamTotalCount = downstream.length; + const seenTurns = new Set<string>(); + for (const m of downstream) { + const meta = (m.metadata ?? {}) as { custom?: { chatTurnId?: string } }; + const tid = meta.custom?.chatTurnId; + if (!tid || seenTurns.has(tid)) continue; + seenTurns.add(tid); + const turnActions = agentActionsByChatTurnId.get(tid) ?? []; + for (const a of turnActions) { + if (a.reversible && a.revertedByActionId === null && !a.isRevertAction && !a.error) { + downstreamReversibleCount += 1; + } + } + } + } + + if (downstreamReversibleCount === 0) { + // Nothing to revert — submit silently. + await handleRegenerate( + queryForApi, + { userMessageContent, userImages }, + { fromMessageId, revertActions: false } + ); + return; + } + + setEditDialogState({ + fromMessageId, + userQuery: queryForApi, + userMessageContent, + userImages, + downstreamReversibleCount, + downstreamTotalCount, + }); }, - [handleRegenerate] + [handleRegenerate, messages, agentActionsByChatTurnId] + ); + + const handleEditDialogChoice = useCallback( + async (choice: EditMessageDialogChoice) => { + const pending = editDialogState; + if (!pending) return; + setEditDialogState(null); + if (choice === "cancel") return; + await handleRegenerate( + pending.userQuery, + { + userMessageContent: pending.userMessageContent, + userImages: pending.userImages, + }, + { + fromMessageId: pending.fromMessageId, + revertActions: choice === "revert", + } + ); + }, + [editDialogState, handleRegenerate] ); // Handle reloading/refreshing the last AI response @@ -1671,6 +2123,7 @@ export default function NewChatPage() { <TokenUsageProvider store={tokenUsageStore}> <AssistantRuntimeProvider runtime={runtime}> <ThinkingStepsDataUI /> + <StepSeparatorDataUI /> <div key={searchSpaceId} className="flex h-full overflow-hidden"> <div className="flex-1 flex flex-col min-w-0 overflow-hidden"> <Thread /> @@ -1679,6 +2132,15 @@ export default function NewChatPage() { <MobileEditorPanel /> <MobileHitlEditPanel /> </div> + <EditMessageDialog + open={editDialogState !== null} + onOpenChange={(open) => { + if (!open) setEditDialogState(null); + }} + downstreamReversibleCount={editDialogState?.downstreamReversibleCount ?? 0} + downstreamTotalCount={editDialogState?.downstreamTotalCount ?? 0} + onChoose={handleEditDialogChoice} + /> </AssistantRuntimeProvider> </TokenUsageProvider> ); diff --git a/surfsense_web/atoms/chat/agent-actions.atom.ts b/surfsense_web/atoms/chat/agent-actions.atom.ts new file mode 100644 index 000000000..7830c8751 --- /dev/null +++ b/surfsense_web/atoms/chat/agent-actions.atom.ts @@ -0,0 +1,194 @@ +"use client"; + +import { atom } from "jotai"; + +/** + * Minimal per-row projection of ``AgentActionLog`` that the tool card + * needs to decide whether to render a Revert button. + * + * Fields are deliberately a subset of the full ``AgentAction`` so the + * SSE side-channel (``data-action-log`` / ``data-action-log-updated``) + * can populate them without depending on the REST endpoint + * ``GET /threads/.../actions`` (which 503s when + * ``SURFSENSE_ENABLE_ACTION_LOG`` is off). + */ +export interface AgentActionLite { + id: number; + threadId: number | null; + lcToolCallId: string | null; + chatTurnId: string | null; + toolName: string; + reversible: boolean; + reverseDescriptorPresent: boolean; + error: boolean; + revertedByActionId: number | null; + isRevertAction: boolean; + createdAt: string | null; +} + +/** + * Map keyed off the LangChain ``tool_call.id`` (mirrors ``ContentPart + * tool-call.langchainToolCallId``). + */ +export const agentActionByLcIdAtom = atom<Map<string, AgentActionLite>>(new Map()); + +/** + * Parallel map keyed off the synthetic chat-card ``toolCallId`` + * (``call_<run-id>``) so ``ToolFallback`` (which only receives the + * synthetic id from assistant-ui) can join its card to the action log. + * + * Both maps are kept in sync by ``upsertAgentActionAtom``. + */ +export const agentActionByToolCallIdAtom = atom<Map<string, AgentActionLite>>(new Map()); + +/** + * Index keyed by ``chat_turn_id`` so the per-turn revert UI can answer + * "how many reversible actions does this assistant turn contain?" in + * O(1). Each entry's array is ordered by insertion (which + * for a single turn matches ``created_at`` because action-log writes + * happen synchronously). + */ +export const agentActionsByChatTurnIdAtom = atom<Map<string, AgentActionLite[]>>(new Map()); + +/** + * Action to upsert one ``AgentActionLite`` row. + * + * ``toolCallId`` is the synthetic card id (``call_<run-id>`` from + * ``stream_new_chat.py``). When provided alongside ``lcToolCallId``, the + * action is indexed under BOTH ids so the tool card can perform the + * lookup without going via the streaming state. + */ +export const upsertAgentActionAtom = atom( + null, + (_get, set, payload: { action: AgentActionLite; toolCallId?: string | null }) => { + const { action, toolCallId } = payload; + const upsertInto = ( + prev: Map<string, AgentActionLite>, + key: string + ): Map<string, AgentActionLite> => { + const next = new Map(prev); + const existing = next.get(key); + next.set(key, { + ...action, + // Preserve the local "reverted" bookkeeping if a reversibility + // flip arrives AFTER the user already reverted via the REST + // route. We never want a stale ``reversible=true`` event to + // resurrect a Reverted card. + revertedByActionId: existing?.revertedByActionId ?? action.revertedByActionId, + isRevertAction: existing?.isRevertAction ?? action.isRevertAction, + }); + return next; + }; + if (action.lcToolCallId) { + set(agentActionByLcIdAtom, (prev) => upsertInto(prev, action.lcToolCallId as string)); + } + if (toolCallId) { + set(agentActionByToolCallIdAtom, (prev) => upsertInto(prev, toolCallId)); + } + if (action.chatTurnId) { + set(agentActionsByChatTurnIdAtom, (prev) => { + const next = new Map(prev); + const turnId = action.chatTurnId as string; + const existing = next.get(turnId) ?? []; + const priorEntry = existing.find((row) => row.id === action.id); + const merged: AgentActionLite = { + ...action, + revertedByActionId: priorEntry?.revertedByActionId ?? action.revertedByActionId, + isRevertAction: priorEntry?.isRevertAction ?? action.isRevertAction, + }; + const others = existing.filter((row) => row.id !== action.id); + next.set(turnId, [...others, merged]); + return next; + }); + } + } +); + +function mutateById( + prev: Map<string, AgentActionLite>, + id: number, + mutator: (entry: AgentActionLite) => AgentActionLite +): Map<string, AgentActionLite> { + let mutated = false; + const next = new Map(prev); + for (const [key, value] of next) { + if (value.id === id) { + next.set(key, mutator(value)); + mutated = true; + } + } + return mutated ? next : prev; +} + +function mutateByIdInTurnIndex( + prev: Map<string, AgentActionLite[]>, + id: number, + mutator: (entry: AgentActionLite) => AgentActionLite +): Map<string, AgentActionLite[]> { + let mutated = false; + const next = new Map(prev); + for (const [key, list] of next) { + let listMutated = false; + const updated = list.map((row) => { + if (row.id === id) { + listMutated = true; + return mutator(row); + } + return row; + }); + if (listMutated) { + next.set(key, updated); + mutated = true; + } + } + return mutated ? next : prev; +} + +/** + * Action to flip an existing entry's ``reversible`` flag, keyed by the + * AgentActionLog row id (the SSE ``data-action-log-updated`` payload + * does NOT carry ``lcToolCallId``). + */ +export const updateAgentActionReversibleAtom = atom( + null, + (_get, set, payload: { id: number; reversible: boolean }) => { + const apply = (entry: AgentActionLite): AgentActionLite => ({ + ...entry, + reversible: payload.reversible, + }); + set(agentActionByLcIdAtom, (prev) => mutateById(prev, payload.id, apply)); + set(agentActionByToolCallIdAtom, (prev) => mutateById(prev, payload.id, apply)); + set(agentActionsByChatTurnIdAtom, (prev) => mutateByIdInTurnIndex(prev, payload.id, apply)); + } +); + +/** Action to mark an existing entry as reverted (post-revert call). */ +export const markAgentActionRevertedAtom = atom( + null, + (_get, set, payload: { id: number; newActionId: number | null }) => { + const apply = (entry: AgentActionLite): AgentActionLite => ({ + ...entry, + revertedByActionId: payload.newActionId ?? -1, + }); + set(agentActionByLcIdAtom, (prev) => mutateById(prev, payload.id, apply)); + set(agentActionByToolCallIdAtom, (prev) => mutateById(prev, payload.id, apply)); + set(agentActionsByChatTurnIdAtom, (prev) => mutateByIdInTurnIndex(prev, payload.id, apply)); + } +); + +/** Mark every action in a turn as reverted, given a list of (id, newActionId) pairs. */ +export const markAgentActionsRevertedBatchAtom = atom( + null, + (_get, set, payload: { entries: Array<{ id: number; newActionId: number | null }> }) => { + for (const entry of payload.entries) { + set(markAgentActionRevertedAtom, entry); + } + } +); + +/** Reset all maps (e.g. when the active thread changes). */ +export const resetAgentActionMapAtom = atom(null, (_get, set) => { + set(agentActionByLcIdAtom, new Map()); + set(agentActionByToolCallIdAtom, new Map()); + set(agentActionsByChatTurnIdAtom, new Map()); +}); diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx index 6b9c2c87e..bfe0434b4 100644 --- a/surfsense_web/components/assistant-ui/assistant-message.tsx +++ b/surfsense_web/components/assistant-ui/assistant-message.tsx @@ -33,6 +33,8 @@ import { useAllCitationMetadata, } from "@/components/assistant-ui/citation-metadata-context"; import { MarkdownText } from "@/components/assistant-ui/markdown-text"; +import { ReasoningMessagePart } from "@/components/assistant-ui/reasoning-message-part"; +import { RevertTurnButton } from "@/components/assistant-ui/revert-turn-button"; import { useTokenUsage } from "@/components/assistant-ui/token-usage-context"; import { ToolFallback } from "@/components/assistant-ui/tool-fallback"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; @@ -491,6 +493,7 @@ const AssistantMessageInner: FC = () => { <MessagePrimitive.Parts components={{ Text: MarkdownText, + Reasoning: ReasoningMessagePart, tools: { by_name: { generate_report: GenerateReportToolUI, @@ -699,6 +702,13 @@ const AssistantActionBar: FC = () => { const isLast = useAuiState((s) => s.message.isLast); const aui = useAui(); const api = useElectronAPI(); + // Surface the persisted ``chat_turn_id`` so the per-turn revert + // affordance can scope to just this message's actions. Streamed + // turns get their id once the assistant message is hydrated/finalised. + const chatTurnId = useAuiState(({ message }) => { + const meta = message?.metadata as { custom?: { chatTurnId?: string | null } } | undefined; + return meta?.custom?.chatTurnId ?? null; + }); const isQuickAssist = !!api?.replaceText && IS_QUICK_ASSIST_WINDOW; @@ -743,6 +753,9 @@ const AssistantActionBar: FC = () => { </TooltipIconButton> )} <MessageInfoDropdown /> + <div className="ml-auto"> + <RevertTurnButton chatTurnId={chatTurnId} /> + </div> </ActionBarPrimitive.Root> ); }; diff --git a/surfsense_web/components/assistant-ui/edit-message-dialog.tsx b/surfsense_web/components/assistant-ui/edit-message-dialog.tsx new file mode 100644 index 000000000..807f16fe7 --- /dev/null +++ b/surfsense_web/components/assistant-ui/edit-message-dialog.tsx @@ -0,0 +1,106 @@ +"use client"; + +/** + * Confirmation dialog shown when the user edits a message that has + * reversible downstream actions. Three buttons: + * + * • "Revert all & resubmit" — POST regenerate with revert_actions=true + * • "Continue without revert" — POST regenerate with revert_actions=false + * • "Cancel" — abort the edit entirely + * + * The dialog is auto-skipped when zero reversible downstream actions + * exist (the caller checks first via ``downstreamReversibleCount``). + */ + +import { useEffect, useRef, useState } from "react"; +import { + AlertDialog, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from "@/components/ui/alert-dialog"; +import { Button } from "@/components/ui/button"; + +export type EditMessageDialogChoice = "revert" | "continue" | "cancel"; + +export interface EditMessageDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + downstreamReversibleCount: number; + downstreamTotalCount: number; + onChoose: (choice: EditMessageDialogChoice) => void | Promise<void>; +} + +export function EditMessageDialog({ + open, + onOpenChange, + downstreamReversibleCount, + downstreamTotalCount, + onChoose, +}: EditMessageDialogProps) { + const [busy, setBusy] = useState<EditMessageDialogChoice | null>(null); + + // The parent's ``handleEditDialogChoice`` calls + // ``setEditDialogState(null)`` BEFORE awaiting ``handleRegenerate``. + // That collapses the dialog (Radix unmounts it) while ``onChoose`` + // is still awaiting the long-running stream. Without this guard, + // the ``finally { setBusy(null) }`` below ran after unmount and + // produced a "state update on unmounted component" dev warning. + const mountedRef = useRef(true); + useEffect(() => { + mountedRef.current = true; + return () => { + mountedRef.current = false; + }; + }, []); + + const handle = async (choice: EditMessageDialogChoice) => { + setBusy(choice); + try { + await onChoose(choice); + } finally { + if (mountedRef.current) { + setBusy(null); + } + } + }; + + return ( + <AlertDialog open={open} onOpenChange={onOpenChange}> + <AlertDialogContent> + <AlertDialogHeader> + <AlertDialogTitle>Edit this message?</AlertDialogTitle> + <AlertDialogDescription> + This edit drops {downstreamTotalCount} downstream message + {downstreamTotalCount === 1 ? "" : "s"} from the thread. {downstreamReversibleCount}{" "} + action + {downstreamReversibleCount === 1 ? "" : "s"} (e.g. file writes, connector changes) can + be rolled back. Pick how to handle them before regenerating. + </AlertDialogDescription> + </AlertDialogHeader> + + <div className="grid gap-2"> + <Button variant="default" disabled={busy !== null} onClick={() => handle("revert")}> + {busy === "revert" + ? "Reverting & resubmitting…" + : `Revert ${downstreamReversibleCount} action${ + downstreamReversibleCount === 1 ? "" : "s" + } & resubmit`} + </Button> + <Button variant="outline" disabled={busy !== null} onClick={() => handle("continue")}> + {busy === "continue" ? "Resubmitting…" : "Continue without reverting"} + </Button> + </div> + + <AlertDialogFooter className="sm:justify-start"> + <AlertDialogCancel disabled={busy !== null} onClick={() => handle("cancel")}> + Cancel + </AlertDialogCancel> + </AlertDialogFooter> + </AlertDialogContent> + </AlertDialog> + ); +} diff --git a/surfsense_web/components/assistant-ui/reasoning-message-part.tsx b/surfsense_web/components/assistant-ui/reasoning-message-part.tsx new file mode 100644 index 000000000..70636eab8 --- /dev/null +++ b/surfsense_web/components/assistant-ui/reasoning-message-part.tsx @@ -0,0 +1,81 @@ +"use client"; + +import type { ReasoningMessagePartComponent } from "@assistant-ui/react"; +import { ChevronRightIcon } from "lucide-react"; +import { useEffect, useMemo, useState } from "react"; +import { TextShimmerLoader } from "@/components/prompt-kit/loader"; +import { cn } from "@/lib/utils"; + +/** + * Renders the structured `reasoning` part emitted by the backend's + * stream-parity v2 path (A1). + * + * Behaviour mirrors the existing `ThinkingStepsDisplay`: + * - collapsed by default; + * - auto-expanded while the part is still `running`; + * - auto-collapsed once status flips to `complete`. + * + * The component is registered via the `Reasoning` slot on + * `MessagePrimitive.Parts` in `assistant-message.tsx` so it lives at the + * exact ordinal position of the reasoning block in the message content + * array (i.e. above the assistant text that follows it). + */ +export const ReasoningMessagePart: ReasoningMessagePartComponent = ({ text, status }) => { + const isRunning = status?.type === "running"; + const [isOpen, setIsOpen] = useState(() => isRunning); + + useEffect(() => { + if (isRunning) { + setIsOpen(true); + } else if (status?.type === "complete") { + setIsOpen(false); + } + }, [isRunning, status?.type]); + + const headerLabel = useMemo(() => { + if (isRunning) return "Thinking"; + if (status?.type === "incomplete") return "Thinking interrupted"; + return "Thought"; + }, [isRunning, status?.type]); + + if (!text || text.length === 0) { + if (!isRunning) return null; + } + + return ( + <div className="mx-auto w-full max-w-(--thread-max-width) px-2 py-2"> + <div className="rounded-lg"> + <button + type="button" + onClick={() => setIsOpen((prev) => !prev)} + className={cn( + "flex w-full items-center gap-1.5 text-left text-sm transition-colors", + "text-muted-foreground hover:text-foreground" + )} + > + {isRunning ? ( + <TextShimmerLoader text={headerLabel} size="sm" /> + ) : ( + <span>{headerLabel}</span> + )} + <ChevronRightIcon + className={cn("size-4 transition-transform duration-200", isOpen && "rotate-90")} + /> + </button> + + <div + className={cn( + "grid transition-[grid-template-rows] duration-300 ease-out", + isOpen ? "grid-rows-[1fr]" : "grid-rows-[0fr]" + )} + > + <div className="overflow-hidden"> + <div className="mt-2 border-l border-muted-foreground/30 pl-3 text-sm leading-relaxed text-muted-foreground whitespace-pre-wrap wrap-break-word"> + {text} + </div> + </div> + </div> + </div> + </div> + ); +}; diff --git a/surfsense_web/components/assistant-ui/revert-turn-button.tsx b/surfsense_web/components/assistant-ui/revert-turn-button.tsx new file mode 100644 index 000000000..9c349738f --- /dev/null +++ b/surfsense_web/components/assistant-ui/revert-turn-button.tsx @@ -0,0 +1,232 @@ +"use client"; + +/** + * "Revert turn" button rendered at the bottom of every completed + * assistant turn that has at least one reversible action. + * + * The button reads the action map keyed by ``chat_turn_id`` from the + * SSE side-channel (``data-action-log`` events). It shows a confirmation + * dialog summarising "N reversible / M total" and, on confirm, calls + * ``POST /threads/{id}/revert-turn/{chat_turn_id}``. + * + * The route returns a per-action result list and never collapses the + * batch into a 4xx — so we render any failed/not_reversible rows inline + * with their messages. + */ + +import { useAtomValue, useSetAtom } from "jotai"; +import { selectAtom } from "jotai/utils"; +import { CheckIcon, RotateCcw, XCircleIcon } from "lucide-react"; +import { useMemo, useState } from "react"; +import { toast } from "sonner"; +import { + type AgentActionLite, + agentActionsByChatTurnIdAtom, + markAgentActionsRevertedBatchAtom, +} from "@/atoms/chat/agent-actions.atom"; +import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, + AlertDialogTrigger, +} from "@/components/ui/alert-dialog"; +import { Button } from "@/components/ui/button"; +import { + agentActionsApiService, + type RevertTurnActionResult, +} from "@/lib/apis/agent-actions-api.service"; +import { AppError } from "@/lib/error"; +import { cn } from "@/lib/utils"; + +interface RevertTurnButtonProps { + chatTurnId: string | null | undefined; +} + +function formatToolName(name: string): string { + return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase()); +} + +// Empty-array sentinel so the per-turn ``selectAtom`` slice returns a +// stable reference when the turn has no recorded actions yet. Without +// this every render allocates a fresh ``[]`` and Jotai's +// equality check would re-render the button on unrelated turn updates. +const EMPTY_ACTIONS: readonly AgentActionLite[] = Object.freeze([]); + +export function RevertTurnButton({ chatTurnId }: RevertTurnButtonProps) { + const session = useAtomValue(chatSessionStateAtom); + const markRevertedBatch = useSetAtom(markAgentActionsRevertedBatchAtom); + const [isReverting, setIsReverting] = useState(false); + const [confirmOpen, setConfirmOpen] = useState(false); + const [resultsOpen, setResultsOpen] = useState(false); + const [results, setResults] = useState<RevertTurnActionResult[]>([]); + + // Subscribe ONLY to the slice of the global action map that belongs + // to ``chatTurnId``. Previously the button read the whole + // ``agentActionsByChatTurnIdAtom``, which meant every action + // upsert (one per tool call) re-rendered every Revert button on + // the page. With ``selectAtom`` we re-render only when our turn's + // list reference changes — and the upsert/mark atoms produce a + // fresh list reference for the affected turn only. + const sliceAtom = useMemo( + () => + selectAtom( + agentActionsByChatTurnIdAtom, + (turnIndex) => (chatTurnId ? turnIndex.get(chatTurnId) : undefined) ?? EMPTY_ACTIONS + ), + [chatTurnId] + ); + const actions = useAtomValue(sliceAtom); + + const reversibleCount = useMemo( + () => + actions.filter( + (a) => a.reversible && a.revertedByActionId === null && !a.isRevertAction && !a.error + ).length, + [actions] + ); + const totalCount = useMemo(() => actions.filter((a) => !a.isRevertAction).length, [actions]); + + if (!chatTurnId) return null; + if (reversibleCount === 0) return null; + const threadId = session?.threadId; + if (!threadId) return null; + + const handleRevertTurn = async () => { + setIsReverting(true); + try { + const response = await agentActionsApiService.revertTurn(threadId, chatTurnId); + setResults(response.results); + const revertedEntries = response.results + .filter((r) => r.status === "reverted" || r.status === "already_reverted") + .map((r) => ({ id: r.action_id, newActionId: r.new_action_id ?? null })); + if (revertedEntries.length > 0) { + markRevertedBatch({ entries: revertedEntries }); + } + if (response.status === "ok") { + toast.success( + response.reverted === 1 ? "Reverted 1 action." : `Reverted ${response.reverted} actions.` + ); + } else { + // Every "not undone" bucket counts as a failure for the + // user-facing summary. ``skipped`` rows are batch + // artefacts (revert rows themselves) and intentionally + // excluded from the failure tally. + const failureCount = + response.failed + response.not_reversible + (response.permission_denied ?? 0); + toast.warning( + `Reverted ${response.reverted} of ${response.total}. ${failureCount} could not be undone.` + ); + setResultsOpen(true); + } + } catch (err) { + if (err instanceof AppError && err.status === 503) { + return; + } + const message = + err instanceof AppError + ? err.message + : err instanceof Error + ? err.message + : "Failed to revert turn."; + toast.error(message); + } finally { + setIsReverting(false); + setConfirmOpen(false); + } + }; + + return ( + <> + <AlertDialog open={confirmOpen} onOpenChange={setConfirmOpen}> + <AlertDialogTrigger asChild> + <Button + size="sm" + variant="ghost" + className="text-muted-foreground hover:text-foreground gap-1.5" + onClick={(e) => { + e.stopPropagation(); + setConfirmOpen(true); + }} + > + <RotateCcw className="size-3.5" /> + <span>Revert turn</span> + <span className="text-xs tabular-nums opacity-70"> + {reversibleCount}/{totalCount} + </span> + </Button> + </AlertDialogTrigger> + <AlertDialogContent> + <AlertDialogHeader> + <AlertDialogTitle>Revert this turn?</AlertDialogTitle> + <AlertDialogDescription> + This will undo {reversibleCount} of {totalCount} action + {totalCount === 1 ? "" : "s"} from this turn in reverse order. The chat history and + any read-only actions are preserved. Some rows may not be reversible — partial success + is normal. + </AlertDialogDescription> + </AlertDialogHeader> + <AlertDialogFooter> + <AlertDialogCancel disabled={isReverting}>Cancel</AlertDialogCancel> + <AlertDialogAction + onClick={(e) => { + e.preventDefault(); + handleRevertTurn(); + }} + disabled={isReverting} + > + {isReverting ? "Reverting…" : "Revert turn"} + </AlertDialogAction> + </AlertDialogFooter> + </AlertDialogContent> + </AlertDialog> + + <AlertDialog open={resultsOpen} onOpenChange={setResultsOpen}> + <AlertDialogContent> + <AlertDialogHeader> + <AlertDialogTitle>Revert results</AlertDialogTitle> + <AlertDialogDescription> + Some actions could not be reverted. Review per-row outcomes below. + </AlertDialogDescription> + </AlertDialogHeader> + <ul className="max-h-72 overflow-y-auto space-y-2 text-sm"> + {results.map((r) => ( + <RevertResultRow key={r.action_id} result={r} /> + ))} + </ul> + <AlertDialogFooter> + <AlertDialogAction onClick={() => setResultsOpen(false)}>Close</AlertDialogAction> + </AlertDialogFooter> + </AlertDialogContent> + </AlertDialog> + </> + ); +} + +function RevertResultRow({ result }: { result: RevertTurnActionResult }) { + const isOk = result.status === "reverted" || result.status === "already_reverted"; + const Icon = isOk ? CheckIcon : XCircleIcon; + return ( + <li className="flex items-start gap-2 rounded-md border bg-muted/30 px-3 py-2"> + <Icon + className={cn("size-4 mt-0.5 shrink-0", isOk ? "text-emerald-500" : "text-destructive")} + /> + <div className="min-w-0 flex-1"> + <p className="font-medium truncate"> + {formatToolName(result.tool_name)}{" "} + <span className="ml-1 text-xs text-muted-foreground"> + {result.status.replace(/_/g, " ")} + </span> + </p> + {(result.message || result.error) && ( + <p className="text-xs text-muted-foreground mt-0.5">{result.error ?? result.message}</p> + )} + </div> + </li> + ); +} diff --git a/surfsense_web/components/assistant-ui/step-separator.tsx b/surfsense_web/components/assistant-ui/step-separator.tsx new file mode 100644 index 000000000..f59130661 --- /dev/null +++ b/surfsense_web/components/assistant-ui/step-separator.tsx @@ -0,0 +1,27 @@ +"use client"; + +import { makeAssistantDataUI } from "@assistant-ui/react"; + +/** + * Renders a thin horizontal divider between model steps within a single + * assistant turn. The data part is pushed by `addStepSeparator` in + * `streaming-state.ts` whenever a `start-step` SSE event arrives after + * the message already has non-step content. + * + * Today the backend emits one `start-step` / `finish-step` pair per turn, + * so most messages won't contain a separator. The renderer is wired up so + * the planned per-model-step refactor (A2 follow-up) can light up without + * touching the persistence path. + */ +function StepSeparatorDataRenderer() { + return ( + <div className="mx-auto my-3 w-full max-w-(--thread-max-width) px-2"> + <div className="border-t border-border/60" /> + </div> + ); +} + +export const StepSeparatorDataUI = makeAssistantDataUI({ + name: "step-separator", + render: StepSeparatorDataRenderer, +}); diff --git a/surfsense_web/components/assistant-ui/tool-fallback.tsx b/surfsense_web/components/assistant-ui/tool-fallback.tsx index 112f3e1d8..70eab9ffc 100644 --- a/surfsense_web/components/assistant-ui/tool-fallback.tsx +++ b/surfsense_web/components/assistant-ui/tool-fallback.tsx @@ -1,12 +1,33 @@ import type { ToolCallMessagePartComponent } from "@assistant-ui/react"; -import { CheckIcon, ChevronDownIcon, ChevronUpIcon, XCircleIcon } from "lucide-react"; +import { useAtomValue, useSetAtom } from "jotai"; +import { CheckIcon, ChevronDownIcon, ChevronUpIcon, RotateCcw, XCircleIcon } from "lucide-react"; import { useMemo, useState } from "react"; +import { toast } from "sonner"; +import { + agentActionByToolCallIdAtom, + markAgentActionRevertedAtom, +} from "@/atoms/chat/agent-actions.atom"; +import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom"; import { DoomLoopApprovalToolUI, isDoomLoopInterrupt, } from "@/components/tool-ui/doom-loop-approval"; import { GenericHitlApprovalToolUI } from "@/components/tool-ui/generic-hitl-approval"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, + AlertDialogTrigger, +} from "@/components/ui/alert-dialog"; +import { Button } from "@/components/ui/button"; import { getToolIcon } from "@/contracts/enums/toolIcons"; +import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; +import { AppError } from "@/lib/error"; import { isInterruptResult } from "@/lib/hitl"; import { cn } from "@/lib/utils"; @@ -14,7 +35,99 @@ function formatToolName(name: string): string { return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase()); } +/** + * Inline Revert button rendered on a tool card when the matching + * ``AgentActionLog`` row is reversible and hasn't been reverted yet. + * Reads from the SSE side-channel atom keyed by the synthetic + * ``toolCallId`` so it lights up even when ``GET /threads/.../actions`` + * is gated behind ``SURFSENSE_ENABLE_ACTION_LOG=False`` (503). + */ +function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) { + const session = useAtomValue(chatSessionStateAtom); + const actionMap = useAtomValue(agentActionByToolCallIdAtom); + const markReverted = useSetAtom(markAgentActionRevertedAtom); + const action = actionMap.get(toolCallId); + const [isReverting, setIsReverting] = useState(false); + const [confirmOpen, setConfirmOpen] = useState(false); + + if (!action) return null; + if (!action.reversible) return null; + if (action.revertedByActionId !== null) return null; + if (action.isRevertAction) return null; + if (action.error) return null; + const threadId = session?.threadId; + if (!threadId) return null; + + const handleRevert = async () => { + setIsReverting(true); + try { + const response = await agentActionsApiService.revert(threadId, action.id); + markReverted({ id: action.id, newActionId: response.new_action_id ?? null }); + toast.success(response.message || "Action reverted."); + } catch (err) { + // 503 means revert is gated off on this deployment — hide the + // button silently rather than nagging the user. Any other error + // is surfaced as a toast so the operator can investigate. + if (err instanceof AppError && err.status === 503) { + return; + } + const message = + err instanceof AppError + ? err.message + : err instanceof Error + ? err.message + : "Failed to revert action."; + toast.error(message); + } finally { + setIsReverting(false); + setConfirmOpen(false); + } + }; + + return ( + <AlertDialog open={confirmOpen} onOpenChange={setConfirmOpen}> + <AlertDialogTrigger asChild> + <Button + size="sm" + variant="outline" + className="gap-1.5" + onClick={(e) => { + e.stopPropagation(); + setConfirmOpen(true); + }} + > + <RotateCcw className="size-3.5" /> + Revert + </Button> + </AlertDialogTrigger> + <AlertDialogContent> + <AlertDialogHeader> + <AlertDialogTitle>Revert this action?</AlertDialogTitle> + <AlertDialogDescription> + This will undo <span className="font-medium">{formatToolName(action.toolName)}</span>{" "} + and append a new audit entry. Chat history is preserved — only the tool's effects on + your knowledge base or connectors will be reversed where possible. + </AlertDialogDescription> + </AlertDialogHeader> + <AlertDialogFooter> + <AlertDialogCancel disabled={isReverting}>Cancel</AlertDialogCancel> + <AlertDialogAction + onClick={(e) => { + e.preventDefault(); + handleRevert(); + }} + disabled={isReverting} + > + {isReverting ? "Reverting…" : "Revert"} + </AlertDialogAction> + </AlertDialogFooter> + </AlertDialogContent> + </AlertDialog> + ); +} + const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({ + toolCallId, toolName, argsText, result, @@ -145,6 +258,9 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({ </div> </> )} + <div className="flex justify-end"> + <ToolCardRevertButton toolCallId={toolCallId} /> + </div> </div> </> )} diff --git a/surfsense_web/components/free-chat/free-chat-page.tsx b/surfsense_web/components/free-chat/free-chat-page.tsx index deac1fd00..bfdd613e2 100644 --- a/surfsense_web/components/free-chat/free-chat-page.tsx +++ b/surfsense_web/components/free-chat/free-chat-page.tsx @@ -9,6 +9,7 @@ import { import { Turnstile, type TurnstileInstance } from "@marsidev/react-turnstile"; import { ShieldCheck } from "lucide-react"; import { useCallback, useEffect, useRef, useState } from "react"; +import { StepSeparatorDataUI } from "@/components/assistant-ui/step-separator"; import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps"; import { createTokenUsageStore, @@ -17,10 +18,13 @@ import { } from "@/components/assistant-ui/token-usage-context"; import { useAnonymousMode } from "@/contexts/anonymous-mode"; import { + addStepSeparator, addToolCall, + appendReasoning, appendText, buildContentForUI, type ContentPartsState, + endReasoning, FrameBatchedUpdater, readSSEStream, type ThinkingStepData, @@ -32,7 +36,9 @@ import { trackAnonymousChatMessageSent } from "@/lib/posthog/events"; import { FreeModelSelector } from "./free-model-selector"; import { FreeThread } from "./free-thread"; -const TOOLS_WITH_UI = new Set(["web_search", "document_qna"]); +// Render all tool calls via ToolFallback; backend keeps persisted +// payloads bounded by summarising / truncating outputs. +const TOOLS_WITH_UI = "all" as const; const TURNSTILE_SITE_KEY = process.env.NEXT_PUBLIC_TURNSTILE_SITE_KEY ?? ""; /** Try to parse a CAPTCHA_REQUIRED or CAPTCHA_INVALID code from a non-ok response. */ @@ -125,6 +131,7 @@ export function FreeChatPage() { const contentPartsState: ContentPartsState = { contentParts: [], currentTextPartIndex: -1, + currentReasoningPartIndex: -1, toolCallIndices: new Map(), }; const { toolCallIndices } = contentPartsState; @@ -148,28 +155,62 @@ export function FreeChatPage() { scheduleFlush(); break; + case "reasoning-delta": + appendReasoning(contentPartsState, parsed.delta); + scheduleFlush(); + break; + + case "reasoning-end": + endReasoning(contentPartsState); + scheduleFlush(); + break; + + case "start-step": + addStepSeparator(contentPartsState); + scheduleFlush(); + break; + + case "finish-step": + break; + case "tool-input-start": - addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {}); + addToolCall( + contentPartsState, + TOOLS_WITH_UI, + parsed.toolCallId, + parsed.toolName, + {}, + false, + parsed.langchainToolCallId + ); batcher.flush(); break; case "tool-input-available": if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {} }); + updateToolCall(contentPartsState, parsed.toolCallId, { + args: parsed.input || {}, + langchainToolCallId: parsed.langchainToolCallId, + }); } else { addToolCall( contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, - parsed.input || {} + parsed.input || {}, + false, + parsed.langchainToolCallId ); } batcher.flush(); break; case "tool-output-available": - updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output }); + updateToolCall(contentPartsState, parsed.toolCallId, { + result: parsed.output, + langchainToolCallId: parsed.langchainToolCallId, + }); batcher.flush(); break; @@ -369,6 +410,7 @@ export function FreeChatPage() { <TokenUsageProvider store={tokenUsageStore}> <AssistantRuntimeProvider runtime={runtime}> <ThinkingStepsDataUI /> + <StepSeparatorDataUI /> <div className="flex h-full flex-col overflow-hidden"> <div className="flex h-14 shrink-0 items-center justify-between border-b border-border/40 px-4"> <FreeModelSelector /> diff --git a/surfsense_web/components/public-chat/public-chat-view.tsx b/surfsense_web/components/public-chat/public-chat-view.tsx index f8dd6db5a..e47ba9bf1 100644 --- a/surfsense_web/components/public-chat/public-chat-view.tsx +++ b/surfsense_web/components/public-chat/public-chat-view.tsx @@ -1,6 +1,7 @@ "use client"; import { AssistantRuntimeProvider } from "@assistant-ui/react"; +import { StepSeparatorDataUI } from "@/components/assistant-ui/step-separator"; import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps"; import { Navbar } from "@/components/homepage/navbar"; import { ReportPanel } from "@/components/report-panel/report-panel"; @@ -41,6 +42,7 @@ export function PublicChatView({ shareToken }: PublicChatViewProps) { <Navbar scrolledBgClassName={navbarScrolledBg} /> <AssistantRuntimeProvider runtime={runtime}> <ThinkingStepsDataUI /> + <StepSeparatorDataUI /> <div className="flex h-screen pt-16 overflow-hidden"> <div className="flex-1 flex flex-col min-w-0 overflow-hidden"> <PublicThread footer={<PublicChatFooter shareToken={shareToken} />} /> diff --git a/surfsense_web/components/public-chat/public-thread.tsx b/surfsense_web/components/public-chat/public-thread.tsx index 627baf831..22e914988 100644 --- a/surfsense_web/components/public-chat/public-thread.tsx +++ b/surfsense_web/components/public-chat/public-thread.tsx @@ -13,6 +13,7 @@ import Image from "next/image"; import { type FC, type ReactNode, useState } from "react"; import { CitationMetadataProvider } from "@/components/assistant-ui/citation-metadata-context"; import { MarkdownText } from "@/components/assistant-ui/markdown-text"; +import { ReasoningMessagePart } from "@/components/assistant-ui/reasoning-message-part"; import { ToolFallback } from "@/components/assistant-ui/tool-fallback"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { GenerateImageToolUI } from "@/components/tool-ui/generate-image"; @@ -157,6 +158,7 @@ const PublicAssistantMessage: FC = () => { <MessagePrimitive.Parts components={{ Text: MarkdownText, + Reasoning: ReasoningMessagePart, tools: { by_name: { generate_podcast: GeneratePodcastToolUI, diff --git a/surfsense_web/contracts/enums/toolIcons.tsx b/surfsense_web/contracts/enums/toolIcons.tsx index bc63bc1b0..1aab08096 100644 --- a/surfsense_web/contracts/enums/toolIcons.tsx +++ b/surfsense_web/contracts/enums/toolIcons.tsx @@ -1,27 +1,112 @@ import { BookOpen, Brain, + Calendar, + Check, + FileEdit, + FilePlus, FileText, FileUser, + FileX, Film, + FolderPlus, + FolderTree, + FolderX, Globe, ImageIcon, + ListTodo, type LucideIcon, + Mail, + MessagesSquare, + Move, + Plus, Podcast, ScanLine, + Search, + Send, + Trash2, Wrench, } from "lucide-react"; +/** + * Every tool now renders a card via ``ToolFallback``. The icon map is + * keyed on the canonical backend tool name (registered in + * ``surfsense_backend/app/agents/new_chat/tools/registry.py``); unknown + * names fall back to the generic ``Wrench`` icon so the card still + * communicates "this is a tool call". + */ const TOOL_ICONS: Record<string, LucideIcon> = { + // Generators generate_podcast: Podcast, generate_video_presentation: Film, generate_report: FileText, generate_resume: FileUser, generate_image: ImageIcon, + display_image: ImageIcon, + // Web / search scrape_webpage: ScanLine, web_search: Globe, search_surfsense_docs: BookOpen, + // Memory update_memory: Brain, + // Filesystem (built-in deepagent + middleware) + read_file: FileText, + write_file: FilePlus, + edit_file: FileEdit, + move_file: Move, + rm: FileX, + rmdir: FolderX, + mkdir: FolderPlus, + ls: FolderTree, + write_todos: ListTodo, + // Calendar + search_calendar_events: Search, + create_calendar_event: Calendar, + update_calendar_event: Calendar, + delete_calendar_event: Calendar, + // Gmail + search_gmail: Search, + read_gmail_email: Mail, + create_gmail_draft: Mail, + update_gmail_draft: FileEdit, + send_gmail_email: Send, + trash_gmail_email: Trash2, + // Notion / Confluence pages + create_notion_page: FilePlus, + update_notion_page: FileEdit, + delete_notion_page: FileX, + create_confluence_page: FilePlus, + update_confluence_page: FileEdit, + delete_confluence_page: FileX, + // Linear / Jira issues + create_linear_issue: Plus, + update_linear_issue: FileEdit, + delete_linear_issue: Trash2, + create_jira_issue: Plus, + update_jira_issue: FileEdit, + delete_jira_issue: Trash2, + // Drive-like file connectors + create_google_drive_file: FilePlus, + delete_google_drive_file: FileX, + create_dropbox_file: FilePlus, + delete_dropbox_file: FileX, + create_onedrive_file: FilePlus, + delete_onedrive_file: FileX, + // Chat connectors + list_discord_channels: MessagesSquare, + read_discord_messages: MessagesSquare, + send_discord_message: Send, + list_teams_channels: MessagesSquare, + read_teams_messages: MessagesSquare, + send_teams_message: Send, + // Luma + list_luma_events: Calendar, + read_luma_event: Calendar, + create_luma_event: Calendar, + // Misc + get_connected_accounts: Check, + execute: Wrench, + execute_code: Wrench, }; export function getToolIcon(name: string): LucideIcon { diff --git a/surfsense_web/lib/apis/agent-actions-api.service.ts b/surfsense_web/lib/apis/agent-actions-api.service.ts index 007bb131e..6634a11f7 100644 --- a/surfsense_web/lib/apis/agent-actions-api.service.ts +++ b/surfsense_web/lib/apis/agent-actions-api.service.ts @@ -15,6 +15,12 @@ const AgentActionReadSchema = z.object({ reverse_of: z.number().nullable(), reverted_by_action_id: z.number().nullable(), is_revert_action: z.boolean(), + // Correlation ids added in migration 135. The LangChain + // ``tool_call_id`` joins this row to the chat tool card via the + // ``data-action-log.lc_tool_call_id`` SSE event, and + // ``chat_turn_id`` keys the per-turn revert endpoint. + tool_call_id: z.string().nullable().optional(), + chat_turn_id: z.string().nullable().optional(), created_at: z.string(), }); @@ -38,6 +44,48 @@ const RevertResponseSchema = z.object({ export type RevertResponse = z.infer<typeof RevertResponseSchema>; +// Per-turn batch revert. The route never returns whole-batch 4xx; +// partial success is the common case and surfaced as +// ``status === "partial"`` with a per-action result list. +const RevertTurnActionResultSchema = z.object({ + action_id: z.number(), + tool_name: z.string(), + status: z.enum([ + "reverted", + "already_reverted", + "not_reversible", + "permission_denied", + "failed", + "skipped", + ]), + message: z.string().nullable().optional(), + new_action_id: z.number().nullable().optional(), + error: z.string().nullable().optional(), +}); + +export type RevertTurnActionResult = z.infer<typeof RevertTurnActionResultSchema>; + +const RevertTurnResponseSchema = z.object({ + status: z.enum(["ok", "partial"]), + chat_turn_id: z.string(), + total: z.number(), + reverted: z.number(), + already_reverted: z.number(), + not_reversible: z.number(), + // ``permission_denied`` and ``skipped`` are first-class counters so + // ``total === reverted + already_reverted + + // not_reversible + permission_denied + failed + skipped`` always + // holds. ``.default(0)`` keeps the schema backwards-compatible + // with older deployments that haven't shipped the response model + // update yet. + permission_denied: z.number().default(0), + failed: z.number(), + skipped: z.number().default(0), + results: z.array(RevertTurnActionResultSchema), +}); + +export type RevertTurnResponse = z.infer<typeof RevertTurnResponseSchema>; + class AgentActionsApiService { listForThread = async ( threadId: number, @@ -59,6 +107,14 @@ class AgentActionsApiService { { body: {} } ); }; + + revertTurn = async (threadId: number, chatTurnId: string): Promise<RevertTurnResponse> => { + return baseApiService.post( + `/api/v1/threads/${threadId}/revert-turn/${encodeURIComponent(chatTurnId)}`, + RevertTurnResponseSchema, + { body: {} } + ); + }; } export const agentActionsApiService = new AgentActionsApiService(); diff --git a/surfsense_web/lib/chat/message-utils.ts b/surfsense_web/lib/chat/message-utils.ts index 2d1a6976f..004542489 100644 --- a/surfsense_web/lib/chat/message-utils.ts +++ b/surfsense_web/lib/chat/message-utils.ts @@ -40,7 +40,7 @@ export function convertToThreadMessage(msg: MessageRecord): ThreadMessageLike { } const metadata = - msg.author_id || msg.token_usage + msg.author_id || msg.token_usage || msg.turn_id ? { custom: { ...(msg.author_id && { @@ -50,6 +50,10 @@ export function convertToThreadMessage(msg: MessageRecord): ThreadMessageLike { }, }), ...(msg.token_usage && { usage: msg.token_usage }), + // Surface ``chat_turn_id`` so the assistant message + // footer can scope its "Revert turn" button to just + // this turn's actions. Null on legacy rows. + ...(msg.turn_id && { chatTurnId: msg.turn_id }), }, } : undefined; diff --git a/surfsense_web/lib/chat/streaming-state.ts b/surfsense_web/lib/chat/streaming-state.ts index ff8fdfbd4..26fd7b98c 100644 --- a/surfsense_web/lib/chat/streaming-state.ts +++ b/surfsense_web/lib/chat/streaming-state.ts @@ -9,21 +9,42 @@ export interface ThinkingStepData { export type ContentPart = | { type: "text"; text: string } + | { type: "reasoning"; text: string } | { type: "tool-call"; toolCallId: string; toolName: string; args: Record<string, unknown>; result?: unknown; + /** + * Authoritative LangChain ``tool_call.id`` propagated by the backend + * via ``langchainToolCallId`` on tool-input-start/available and + * tool-output-available events. Used to join a card to the + * matching ``AgentActionLog`` row exposed by + * ``GET /threads/{id}/actions`` and the streamed + * ``data-action-log`` events. + */ + langchainToolCallId?: string; } | { type: "data-thinking-steps"; data: { steps: ThinkingStepData[] }; + } + | { + /** + * Between-step separator. Pushed by `addStepSeparator` when + * a `start-step` SSE event arrives AFTER the message already + * has non-step content. Rendered by `StepSeparatorDataUI` + * (see assistant-ui/step-separator.tsx). + */ + type: "data-step-separator"; + data: { stepIndex: number }; }; export interface ContentPartsState { contentParts: ContentPart[]; currentTextPartIndex: number; + currentReasoningPartIndex: number; toolCallIndices: Map<string, number>; } @@ -74,6 +95,9 @@ export function updateThinkingSteps( if (state.currentTextPartIndex >= 0) { state.currentTextPartIndex += 1; } + if (state.currentReasoningPartIndex >= 0) { + state.currentReasoningPartIndex += 1; + } for (const [id, idx] of state.toolCallIndices) { state.toolCallIndices.set(id, idx + 1); } @@ -131,6 +155,12 @@ export class FrameBatchedUpdater { } export function appendText(state: ContentPartsState, delta: string): void { + // First text delta after a reasoning block: close the reasoning so + // the assistant-ui renderer treats them as separate parts (the + // reasoning block collapses; the answer streams below). + if (state.currentReasoningPartIndex >= 0) { + state.currentReasoningPartIndex = -1; + } if ( state.currentTextPartIndex >= 0 && state.contentParts[state.currentTextPartIndex]?.type === "text" @@ -143,36 +173,129 @@ export function appendText(state: ContentPartsState, delta: string): void { } } +export function appendReasoning(state: ContentPartsState, delta: string): void { + // Symmetric to appendText: open a fresh reasoning block on first + // delta, then accumulate into it. ``endReasoning`` simply closes + // the active block; subsequent reasoning deltas would open a new + // one (matching ``text-start/end`` semantics on the wire). + if (state.currentTextPartIndex >= 0) { + state.currentTextPartIndex = -1; + } + if ( + state.currentReasoningPartIndex >= 0 && + state.contentParts[state.currentReasoningPartIndex]?.type === "reasoning" + ) { + ( + state.contentParts[state.currentReasoningPartIndex] as { + type: "reasoning"; + text: string; + } + ).text += delta; + } else { + state.contentParts.push({ type: "reasoning", text: delta }); + state.currentReasoningPartIndex = state.contentParts.length - 1; + } +} + +export function endReasoning(state: ContentPartsState): void { + state.currentReasoningPartIndex = -1; +} + +export function addStepSeparator(state: ContentPartsState): void { + // Push a divider between consecutive model steps within a single + // assistant turn. We only emit it when the message already has + // non-step content (so the FIRST step of a turn doesn't + // generate a leading separator) and when the previous part isn't + // itself a separator (defensive against duplicate `start-step` + // events). + const hasContent = state.contentParts.some( + (p) => p.type === "text" || p.type === "reasoning" || p.type === "tool-call" + ); + if (!hasContent) return; + const last = state.contentParts[state.contentParts.length - 1]; + if (last && last.type === "data-step-separator") return; + + const stepIndex = state.contentParts.filter((p) => p.type === "data-step-separator").length; + state.contentParts.push({ type: "data-step-separator", data: { stepIndex } }); + state.currentTextPartIndex = -1; + state.currentReasoningPartIndex = -1; +} + +/** + * Allowlist of tool names that should produce a UI tool card. The + * sentinel ``"all"`` matches every tool — we dropped the legacy + * ``BASE_TOOLS_WITH_UI`` gate so that ALL tool calls render via the + * generic ``ToolFallback``. The backend's ``format_thinking_step`` + * summarisation and the defensive ``result_length``-only default for + * unknown tools keep persisted message JSON from ballooning. + */ +export type ToolUIGate = Set<string> | "all"; + +function _toolPasses(gate: ToolUIGate, toolName: string): boolean { + return gate === "all" || gate.has(toolName); +} + export function addToolCall( state: ContentPartsState, - toolsWithUI: Set<string>, + toolsWithUI: ToolUIGate, toolCallId: string, toolName: string, args: Record<string, unknown>, - force = false + force = false, + langchainToolCallId?: string ): void { - if (force || toolsWithUI.has(toolName)) { + if (force || _toolPasses(toolsWithUI, toolName)) { state.contentParts.push({ type: "tool-call", toolCallId, toolName, args, + ...(langchainToolCallId ? { langchainToolCallId } : {}), }); state.toolCallIndices.set(toolCallId, state.contentParts.length - 1); state.currentTextPartIndex = -1; + state.currentReasoningPartIndex = -1; } } +/** + * Reverse-lookup helper used by the SSE ``data-action-log`` handler: + * given the LangChain ``tool_call.id`` (set on the content part as + * ``langchainToolCallId``), return the synthetic ``toolCallId`` that + * the chat tool card uses (``call_<run-id>``). Returns ``null`` when no + * matching tool card has been seen yet — the action is still recorded + * in the LC-id-keyed atom so the card can pick it up when it eventually + * arrives. + */ +export function findToolCallIdByLcId( + state: ContentPartsState, + lcToolCallId: string +): string | null { + for (const part of state.contentParts) { + if (part.type === "tool-call" && part.langchainToolCallId === lcToolCallId) { + return part.toolCallId; + } + } + return null; +} + export function updateToolCall( state: ContentPartsState, toolCallId: string, - update: { args?: Record<string, unknown>; result?: unknown } + update: { args?: Record<string, unknown>; result?: unknown; langchainToolCallId?: string } ): void { const index = state.toolCallIndices.get(toolCallId); if (index !== undefined && state.contentParts[index]?.type === "tool-call") { const tc = state.contentParts[index] as ContentPart & { type: "tool-call" }; if (update.args) tc.args = update.args; if (update.result !== undefined) tc.result = update.result; + // Only backfill langchainToolCallId if not already set — the + // authoritative ``on_tool_end`` value should override an earlier + // best-effort match, but a NULL late-arriving value should not + // blow away a known good early one. + if (update.langchainToolCallId && !tc.langchainToolCallId) { + tc.langchainToolCallId = update.langchainToolCallId; + } } } @@ -184,13 +307,15 @@ function _hasInterruptResult(part: ContentPart): boolean { export function buildContentForUI( state: ContentPartsState, - toolsWithUI: Set<string> + toolsWithUI: ToolUIGate ): ThreadMessageLike["content"] { const filtered = state.contentParts.filter((part) => { if (part.type === "text") return part.text.length > 0; + if (part.type === "reasoning") return part.text.length > 0; if (part.type === "tool-call") - return toolsWithUI.has(part.toolName) || _hasInterruptResult(part); + return _toolPasses(toolsWithUI, part.toolName) || _hasInterruptResult(part); if (part.type === "data-thinking-steps") return true; + if (part.type === "data-step-separator") return true; return false; }); return filtered.length > 0 @@ -200,20 +325,28 @@ export function buildContentForUI( export function buildContentForPersistence( state: ContentPartsState, - toolsWithUI: Set<string> + toolsWithUI: ToolUIGate ): unknown[] { const parts: unknown[] = []; for (const part of state.contentParts) { if (part.type === "text" && part.text.length > 0) { parts.push(part); + } else if (part.type === "reasoning" && part.text.length > 0) { + // Persist reasoning blocks so a chat reload re-renders the + // collapsed thinking section instead of + // silently dropping it (mirrors the data-thinking-steps + // branch above). + parts.push(part); } else if ( part.type === "tool-call" && - (toolsWithUI.has(part.toolName) || _hasInterruptResult(part)) + (_toolPasses(toolsWithUI, part.toolName) || _hasInterruptResult(part)) ) { parts.push(part); } else if (part.type === "data-thinking-steps") { parts.push(part); + } else if (part.type === "data-step-separator") { + parts.push(part); } } @@ -221,23 +354,122 @@ export function buildContentForPersistence( } export type SSEEvent = - | { type: "text-delta"; delta: string } - | { type: "tool-input-start"; toolCallId: string; toolName: string } + | { type: "start"; messageId?: string } + | { type: "finish" } + | { type: "start-step" } + | { type: "finish-step" } + | { type: "text-start"; id: string } + | { type: "text-delta"; id?: string; delta: string } + | { type: "text-end"; id: string } + | { type: "reasoning-start"; id: string } + | { type: "reasoning-delta"; id?: string; delta: string } + | { type: "reasoning-end"; id: string } + | { + type: "tool-input-start"; + toolCallId: string; + toolName: string; + /** Authoritative LangChain ``tool_call.id``. Optional. */ + langchainToolCallId?: string; + } | { type: "tool-input-available"; toolCallId: string; toolName: string; input: Record<string, unknown>; + langchainToolCallId?: string; } | { type: "tool-output-available"; toolCallId: string; output: Record<string, unknown>; + /** Authoritative LangChain ``tool_call.id`` extracted from + * ``ToolMessage.tool_call_id`` at on_tool_end. Backfills cards + * that didn't get the id at tool-input-start time. */ + langchainToolCallId?: string; } | { type: "data-thinking-step"; data: ThinkingStepData } | { type: "data-thread-title-update"; data: { threadId: number; title: string } } | { type: "data-interrupt-request"; data: Record<string, unknown> } | { type: "data-documents-updated"; data: Record<string, unknown> } + | { + /** + * A freshly committed AgentActionLog row. Frontend stores + * this in a Map keyed off ``lc_tool_call_id`` so the chat + * tool card can light up its Revert button. + */ + type: "data-action-log"; + data: { + id: number; + lc_tool_call_id: string | null; + chat_turn_id: string | null; + tool_name: string; + reversible: boolean; + reverse_descriptor_present: boolean; + created_at: string | null; + error: boolean; + }; + } + | { + /** + * Reversibility flipped (filesystem op SAVEPOINT committed; + * cf. ``kb_persistence._dispatch_reversibility_update``). + */ + type: "data-action-log-updated"; + data: { id: number; reversible: boolean }; + } + | { + /** + * Emitted at the start of every stream so the frontend can + * stamp the per-turn correlation id onto the in-flight + * assistant message and replay it via + * ``appendMessage``. Pure-text turns never produce + * action-log events; this event guarantees the frontend + * always learns the turn id. + */ + type: "data-turn-info"; + data: { chat_turn_id: string }; + } + | { + /** + * Best-effort revert pass that ran BEFORE this regeneration. + * Per-action results are forwarded to the UI so the user + * can see which downstream actions were rolled + * back vs which couldn't be undone. + */ + type: "data-revert-results"; + data: { + status: "ok" | "partial"; + chat_turn_ids: string[]; + total: number; + reverted: number; + already_reverted: number; + not_reversible: number; + /** + * ``permission_denied`` and ``skipped`` are first-class + * counters so the response invariant + * ``total === sum(counters)`` always holds. Optional + * for forward compatibility with older backends; the + * frontend treats missing values as ``0``. + */ + permission_denied?: number; + failed: number; + skipped?: number; + results: Array<{ + action_id: number; + tool_name: string; + status: + | "reverted" + | "already_reverted" + | "not_reversible" + | "permission_denied" + | "failed" + | "skipped"; + message?: string | null; + new_action_id?: number | null; + error?: string | null; + }>; + }; + } | { type: "data-token-usage"; data: { diff --git a/surfsense_web/lib/chat/thread-persistence.ts b/surfsense_web/lib/chat/thread-persistence.ts index b5c5899b4..fc970c26e 100644 --- a/surfsense_web/lib/chat/thread-persistence.ts +++ b/surfsense_web/lib/chat/thread-persistence.ts @@ -46,6 +46,11 @@ export interface MessageRecord { author_display_name?: string | null; author_avatar_url?: string | null; token_usage?: TokenUsageSummary | null; + // Per-turn correlation id from ``configurable.turn_id`` at streaming + // time (added in migration 136). Used by the per-turn revert + // endpoint and edit-from-arbitrary-position. Nullable on legacy + // rows that predate the column. + turn_id?: string | null; } export interface ThreadListResponse { @@ -123,10 +128,20 @@ export async function getThreadMessages(threadId: number): Promise<ThreadHistory /** * Append a message to a thread. + * + * ``turn_id`` is the per-turn correlation id streamed by the backend + * via ``data-turn-info``. Persisting it lets later edits locate the + * matching LangGraph checkpoint without HumanMessage scanning. Older + * callers can still omit it for back-compat. */ export async function appendMessage( threadId: number, - message: { role: "user" | "assistant" | "system"; content: unknown; token_usage?: unknown } + message: { + role: "user" | "assistant" | "system"; + content: unknown; + token_usage?: unknown; + turn_id?: string | null; + } ): Promise<MessageRecord> { return baseApiService.post<MessageRecord>(`/api/v1/threads/${threadId}/messages`, undefined, { body: message, From 9a114a2d45f0341c0edbfeac1dafba858aa7e38e Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Wed, 29 Apr 2026 07:40:11 -0700 Subject: [PATCH 07/68] feat: enhance tool display names for better user experience in chat UI --- .../app/tasks/chat/stream_new_chat.py | 141 +++++++++++++++++- .../agent-action-log/action-log-item.tsx | 8 +- .../assistant-ui/revert-turn-button.tsx | 7 +- .../components/assistant-ui/thread.tsx | 13 +- .../components/assistant-ui/tool-fallback.tsx | 19 +-- .../tool-ui/generic-hitl-approval.tsx | 6 +- surfsense_web/contracts/enums/toolIcons.tsx | 105 +++++++++++++ 7 files changed, 267 insertions(+), 32 deletions(-) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 2f8e33ba9..f7bf75649 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -622,6 +622,95 @@ async def _stream_agent_events( status="in_progress", items=last_active_step_items, ) + elif tool_name == "rm": + rm_path = ( + tool_input.get("path", "") + if isinstance(tool_input, dict) + else str(tool_input) + ) + display_path = rm_path if len(rm_path) <= 80 else "…" + rm_path[-77:] + last_active_step_title = "Deleting file" + last_active_step_items = [display_path] if display_path else [] + yield streaming_service.format_thinking_step( + step_id=tool_step_id, + title="Deleting file", + status="in_progress", + items=last_active_step_items, + ) + elif tool_name == "rmdir": + rmdir_path = ( + tool_input.get("path", "") + if isinstance(tool_input, dict) + else str(tool_input) + ) + display_path = ( + rmdir_path if len(rmdir_path) <= 80 else "…" + rmdir_path[-77:] + ) + last_active_step_title = "Deleting folder" + last_active_step_items = [display_path] if display_path else [] + yield streaming_service.format_thinking_step( + step_id=tool_step_id, + title="Deleting folder", + status="in_progress", + items=last_active_step_items, + ) + elif tool_name == "mkdir": + mkdir_path = ( + tool_input.get("path", "") + if isinstance(tool_input, dict) + else str(tool_input) + ) + display_path = ( + mkdir_path if len(mkdir_path) <= 80 else "…" + mkdir_path[-77:] + ) + last_active_step_title = "Creating folder" + last_active_step_items = [display_path] if display_path else [] + yield streaming_service.format_thinking_step( + step_id=tool_step_id, + title="Creating folder", + status="in_progress", + items=last_active_step_items, + ) + elif tool_name == "move_file": + src = ( + tool_input.get("source_path", "") + if isinstance(tool_input, dict) + else "" + ) + dst = ( + tool_input.get("destination_path", "") + if isinstance(tool_input, dict) + else "" + ) + display_src = src if len(src) <= 60 else "…" + src[-57:] + display_dst = dst if len(dst) <= 60 else "…" + dst[-57:] + last_active_step_title = "Moving file" + last_active_step_items = ( + [f"{display_src} → {display_dst}"] if src or dst else [] + ) + yield streaming_service.format_thinking_step( + step_id=tool_step_id, + title="Moving file", + status="in_progress", + items=last_active_step_items, + ) + elif tool_name == "write_todos": + todos = ( + tool_input.get("todos", []) if isinstance(tool_input, dict) else [] + ) + todo_count = len(todos) if isinstance(todos, list) else 0 + last_active_step_title = "Planning tasks" + last_active_step_items = ( + [f"{todo_count} task{'s' if todo_count != 1 else ''}"] + if todo_count + else [] + ) + yield streaming_service.format_thinking_step( + step_id=tool_step_id, + title="Planning tasks", + status="in_progress", + items=last_active_step_items, + ) elif tool_name == "save_document": doc_title = ( tool_input.get("title", "") @@ -729,7 +818,15 @@ async def _stream_agent_events( items=last_active_step_items, ) else: - last_active_step_title = f"Using {tool_name.replace('_', ' ')}" + # Fallback for tools without a curated thinking-step title + # (typically connector tools, MCP-registered tools, or + # newly added tools that haven't been wired up here yet). + # Render the snake_cased name as a sentence-cased phrase + # so non-technical users see e.g. "Send gmail email" + # rather than the raw identifier "send_gmail_email". + last_active_step_title = ( + tool_name.replace("_", " ").strip().capitalize() or tool_name + ) last_active_step_items = [] yield streaming_service.format_thinking_step( step_id=tool_step_id, @@ -885,6 +982,41 @@ async def _stream_agent_events( status="completed", items=last_active_step_items, ) + elif tool_name == "rm": + yield streaming_service.format_thinking_step( + step_id=original_step_id, + title="Deleting file", + status="completed", + items=last_active_step_items, + ) + elif tool_name == "rmdir": + yield streaming_service.format_thinking_step( + step_id=original_step_id, + title="Deleting folder", + status="completed", + items=last_active_step_items, + ) + elif tool_name == "mkdir": + yield streaming_service.format_thinking_step( + step_id=original_step_id, + title="Creating folder", + status="completed", + items=last_active_step_items, + ) + elif tool_name == "move_file": + yield streaming_service.format_thinking_step( + step_id=original_step_id, + title="Moving file", + status="completed", + items=last_active_step_items, + ) + elif tool_name == "write_todos": + yield streaming_service.format_thinking_step( + step_id=original_step_id, + title="Planning tasks", + status="completed", + items=last_active_step_items, + ) elif tool_name == "save_document": result_str = ( tool_output.get("result", "") @@ -1136,9 +1268,14 @@ async def _stream_agent_events( items=completed_items, ) else: + # Fallback completion title — see the matching in-progress + # branch above for the wording rationale. + fallback_title = ( + tool_name.replace("_", " ").strip().capitalize() or tool_name + ) yield streaming_service.format_thinking_step( step_id=original_step_id, - title=f"Using {tool_name.replace('_', ' ')}", + title=fallback_title, status="completed", items=last_active_step_items, ) diff --git a/surfsense_web/components/agent-action-log/action-log-item.tsx b/surfsense_web/components/agent-action-log/action-log-item.tsx index 425714c1f..673189709 100644 --- a/surfsense_web/components/agent-action-log/action-log-item.tsx +++ b/surfsense_web/components/agent-action-log/action-log-item.tsx @@ -17,16 +17,12 @@ import { import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Separator } from "@/components/ui/separator"; -import { getToolIcon } from "@/contracts/enums/toolIcons"; +import { getToolDisplayName, getToolIcon } from "@/contracts/enums/toolIcons"; import { type AgentAction, agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; import { AppError } from "@/lib/error"; import { formatRelativeDate } from "@/lib/format-date"; import { cn } from "@/lib/utils"; -function formatToolName(name: string): string { - return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase()); -} - interface ActionLogItemProps { action: AgentAction; threadId: number; @@ -43,7 +39,7 @@ export function ActionLogItem({ action, threadId, onRevertSuccess }: ActionLogIt const hasError = action.error !== null && action.error !== undefined; const Icon = getToolIcon(action.tool_name); - const displayName = formatToolName(action.tool_name); + const displayName = getToolDisplayName(action.tool_name); const argsPreview = action.args ? JSON.stringify(action.args, null, 2) : null; const truncatedArgs = diff --git a/surfsense_web/components/assistant-ui/revert-turn-button.tsx b/surfsense_web/components/assistant-ui/revert-turn-button.tsx index 9c349738f..af71299d0 100644 --- a/surfsense_web/components/assistant-ui/revert-turn-button.tsx +++ b/surfsense_web/components/assistant-ui/revert-turn-button.tsx @@ -37,6 +37,7 @@ import { AlertDialogTrigger, } from "@/components/ui/alert-dialog"; import { Button } from "@/components/ui/button"; +import { getToolDisplayName } from "@/contracts/enums/toolIcons"; import { agentActionsApiService, type RevertTurnActionResult, @@ -48,10 +49,6 @@ interface RevertTurnButtonProps { chatTurnId: string | null | undefined; } -function formatToolName(name: string): string { - return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase()); -} - // Empty-array sentinel so the per-turn ``selectAtom`` slice returns a // stable reference when the turn has no recorded actions yet. Without // this every render allocates a fresh ``[]`` and Jotai's @@ -218,7 +215,7 @@ function RevertResultRow({ result }: { result: RevertTurnActionResult }) { /> <div className="min-w-0 flex-1"> <p className="font-medium truncate"> - {formatToolName(result.tool_name)}{" "} + {getToolDisplayName(result.tool_name)}{" "} <span className="ml-1 text-xs text-muted-foreground"> {result.status.replace(/_/g, " ")} </span> diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index cf99598f1..e58783c87 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -82,6 +82,7 @@ import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import { CONNECTOR_ICON_TO_TYPES, CONNECTOR_TOOL_ICON_PATHS, + getToolDisplayName, getToolIcon, } from "@/contracts/enums/toolIcons"; import type { Document } from "@/contracts/types/document.types"; @@ -1317,12 +1318,14 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false ); }; -/** Convert snake_case tool names to human-readable labels */ +/** + * Friendly tool name for display in the chat UI. Delegates to the + * shared map in ``contracts/enums/toolIcons`` so unix-style identifiers + * (``rm``, ``ls``, ``grep`` …) and snake_cased function names render as + * plain English (e.g. "Delete file", "List files", "Search in files"). + */ function formatToolName(name: string): string { - return name - .split("_") - .map((word) => word.charAt(0).toUpperCase() + word.slice(1)) - .join(" "); + return getToolDisplayName(name); } interface ToolGroup { diff --git a/surfsense_web/components/assistant-ui/tool-fallback.tsx b/surfsense_web/components/assistant-ui/tool-fallback.tsx index 70eab9ffc..cc7582695 100644 --- a/surfsense_web/components/assistant-ui/tool-fallback.tsx +++ b/surfsense_web/components/assistant-ui/tool-fallback.tsx @@ -25,16 +25,12 @@ import { AlertDialogTrigger, } from "@/components/ui/alert-dialog"; import { Button } from "@/components/ui/button"; -import { getToolIcon } from "@/contracts/enums/toolIcons"; +import { getToolDisplayName, getToolIcon } from "@/contracts/enums/toolIcons"; import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; import { AppError } from "@/lib/error"; import { isInterruptResult } from "@/lib/hitl"; import { cn } from "@/lib/utils"; -function formatToolName(name: string): string { - return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase()); -} - /** * Inline Revert button rendered on a tool card when the matching * ``AgentActionLog`` row is reversible and hasn't been reverted yet. @@ -104,9 +100,10 @@ function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) { <AlertDialogHeader> <AlertDialogTitle>Revert this action?</AlertDialogTitle> <AlertDialogDescription> - This will undo <span className="font-medium">{formatToolName(action.toolName)}</span>{" "} - and append a new audit entry. Chat history is preserved — only the tool's effects on - your knowledge base or connectors will be reversed where possible. + This will undo{" "} + <span className="font-medium">{getToolDisplayName(action.toolName)}</span> and add a + new entry to the history. Your chat is preserved — only the changes the agent made to + your knowledge base or connected apps will be rolled back where possible. </AlertDialogDescription> </AlertDialogHeader> <AlertDialogFooter> @@ -164,7 +161,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({ : null; const Icon = getToolIcon(toolName); - const displayName = formatToolName(toolName); + const displayName = getToolDisplayName(toolName); return ( <div @@ -215,7 +212,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({ ? `Failed: ${displayName}` : displayName} </p> - {isRunning && <p className="text-xs text-muted-foreground mt-0.5">Running...</p>} + {isRunning && <p className="text-xs text-muted-foreground mt-0.5">Working…</p>} {cancelledReason && ( <p className="text-xs text-muted-foreground mt-0.5 truncate">{cancelledReason}</p> )} @@ -241,7 +238,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({ <div className="px-5 py-3 space-y-3"> {argsText && ( <div> - <p className="text-xs font-medium text-muted-foreground mb-1">Arguments</p> + <p className="text-xs font-medium text-muted-foreground mb-1">Inputs</p> <pre className="text-xs text-foreground/80 whitespace-pre-wrap break-all"> {argsText} </pre> diff --git a/surfsense_web/components/tool-ui/generic-hitl-approval.tsx b/surfsense_web/components/tool-ui/generic-hitl-approval.tsx index ceb1d0209..a584084ff 100644 --- a/surfsense_web/components/tool-ui/generic-hitl-approval.tsx +++ b/surfsense_web/components/tool-ui/generic-hitl-approval.tsx @@ -8,6 +8,7 @@ import { TextShimmerLoader } from "@/components/prompt-kit/loader"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; import { Textarea } from "@/components/ui/textarea"; +import { getToolDisplayName } from "@/contracts/enums/toolIcons"; import { useHitlPhase } from "@/hooks/use-hitl-phase"; import { connectorsApiService } from "@/lib/apis/connectors-api.service"; import type { HitlDecision, InterruptResult } from "@/lib/hitl"; @@ -77,7 +78,7 @@ function GenericApprovalCard({ const [editedParams, setEditedParams] = useState<Record<string, unknown>>(args); const [isEditing, setIsEditing] = useState(false); - const displayName = toolName.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase()); + const displayName = getToolDisplayName(toolName); const mcpServer = interruptData.context?.mcp_server as string | undefined; const toolDescription = interruptData.context?.tool_description as string | undefined; @@ -186,12 +187,11 @@ function GenericApprovalCard({ </> )} - {/* Parameters */} {Object.keys(args).length > 0 && ( <> <div className="mx-5 h-px bg-border/50" /> <div className="px-5 py-4 space-y-2"> - <p className="text-xs font-medium text-muted-foreground">Parameters</p> + <p className="text-xs font-medium text-muted-foreground">Inputs</p> {phase === "pending" && isEditing ? ( <ParamEditor params={editedParams} diff --git a/surfsense_web/contracts/enums/toolIcons.tsx b/surfsense_web/contracts/enums/toolIcons.tsx index 1aab08096..bdb8222cb 100644 --- a/surfsense_web/contracts/enums/toolIcons.tsx +++ b/surfsense_web/contracts/enums/toolIcons.tsx @@ -113,6 +113,111 @@ export function getToolIcon(name: string): LucideIcon { return TOOL_ICONS[name] ?? Wrench; } +/** + * Friendly display names for tools shown in the chat UI. + * + * Most users aren't engineers; they shouldn't see raw unix-style + * identifiers like ``rm`` / ``rmdir`` / ``ls`` / ``grep`` / ``glob`` or + * snake_cased function names. The map below renders each tool with + * plain English wording (verb + object) so non-technical users + * understand what the agent is doing at a glance. + * + * Unmapped tool names fall back to a snake_case-to-Title-Case + * conversion via :func:`getToolDisplayName`. + */ +const TOOL_DISPLAY_NAMES: Record<string, string> = { + // Filesystem / knowledge base + read_file: "Read file", + write_file: "Write file", + edit_file: "Edit file", + move_file: "Move file", + rm: "Delete file", + rmdir: "Delete folder", + mkdir: "Create folder", + ls: "List files", + glob: "Find files", + grep: "Search in files", + write_todos: "Plan tasks", + save_document: "Save document", + // Generators + generate_podcast: "Generate podcast", + generate_video_presentation: "Generate video presentation", + generate_report: "Generate report", + generate_resume: "Generate resume", + generate_image: "Generate image", + display_image: "Show image", + // Web / search + scrape_webpage: "Read webpage", + web_search: "Search the web", + search_surfsense_docs: "Search knowledge base", + // Memory + update_memory: "Update memory", + // Calendar + search_calendar_events: "Search calendar", + create_calendar_event: "Create event", + update_calendar_event: "Update event", + delete_calendar_event: "Delete event", + // Gmail + search_gmail: "Search Gmail", + read_gmail_email: "Read email", + create_gmail_draft: "Draft email", + update_gmail_draft: "Update draft", + send_gmail_email: "Send email", + trash_gmail_email: "Move email to trash", + // Notion + create_notion_page: "Create Notion page", + update_notion_page: "Update Notion page", + delete_notion_page: "Delete Notion page", + // Confluence + create_confluence_page: "Create Confluence page", + update_confluence_page: "Update Confluence page", + delete_confluence_page: "Delete Confluence page", + // Linear + create_linear_issue: "Create Linear issue", + update_linear_issue: "Update Linear issue", + delete_linear_issue: "Delete Linear issue", + // Jira + create_jira_issue: "Create Jira issue", + update_jira_issue: "Update Jira issue", + delete_jira_issue: "Delete Jira issue", + // Drive-like file connectors + create_google_drive_file: "Create Google Drive file", + delete_google_drive_file: "Delete Google Drive file", + create_dropbox_file: "Create Dropbox file", + delete_dropbox_file: "Delete Dropbox file", + create_onedrive_file: "Create OneDrive file", + delete_onedrive_file: "Delete OneDrive file", + // Discord + list_discord_channels: "List Discord channels", + read_discord_messages: "Read Discord messages", + send_discord_message: "Send Discord message", + // Teams + list_teams_channels: "List Teams channels", + read_teams_messages: "Read Teams messages", + send_teams_message: "Send Teams message", + // Luma + list_luma_events: "List Luma events", + read_luma_event: "Read Luma event", + create_luma_event: "Create Luma event", + // Misc + get_connected_accounts: "Check connected accounts", + execute: "Run command", + execute_code: "Run code", +}; + +/** + * Format a tool's canonical (snake_case) name for display in the chat UI. + * + * Looks up :data:`TOOL_DISPLAY_NAMES` first; falls back to a + * snake_case-to-Title-Case rewrite for tools that don't have a curated + * label (e.g. dynamically registered MCP tools). + */ +export function getToolDisplayName(name: string): string { + const friendly = TOOL_DISPLAY_NAMES[name]; + if (friendly) return friendly; + return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase()); +} + export const CONNECTOR_TOOL_ICON_PATHS: Record<string, { src: string; alt: string }> = { gmail: { src: "/connectors/google-gmail.svg", alt: "Gmail" }, google_calendar: { src: "/connectors/google-calendar.svg", alt: "Google Calendar" }, From c598d7038f4f2766e37b9e9dc3e037b07fc1938b Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 29 Apr 2026 20:17:45 +0530 Subject: [PATCH 08/68] refactor(chat): update premium token error messages for clarity and consistency --- .../app/tasks/chat/stream_new_chat.py | 4 ++-- .../new-chat/[[...chat_id]]/page.tsx | 16 ++++++---------- surfsense_web/components/assistant-ui/thread.tsx | 7 +++---- 3 files changed, 11 insertions(+), 16 deletions(-) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 1a56547ca..233b45396 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -1542,7 +1542,7 @@ async def stream_new_chat( llm_config_id, ) yield streaming_service.format_error( - "Premium token quota exceeded for this pinned model. Select a free model or re-select Auto (Fastest) to repin." + "Premium tokens exhausted. Buy more tokens to continue with this model, or switch to a free model." ) yield streaming_service.format_done() return @@ -2263,7 +2263,7 @@ async def stream_resume_chat( llm_config_id, ) yield streaming_service.format_error( - "Premium token quota exceeded for this pinned model. Select a free model or re-select Auto (Fastest) to repin." + "Premium tokens exhausted. Buy more tokens to continue with this model, or switch to a free model." ) yield streaming_service.format_done() return diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index a5461e17f..05621419d 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -201,17 +201,16 @@ const BASE_TOOLS_WITH_UI = new Set([ // "write_todos", // Disabled for now ]); -const PINNED_PREMIUM_QUOTA_MESSAGE = "Premium token quota exceeded for this pinned model."; - function getPinnedPremiumQuotaErrorMessage(error: unknown): string | null { if (!(error instanceof Error)) return null; - if (!error.message.toLowerCase().includes("premium token quota exceeded")) { + const normalized = error.message.toLowerCase(); + if ( + !normalized.includes("premium tokens exhausted") + && !normalized.includes("premium token quota exceeded") + ) { return null; } - if (!error.message.toLowerCase().includes("pinned model")) { - return null; - } - return error.message || PINNED_PREMIUM_QUOTA_MESSAGE; + return error.message; } export default function NewChatPage() { @@ -980,7 +979,6 @@ export default function NewChatPage() { threadId: currentThreadId, message: premiumQuotaAlertMessage, }); - toast.error(PINNED_PREMIUM_QUOTA_MESSAGE); } else { toast.error("Failed to get response. Please try again."); } @@ -1290,7 +1288,6 @@ export default function NewChatPage() { threadId: resumeThreadId, message: premiumQuotaAlertMessage, }); - toast.error(PINNED_PREMIUM_QUOTA_MESSAGE); } else { toast.error("Failed to resume. Please try again."); } @@ -1638,7 +1635,6 @@ export default function NewChatPage() { threadId, message: premiumQuotaAlertMessage, }); - toast.error(PINNED_PREMIUM_QUOTA_MESSAGE); } else { toast.error("Failed to regenerate response. Please try again."); } diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index 06f25f5fb..cb063fac3 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -161,16 +161,15 @@ const PremiumQuotaPinnedAlert: FC = () => { if (!alert) return null; return ( - <div className="mx-2 rounded-2xl border border-amber-300/40 bg-amber-500/10 px-4 py-3 text-amber-50 shadow-lg backdrop-blur-sm"> + <div className="mx-0 bg-amber-500/10 px-3 py-2 text-amber-100"> <div className="flex items-start gap-2"> <AlertCircle className="mt-0.5 size-4 shrink-0 text-amber-300" /> <div className="min-w-0 flex-1"> - <p className="text-sm font-medium">Premium quota exhausted</p> - <p className="mt-1 text-xs text-amber-100/90">{alert.message}</p> + <p className="text-sm">{alert.message}</p> </div> <button type="button" - className="inline-flex size-6 items-center justify-center rounded-md text-amber-200 transition-colors hover:bg-amber-200/20 hover:text-amber-50" + className="inline-flex size-6 items-center justify-center text-amber-200 transition-colors hover:text-amber-50" aria-label="Dismiss premium quota alert" onClick={() => clearPremiumAlertForThread(currentThreadId)} > From d66fa1559b3913648e195c379e60b03ff1f00baf Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 29 Apr 2026 20:29:41 +0530 Subject: [PATCH 09/68] feat(chat): implement forced repin to free tier for pinned LLM configurations --- .../app/services/auto_model_pin_service.py | 17 +- .../app/tasks/chat/stream_new_chat.py | 209 ++++++++++++------ .../services/test_auto_model_pin_service.py | 38 ++++ 3 files changed, 200 insertions(+), 64 deletions(-) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index ce417a26d..6bdb60f57 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -84,6 +84,7 @@ async def resolve_or_get_pinned_llm_config_id( search_space_id: int, user_id: str | UUID | None, selected_llm_config_id: int, + force_repin_free: bool = False, ) -> AutoPinResolution: """Resolve Auto (Fastest) to one concrete config id and persist pin metadata. @@ -130,9 +131,12 @@ async def resolve_or_get_pinned_llm_config_id( raise ValueError("No usable global LLM configs are available for Auto mode") candidate_by_id = {int(c["id"]): c for c in candidates} - # Reuse existing valid pin without re-checking current quota (no silent tier switch). + # Reuse existing valid pin without re-checking current quota (no silent tier switch), + # unless the caller explicitly requests a forced repin to free. pinned_id = thread.pinned_llm_config_id if ( + not force_repin_free + and thread.pinned_auto_mode == AUTO_FASTEST_MODE and pinned_id is not None and int(pinned_id) in candidate_by_id @@ -159,7 +163,7 @@ async def resolve_or_get_pinned_llm_config_id( thread.pinned_auto_mode, ) - premium_eligible = await _is_premium_eligible(session, user_id) + premium_eligible = False if force_repin_free else await _is_premium_eligible(session, user_id) if premium_eligible: eligible = candidates else: @@ -179,6 +183,15 @@ async def resolve_or_get_pinned_llm_config_id( thread.pinned_at = datetime.now(UTC) await session.commit() + if force_repin_free: + logger.info( + "auto_pin_forced_free_repin thread_id=%s search_space_id=%s previous_config_id=%s resolved_config_id=%s", + thread_id, + search_space_id, + pinned_id, + selected_id, + ) + if pinned_id is None: logger.info( "auto_pin_created thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s premium_eligible=%s", diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 233b45396..edc5aa763 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -1455,6 +1455,37 @@ async def stream_new_chat( await set_ai_responding(session, chat_id, UUID(user_id)) # Load LLM config - supports both YAML (negative IDs) and database (positive IDs) agent_config: AgentConfig | None = None + requested_llm_config_id = llm_config_id + + async def _load_llm_bundle( + config_id: int, + ) -> tuple[Any, AgentConfig | None, str | None]: + if config_id >= 0: + loaded_agent_config = await load_agent_config( + session=session, + config_id=config_id, + search_space_id=search_space_id, + ) + if not loaded_agent_config: + return ( + None, + None, + f"Failed to load NewLLMConfig with id {config_id}", + ) + return ( + create_chat_litellm_from_agent_config(loaded_agent_config), + loaded_agent_config, + None, + ) + + loaded_llm_config = load_global_llm_config_by_id(config_id) + if not loaded_llm_config: + return None, None, f"Failed to load LLM config with id {config_id}" + return ( + create_chat_litellm_from_config(loaded_llm_config), + AgentConfig.from_yaml_config(loaded_llm_config), + None, + ) _t0 = time.perf_counter() try: @@ -1472,35 +1503,11 @@ async def stream_new_chat( yield streaming_service.format_done() return - if llm_config_id >= 0: - # Positive ID: Load from NewLLMConfig database table - agent_config = await load_agent_config( - session=session, - config_id=llm_config_id, - search_space_id=search_space_id, - ) - if not agent_config: - yield streaming_service.format_error( - f"Failed to load NewLLMConfig with id {llm_config_id}" - ) - yield streaming_service.format_done() - return - - # Create ChatLiteLLM from AgentConfig - llm = create_chat_litellm_from_agent_config(agent_config) - else: - # Negative ID: Load from in-memory global configs (includes dynamic OpenRouter models) - llm_config = load_global_llm_config_by_id(llm_config_id) - if not llm_config: - yield streaming_service.format_error( - f"Failed to load LLM config with id {llm_config_id}" - ) - yield streaming_service.format_done() - return - - # Create ChatLiteLLM from global config dict - llm = create_chat_litellm_from_config(llm_config) - agent_config = AgentConfig.from_yaml_config(llm_config) + llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + if llm_load_error: + yield streaming_service.format_error(llm_load_error) + yield streaming_service.format_done() + return _perf_log.info( "[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)", time.perf_counter() - _t0, @@ -1541,11 +1548,43 @@ async def stream_new_chat( user_id, llm_config_id, ) - yield streaming_service.format_error( - "Premium tokens exhausted. Buy more tokens to continue with this model, or switch to a free model." - ) - yield streaming_service.format_done() - return + if requested_llm_config_id == 0: + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + force_repin_free=True, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield streaming_service.format_error(str(pin_error)) + yield streaming_service.format_done() + return + + llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + if llm_load_error: + yield streaming_service.format_error(llm_load_error) + yield streaming_service.format_done() + return + _premium_request_id = None + _premium_reserved = 0 + logging.getLogger(__name__).info( + "premium_quota_auto_fallback_to_free thread_id=%s search_space_id=%s user_id=%s fallback_config_id=%s", + chat_id, + search_space_id, + user_id, + llm_config_id, + ) + else: + yield streaming_service.format_error( + "Premium tokens exhausted. Buy more tokens to continue with this model, or switch to a free model." + ) + yield streaming_service.format_done() + return if not llm: yield streaming_service.format_error("Failed to create LLM instance") @@ -2183,6 +2222,38 @@ async def stream_resume_chat( await set_ai_responding(session, chat_id, UUID(user_id)) agent_config: AgentConfig | None = None + requested_llm_config_id = llm_config_id + + async def _load_llm_bundle( + config_id: int, + ) -> tuple[Any, AgentConfig | None, str | None]: + if config_id >= 0: + loaded_agent_config = await load_agent_config( + session=session, + config_id=config_id, + search_space_id=search_space_id, + ) + if not loaded_agent_config: + return ( + None, + None, + f"Failed to load NewLLMConfig with id {config_id}", + ) + return ( + create_chat_litellm_from_agent_config(loaded_agent_config), + loaded_agent_config, + None, + ) + + loaded_llm_config = load_global_llm_config_by_id(config_id) + if not loaded_llm_config: + return None, None, f"Failed to load LLM config with id {config_id}" + return ( + create_chat_litellm_from_config(loaded_llm_config), + AgentConfig.from_yaml_config(loaded_llm_config), + None, + ) + _t0 = time.perf_counter() try: llm_config_id = ( @@ -2199,29 +2270,11 @@ async def stream_resume_chat( yield streaming_service.format_done() return - if llm_config_id >= 0: - agent_config = await load_agent_config( - session=session, - config_id=llm_config_id, - search_space_id=search_space_id, - ) - if not agent_config: - yield streaming_service.format_error( - f"Failed to load NewLLMConfig with id {llm_config_id}" - ) - yield streaming_service.format_done() - return - llm = create_chat_litellm_from_agent_config(agent_config) - else: - llm_config = load_global_llm_config_by_id(llm_config_id) - if not llm_config: - yield streaming_service.format_error( - f"Failed to load LLM config with id {llm_config_id}" - ) - yield streaming_service.format_done() - return - llm = create_chat_litellm_from_config(llm_config) - agent_config = AgentConfig.from_yaml_config(llm_config) + llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + if llm_load_error: + yield streaming_service.format_error(llm_load_error) + yield streaming_service.format_done() + return _perf_log.info( "[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0 ) @@ -2262,11 +2315,43 @@ async def stream_resume_chat( user_id, llm_config_id, ) - yield streaming_service.format_error( - "Premium tokens exhausted. Buy more tokens to continue with this model, or switch to a free model." - ) - yield streaming_service.format_done() - return + if requested_llm_config_id == 0: + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + force_repin_free=True, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield streaming_service.format_error(str(pin_error)) + yield streaming_service.format_done() + return + + llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + if llm_load_error: + yield streaming_service.format_error(llm_load_error) + yield streaming_service.format_done() + return + _resume_premium_request_id = None + _resume_premium_reserved = 0 + logging.getLogger(__name__).info( + "premium_quota_auto_fallback_to_free thread_id=%s search_space_id=%s user_id=%s fallback_config_id=%s", + chat_id, + search_space_id, + user_id, + llm_config_id, + ) + else: + yield streaming_service.format_error( + "Premium tokens exhausted. Buy more tokens to continue with this model, or switch to a free model." + ) + yield streaming_service.format_done() + return if not llm: yield streaming_service.format_error("Failed to create LLM instance") diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index a9853c980..f08e50ba2 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -227,6 +227,44 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): assert result.from_existing_pin is True +@pytest.mark.asyncio +async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch): + from app.config import config + + session = _FakeSession( + _thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE) + ) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"}, + {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + force_repin_free=True, + ) + assert result.resolved_llm_config_id == -2 + assert result.resolved_tier == "free" + assert result.from_existing_pin is False + assert session.thread.pinned_llm_config_id == -2 + + @pytest.mark.asyncio async def test_explicit_user_model_change_clears_pin(monkeypatch): from app.config import config From a68889511569caa045dec9420353b5e8a9a16647 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Wed, 29 Apr 2026 08:03:39 -0700 Subject: [PATCH 10/68] feat: increase recursion limit for chat streaming to enhance tool iteration capabilities --- .../app/tasks/chat/stream_new_chat.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index f7bf75649..1493c4326 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -2090,7 +2090,16 @@ async def stream_new_chat( config = { "configurable": configurable, - "recursion_limit": 80, # Increase from default 25 to allow more tool iterations + # Effectively uncapped, matching the agent-level + # ``with_config`` default in ``chat_deepagent.create_agent`` + # and the unbounded ``while(true)`` loop used by OpenCode's + # ``session/processor.ts``. Real circuit-breakers live in + # middleware: ``DoomLoopMiddleware`` (sliding-window tool + # signature check), plus ``enable_tool_call_limit`` / + # ``enable_model_call_limit`` when those flags are set. The + # original LangGraph default of 25 (and our previous 80 + # bump) hit users on legitimate multi-tool plans. + "recursion_limit": 10_000, } # Start the message stream @@ -2686,7 +2695,11 @@ async def stream_resume_chat( "request_id": request_id or "unknown", "turn_id": stream_result.turn_id, }, - "recursion_limit": 80, + # See ``stream_new_chat`` above for rationale: effectively + # uncapped to mirror the agent default and OpenCode's + # session loop. Doom-loop / call-limit middleware enforce + # the real ceiling. + "recursion_limit": 10_000, } yield streaming_service.format_message_start() From fa6a09197ef51641a649d14b14dd68cc131fddbd Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 29 Apr 2026 20:57:33 +0530 Subject: [PATCH 11/68] feat(chat): enhance error handling for premium quota exhaustion in chat messages --- .../new-chat/[[...chat_id]]/page.tsx | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 05621419d..ed0611ee9 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -201,6 +201,9 @@ const BASE_TOOLS_WITH_UI = new Set([ // "write_todos", // Disabled for now ]); +const PREMIUM_QUOTA_ASSISTANT_MESSAGE = + "I can’t continue with the current premium model because your premium tokens are exhausted. Switch to a free model or buy more tokens to continue."; + function getPinnedPremiumQuotaErrorMessage(error: unknown): string | null { if (!(error instanceof Error)) return null; const normalized = error.message.toLowerCase(); @@ -992,7 +995,9 @@ export default function NewChatPage() { { type: "text", text: - premiumQuotaAlertMessage ?? + (premiumQuotaAlertMessage + ? PREMIUM_QUOTA_ASSISTANT_MESSAGE + : undefined) ?? "Sorry, there was an error. Please try again.", }, ], @@ -1288,6 +1293,16 @@ export default function NewChatPage() { threadId: resumeThreadId, message: premiumQuotaAlertMessage, }); + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId + ? { + ...m, + content: [{ type: "text", text: PREMIUM_QUOTA_ASSISTANT_MESSAGE }], + } + : m + ) + ); } else { toast.error("Failed to resume. Please try again."); } @@ -1647,7 +1662,9 @@ export default function NewChatPage() { { type: "text", text: - premiumQuotaAlertMessage ?? + (premiumQuotaAlertMessage + ? PREMIUM_QUOTA_ASSISTANT_MESSAGE + : undefined) ?? "Sorry, there was an error. Please try again.", }, ], From 901de3368402d7545ea2572e617c063b357429a2 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 29 Apr 2026 21:05:21 +0530 Subject: [PATCH 12/68] feat(chat): enhance error formatting to include optional error codes for better frontend handling --- .../app/services/new_streaming_service.py | 10 +- .../app/tasks/chat/stream_new_chat.py | 6 +- .../new-chat/[[...chat_id]]/page.tsx | 137 +++++++++++------- surfsense_web/lib/chat/streaming-state.ts | 2 +- 4 files changed, 97 insertions(+), 58 deletions(-) diff --git a/surfsense_backend/app/services/new_streaming_service.py b/surfsense_backend/app/services/new_streaming_service.py index 52a215997..3e24c1376 100644 --- a/surfsense_backend/app/services/new_streaming_service.py +++ b/surfsense_backend/app/services/new_streaming_service.py @@ -565,20 +565,24 @@ class VercelStreamingService: # Error Part # ========================================================================= - def format_error(self, error_text: str) -> str: + def format_error(self, error_text: str, error_code: str | None = None) -> str: """ Format an error message. Args: error_text: The error message text + error_code: Optional machine-readable error code for frontend branching Returns: str: SSE formatted error part Example output: - data: {"type":"error","errorText":"Something went wrong"} + data: {"type":"error","errorText":"Something went wrong","errorCode":"SOME_CODE"} """ - return self._format_sse({"type": "error", "errorText": error_text}) + payload: dict[str, str] = {"type": "error", "errorText": error_text} + if error_code: + payload["errorCode"] = error_code + return self._format_sse(payload) # ========================================================================= # Tool Parts diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index edc5aa763..060dd23c6 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -1581,7 +1581,8 @@ async def stream_new_chat( ) else: yield streaming_service.format_error( - "Premium tokens exhausted. Buy more tokens to continue with this model, or switch to a free model." + "Buy more tokens to continue with this model, or switch to a free model.", + error_code="PREMIUM_QUOTA_EXHAUSTED", ) yield streaming_service.format_done() return @@ -2348,7 +2349,8 @@ async def stream_resume_chat( ) else: yield streaming_service.format_error( - "Premium tokens exhausted. Buy more tokens to continue with this model, or switch to a free model." + "Buy more tokens to continue with this model, or switch to a free model.", + error_code="PREMIUM_QUOTA_EXHAUSTED", ) yield streaming_service.format_done() return diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index ed0611ee9..f775e1f06 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -206,10 +206,15 @@ const PREMIUM_QUOTA_ASSISTANT_MESSAGE = function getPinnedPremiumQuotaErrorMessage(error: unknown): string | null { if (!(error instanceof Error)) return null; + const withCode = error as Error & { errorCode?: string }; + if (withCode.errorCode === "PREMIUM_QUOTA_EXHAUSTED") { + return error.message; + } const normalized = error.message.toLowerCase(); if ( !normalized.includes("premium tokens exhausted") && !normalized.includes("premium token quota exceeded") + && !normalized.includes("buy more tokens") ) { return null; } @@ -233,6 +238,50 @@ export default function NewChatPage() { } | null>(null); const toolsWithUI = useMemo(() => new Set([...BASE_TOOLS_WITH_UI]), []); + const persistAssistantErrorMessage = useCallback( + async ({ + threadId, + assistantMsgId, + text, + }: { + threadId: number | null; + assistantMsgId: string; + text: string; + }) => { + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId + ? { + ...m, + content: [{ type: "text", text }], + } + : m + ) + ); + + if (!threadId) return; + + // Persist only temporary assistant placeholders to avoid duplicate rows + // when the message already has a database-backed ID. + if (!assistantMsgId.startsWith("msg-assistant-")) return; + + try { + const savedMessage = await appendMessage(threadId, { + role: "assistant", + content: [{ type: "text", text }], + }); + const newMsgId = `msg-${savedMessage.id}`; + tokenUsageStore.rename(assistantMsgId, newMsgId); + setMessages((prev) => + prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) + ); + } catch (persistErr) { + console.error("Failed to persist assistant error message:", persistErr); + } + }, + [tokenUsageStore] + ); + // Get disabled tools from the tool toggle UI const disabledTools = useAtomValue(disabledToolsAtom); @@ -903,7 +952,9 @@ export default function NewChatPage() { break; case "error": - throw new Error(parsed.errorText || "Server error"); + throw Object.assign(new Error(parsed.errorText || "Server error"), { + errorCode: parsed.errorCode, + }); } } @@ -985,26 +1036,14 @@ export default function NewChatPage() { } else { toast.error("Failed to get response. Please try again."); } - // Update assistant message with error - setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId - ? { - ...m, - content: [ - { - type: "text", - text: - (premiumQuotaAlertMessage - ? PREMIUM_QUOTA_ASSISTANT_MESSAGE - : undefined) ?? - "Sorry, there was an error. Please try again.", - }, - ], - } - : m - ) - ); + await persistAssistantErrorMessage({ + threadId: currentThreadId, + assistantMsgId, + text: + (premiumQuotaAlertMessage + ? PREMIUM_QUOTA_ASSISTANT_MESSAGE + : undefined) ?? "Sorry, there was an error. Please try again.", + }); } finally { setIsRunning(false); abortControllerRef.current = null; @@ -1028,6 +1067,7 @@ export default function NewChatPage() { setPendingUserImageUrls, toolsWithUI, setPremiumAlertForThread, + persistAssistantErrorMessage, ] ); @@ -1258,7 +1298,9 @@ export default function NewChatPage() { break; case "error": - throw new Error(parsed.errorText || "Server error"); + throw Object.assign(new Error(parsed.errorText || "Server error"), { + errorCode: parsed.errorCode, + }); } } @@ -1293,19 +1335,17 @@ export default function NewChatPage() { threadId: resumeThreadId, message: premiumQuotaAlertMessage, }); - setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId - ? { - ...m, - content: [{ type: "text", text: PREMIUM_QUOTA_ASSISTANT_MESSAGE }], - } - : m - ) - ); } else { toast.error("Failed to resume. Please try again."); } + await persistAssistantErrorMessage({ + threadId: resumeThreadId, + assistantMsgId, + text: + (premiumQuotaAlertMessage + ? PREMIUM_QUOTA_ASSISTANT_MESSAGE + : undefined) ?? "Sorry, there was an error. Please try again.", + }); } finally { setIsRunning(false); abortControllerRef.current = null; @@ -1318,6 +1358,7 @@ export default function NewChatPage() { tokenUsageStore, toolsWithUI, setPremiumAlertForThread, + persistAssistantErrorMessage, ] ); @@ -1589,7 +1630,9 @@ export default function NewChatPage() { break; case "error": - throw new Error(parsed.errorText || "Server error"); + throw Object.assign(new Error(parsed.errorText || "Server error"), { + errorCode: parsed.errorCode, + }); } } @@ -1653,25 +1696,14 @@ export default function NewChatPage() { } else { toast.error("Failed to regenerate response. Please try again."); } - setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId - ? { - ...m, - content: [ - { - type: "text", - text: - (premiumQuotaAlertMessage - ? PREMIUM_QUOTA_ASSISTANT_MESSAGE - : undefined) ?? - "Sorry, there was an error. Please try again.", - }, - ], - } - : m - ) - ); + await persistAssistantErrorMessage({ + threadId, + assistantMsgId, + text: + (premiumQuotaAlertMessage + ? PREMIUM_QUOTA_ASSISTANT_MESSAGE + : undefined) ?? "Sorry, there was an error. Please try again.", + }); } finally { setIsRunning(false); abortControllerRef.current = null; @@ -1685,6 +1717,7 @@ export default function NewChatPage() { tokenUsageStore, toolsWithUI, setPremiumAlertForThread, + persistAssistantErrorMessage, ] ); diff --git a/surfsense_web/lib/chat/streaming-state.ts b/surfsense_web/lib/chat/streaming-state.ts index ff8fdfbd4..9f2ac87a5 100644 --- a/surfsense_web/lib/chat/streaming-state.ts +++ b/surfsense_web/lib/chat/streaming-state.ts @@ -256,7 +256,7 @@ export type SSEEvent = }>; }; } - | { type: "error"; errorText: string }; + | { type: "error"; errorText: string; errorCode?: string }; /** * Async generator that reads an SSE stream and yields parsed JSON objects. From e6db050dfd6ae9d0bdd63597d261caf7151d8720 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 29 Apr 2026 21:58:17 +0530 Subject: [PATCH 13/68] feat(chat): add userId to premium alert handling and improve alert visibility in UI --- surfsense_backend/app/tasks/chat/stream_new_chat.py | 4 ++-- .../new-chat/[[...chat_id]]/page.tsx | 3 +++ surfsense_web/atoms/chat/premium-alert.atom.ts | 12 ++++++++++++ surfsense_web/components/assistant-ui/thread.tsx | 8 ++++---- 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 060dd23c6..ecc727b47 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -1581,7 +1581,7 @@ async def stream_new_chat( ) else: yield streaming_service.format_error( - "Buy more tokens to continue with this model, or switch to a free model.", + "Buy more tokens to continue with this model, or switch to a free model", error_code="PREMIUM_QUOTA_EXHAUSTED", ) yield streaming_service.format_done() @@ -2349,7 +2349,7 @@ async def stream_resume_chat( ) else: yield streaming_service.format_error( - "Buy more tokens to continue with this model, or switch to a free model.", + "Buy more tokens to continue with this model, or switch to a free model", error_code="PREMIUM_QUOTA_EXHAUSTED", ) yield streaming_service.format_done() diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index f775e1f06..6ec587f91 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -1032,6 +1032,7 @@ export default function NewChatPage() { setPremiumAlertForThread({ threadId: currentThreadId, message: premiumQuotaAlertMessage, + userId: currentUser?.id ?? null, }); } else { toast.error("Failed to get response. Please try again."); @@ -1334,6 +1335,7 @@ export default function NewChatPage() { setPremiumAlertForThread({ threadId: resumeThreadId, message: premiumQuotaAlertMessage, + userId: currentUser?.id ?? null, }); } else { toast.error("Failed to resume. Please try again."); @@ -1692,6 +1694,7 @@ export default function NewChatPage() { setPremiumAlertForThread({ threadId, message: premiumQuotaAlertMessage, + userId: currentUser?.id ?? null, }); } else { toast.error("Failed to regenerate response. Please try again."); diff --git a/surfsense_web/atoms/chat/premium-alert.atom.ts b/surfsense_web/atoms/chat/premium-alert.atom.ts index c0efc174f..1c837dd65 100644 --- a/surfsense_web/atoms/chat/premium-alert.atom.ts +++ b/surfsense_web/atoms/chat/premium-alert.atom.ts @@ -14,13 +14,25 @@ export const setPremiumAlertForThreadAtom = atom( payload: { threadId: number; message: string; + userId?: string | null; } ) => { + const storageKey = `surfsense-premium-alert-seen-v1:${payload.userId ?? "anonymous"}`; + + if (typeof window !== "undefined") { + const hasSeen = localStorage.getItem(storageKey) === "true"; + if (hasSeen) return; + } + const current = get(premiumAlertByThreadAtom); set(premiumAlertByThreadAtom, { ...current, [payload.threadId]: { message: payload.message }, }); + + if (typeof window !== "undefined") { + localStorage.setItem(storageKey, "true"); + } } ); diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index cb063fac3..3095556dc 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -161,15 +161,15 @@ const PremiumQuotaPinnedAlert: FC = () => { if (!alert) return null; return ( - <div className="mx-0 bg-amber-500/10 px-3 py-2 text-amber-100"> - <div className="flex items-start gap-2"> - <AlertCircle className="mt-0.5 size-4 shrink-0 text-amber-300" /> + <div className="mx-0 overflow-hidden rounded-2xl border-input bg-muted px-4 py-4 text-foreground select-none"> + <div className="flex items-center gap-2"> + <AlertCircle className="size-4 shrink-0 text-muted-foreground" /> <div className="min-w-0 flex-1"> <p className="text-sm">{alert.message}</p> </div> <button type="button" - className="inline-flex size-6 items-center justify-center text-amber-200 transition-colors hover:text-amber-50" + className="inline-flex size-6 items-center justify-center text-muted-foreground transition-colors hover:text-foreground" aria-label="Dismiss premium quota alert" onClick={() => clearPremiumAlertForThread(currentThreadId)} > From 222b27183fd9603637df2c31459dc74cc988ade9 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 29 Apr 2026 22:01:28 +0530 Subject: [PATCH 14/68] feat(chat): improve error handling and logging for premium quota exhaustion in chat operations --- .../new-chat/[[...chat_id]]/page.tsx | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 6ec587f91..a2985ab0c 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -1018,8 +1018,12 @@ export default function NewChatPage() { } return; } - console.error("[NewChatPage] Chat error:", error); const premiumQuotaAlertMessage = getPinnedPremiumQuotaErrorMessage(error); + if (premiumQuotaAlertMessage) { + console.info("[NewChatPage] Premium quota exhausted:", error); + } else { + console.error("[NewChatPage] Chat error:", error); + } // Track chat error trackChatError( @@ -1329,8 +1333,12 @@ export default function NewChatPage() { if (error instanceof Error && error.name === "AbortError") { return; } - console.error("[NewChatPage] Resume error:", error); const premiumQuotaAlertMessage = getPinnedPremiumQuotaErrorMessage(error); + if (premiumQuotaAlertMessage) { + console.info("[NewChatPage] Premium quota exhausted during resume:", error); + } else { + console.error("[NewChatPage] Resume error:", error); + } if (premiumQuotaAlertMessage) { setPremiumAlertForThread({ threadId: resumeThreadId, @@ -1357,6 +1365,7 @@ export default function NewChatPage() { pendingInterrupt, messages, searchSpaceId, + currentUser?.id, tokenUsageStore, toolsWithUI, setPremiumAlertForThread, @@ -1683,8 +1692,12 @@ export default function NewChatPage() { return; } batcher.dispose(); - console.error("[NewChatPage] Regeneration error:", error); const premiumQuotaAlertMessage = getPinnedPremiumQuotaErrorMessage(error); + if (premiumQuotaAlertMessage) { + console.info("[NewChatPage] Premium quota exhausted during regeneration:", error); + } else { + console.error("[NewChatPage] Regeneration error:", error); + } trackChatError( searchSpaceId, threadId, @@ -1717,6 +1730,7 @@ export default function NewChatPage() { searchSpaceId, messages, disabledTools, + currentUser?.id, tokenUsageStore, toolsWithUI, setPremiumAlertForThread, From d64543686fe6304f99eac9e62bbb86944895840f Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 11:56:41 +0530 Subject: [PATCH 15/68] feat(chat): unify error handling and logging for chat operations, enhancing clarity and consistency in error reporting --- .../app/routes/new_chat_routes.py | 1 + .../app/tasks/chat/stream_new_chat.py | 319 +++++++++++++++--- .../unit/test_stream_new_chat_contract.py | 119 +++++++ .../new-chat/[[...chat_id]]/page.tsx | 240 ++++++------- .../lib/chat/chat-error-classifier.ts | 273 +++++++++++++++ surfsense_web/lib/posthog/events.ts | 50 +++ 6 files changed, 831 insertions(+), 171 deletions(-) create mode 100644 surfsense_web/lib/chat/chat-error-classifier.ts diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index b5560d90d..0189dd139 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -1524,6 +1524,7 @@ async def regenerate_response( filesystem_selection=filesystem_selection, request_id=getattr(http_request.state, "request_id", "unknown"), user_image_data_urls=regenerate_image_urls or None, + flow="regenerate", ): yield chunk streaming_completed = True diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index ecc727b47..a0be55c1b 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -19,7 +19,7 @@ import re import time from collections.abc import AsyncGenerator from dataclasses import dataclass, field -from typing import Any +from typing import Any, Literal from uuid import UUID import anyio @@ -253,6 +253,98 @@ def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None: ) +def _log_chat_stream_error( + *, + flow: Literal["new", "resume", "regenerate"], + error_kind: str, + error_code: str | None, + severity: Literal["info", "warn", "error"], + is_expected: bool, + request_id: str | None, + thread_id: int | None, + search_space_id: int | None, + user_id: str | None, + message: str, + extra: dict[str, Any] | None = None, +) -> None: + payload: dict[str, Any] = { + "event": "chat_stream_error", + "flow": flow, + "error_kind": error_kind, + "error_code": error_code, + "severity": severity, + "is_expected": is_expected, + "request_id": request_id or "unknown", + "thread_id": thread_id, + "search_space_id": search_space_id, + "user_id": user_id, + "message": message, + } + if extra: + payload.update(extra) + + logger = logging.getLogger(__name__) + rendered = json.dumps(payload, ensure_ascii=False) + if severity == "error": + logger.error("[chat_stream_error] %s", rendered) + elif severity == "warn": + logger.warning("[chat_stream_error] %s", rendered) + else: + logger.info("[chat_stream_error] %s", rendered) + + +def _parse_error_payload(message: str) -> dict[str, Any] | None: + candidates = [message] + first_brace_idx = message.find("{") + if first_brace_idx >= 0: + candidates.append(message[first_brace_idx:]) + + for candidate in candidates: + try: + parsed = json.loads(candidate) + if isinstance(parsed, dict): + return parsed + except Exception: + continue + return None + + +def _classify_stream_exception( + exc: Exception, + *, + flow_label: str, +) -> tuple[str, str, Literal["info", "warn", "error"], bool, str]: + raw = str(exc) + parsed = _parse_error_payload(raw) + provider_error_type = "" + if parsed: + top_type = parsed.get("type") + if isinstance(top_type, str): + provider_error_type = top_type.lower() + nested = parsed.get("error") + if isinstance(nested, dict): + nested_type = nested.get("type") + if isinstance(nested_type, str): + provider_error_type = nested_type.lower() + + if provider_error_type == "rate_limit_error": + return ( + "rate_limited", + "RATE_LIMITED", + "warn", + True, + "This model is temporarily rate-limited. Please try again in a few seconds or switch models.", + ) + + return ( + "server_error", + "SERVER_ERROR", + "error", + False, + f"Error during {flow_label}: {raw}", + ) + + async def _stream_agent_events( agent: Any, config: dict[str, Any], @@ -1397,6 +1489,7 @@ async def stream_new_chat( filesystem_selection: FilesystemSelection | None = None, request_id: str | None = None, user_image_data_urls: list[str] | None = None, + flow: Literal["new", "regenerate"] = "new", ) -> AsyncGenerator[str, None]: """ Stream chat responses from the new SurfSense deep agent. @@ -1448,6 +1541,30 @@ async def stream_new_chat( _premium_reserved = 0 _premium_request_id: str | None = None + def _emit_stream_error( + *, + message: str, + error_kind: str = "server_error", + error_code: str = "SERVER_ERROR", + severity: Literal["info", "warn", "error"] = "error", + is_expected: bool = False, + extra: dict[str, Any] | None = None, + ) -> str: + _log_chat_stream_error( + flow=flow, + error_kind=error_kind, + error_code=error_code, + severity=severity, + is_expected=is_expected, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=message, + extra=extra, + ) + return streaming_service.format_error(message, error_code=error_code) + session = async_session_maker() try: # Mark AI as responding to this user for live collaboration @@ -1499,13 +1616,21 @@ async def stream_new_chat( ) ).resolved_llm_config_id except ValueError as pin_error: - yield streaming_service.format_error(str(pin_error)) + yield _emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", + ) yield streaming_service.format_done() return llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) if llm_load_error: - yield streaming_service.format_error(llm_load_error) + yield _emit_stream_error( + message=llm_load_error, + error_kind="server_error", + error_code="SERVER_ERROR", + ) yield streaming_service.format_done() return _perf_log.info( @@ -1541,13 +1666,6 @@ async def stream_new_chat( ) _premium_reserved = reserve_amount if not quota_result.allowed: - logging.getLogger(__name__).info( - "premium_quota_blocked_pinned_model thread_id=%s search_space_id=%s user_id=%s resolved_config_id=%s", - chat_id, - search_space_id, - user_id, - llm_config_id, - ) if requested_llm_config_id == 0: try: llm_config_id = ( @@ -1561,34 +1679,66 @@ async def stream_new_chat( ) ).resolved_llm_config_id except ValueError as pin_error: - yield streaming_service.format_error(str(pin_error)) + yield _emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", + ) yield streaming_service.format_done() return llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) if llm_load_error: - yield streaming_service.format_error(llm_load_error) + yield _emit_stream_error( + message=llm_load_error, + error_kind="server_error", + error_code="SERVER_ERROR", + ) yield streaming_service.format_done() return _premium_request_id = None _premium_reserved = 0 - logging.getLogger(__name__).info( - "premium_quota_auto_fallback_to_free thread_id=%s search_space_id=%s user_id=%s fallback_config_id=%s", - chat_id, - search_space_id, - user_id, - llm_config_id, + _log_chat_stream_error( + flow=flow, + error_kind="premium_quota_exhausted", + error_code="PREMIUM_QUOTA_EXHAUSTED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Premium quota exhausted on pinned model; auto-fallback switched to a free model" + ), + extra={ + "fallback_config_id": llm_config_id, + "auto_fallback": True, + }, ) else: - yield streaming_service.format_error( - "Buy more tokens to continue with this model, or switch to a free model", + yield _emit_stream_error( + message=( + "Buy more tokens to continue with this model, or switch to a free model" + ), + error_kind="premium_quota_exhausted", error_code="PREMIUM_QUOTA_EXHAUSTED", + severity="info", + is_expected=True, + extra={ + "resolved_config_id": llm_config_id, + "auto_fallback": False, + }, ) yield streaming_service.format_done() return if not llm: - yield streaming_service.format_error("Failed to create LLM instance") + yield _emit_stream_error( + message="Failed to create LLM instance", + error_kind="server_error", + error_code="SERVER_ERROR", + ) yield streaming_service.format_done() return @@ -2097,12 +2247,25 @@ async def stream_new_chat( # Handle any errors import traceback + ( + error_kind, + error_code, + severity, + is_expected, + user_message, + ) = _classify_stream_exception(e, flow_label="chat") error_message = f"Error during chat: {e!s}" print(f"[stream_new_chat] {error_message}") print(f"[stream_new_chat] Exception type: {type(e).__name__}") print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}") - yield streaming_service.format_error(error_message) + yield _emit_stream_error( + message=user_message, + error_kind=error_kind, + error_code=error_code, + severity=severity, + is_expected=is_expected, + ) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() @@ -2217,6 +2380,30 @@ async def stream_resume_chat( accumulator = start_turn() + def _emit_stream_error( + *, + message: str, + error_kind: str = "server_error", + error_code: str = "SERVER_ERROR", + severity: Literal["info", "warn", "error"] = "error", + is_expected: bool = False, + extra: dict[str, Any] | None = None, + ) -> str: + _log_chat_stream_error( + flow="resume", + error_kind=error_kind, + error_code=error_code, + severity=severity, + is_expected=is_expected, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=message, + extra=extra, + ) + return streaming_service.format_error(message, error_code=error_code) + session = async_session_maker() try: if user_id: @@ -2267,13 +2454,21 @@ async def stream_resume_chat( ) ).resolved_llm_config_id except ValueError as pin_error: - yield streaming_service.format_error(str(pin_error)) + yield _emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", + ) yield streaming_service.format_done() return llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) if llm_load_error: - yield streaming_service.format_error(llm_load_error) + yield _emit_stream_error( + message=llm_load_error, + error_kind="server_error", + error_code="SERVER_ERROR", + ) yield streaming_service.format_done() return _perf_log.info( @@ -2309,13 +2504,6 @@ async def stream_resume_chat( ) _resume_premium_reserved = reserve_amount if not quota_result.allowed: - logging.getLogger(__name__).info( - "premium_quota_blocked_pinned_model thread_id=%s search_space_id=%s user_id=%s resolved_config_id=%s", - chat_id, - search_space_id, - user_id, - llm_config_id, - ) if requested_llm_config_id == 0: try: llm_config_id = ( @@ -2329,34 +2517,66 @@ async def stream_resume_chat( ) ).resolved_llm_config_id except ValueError as pin_error: - yield streaming_service.format_error(str(pin_error)) + yield _emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", + ) yield streaming_service.format_done() return llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) if llm_load_error: - yield streaming_service.format_error(llm_load_error) + yield _emit_stream_error( + message=llm_load_error, + error_kind="server_error", + error_code="SERVER_ERROR", + ) yield streaming_service.format_done() return _resume_premium_request_id = None _resume_premium_reserved = 0 - logging.getLogger(__name__).info( - "premium_quota_auto_fallback_to_free thread_id=%s search_space_id=%s user_id=%s fallback_config_id=%s", - chat_id, - search_space_id, - user_id, - llm_config_id, + _log_chat_stream_error( + flow="resume", + error_kind="premium_quota_exhausted", + error_code="PREMIUM_QUOTA_EXHAUSTED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Premium quota exhausted on pinned model; auto-fallback switched to a free model" + ), + extra={ + "fallback_config_id": llm_config_id, + "auto_fallback": True, + }, ) else: - yield streaming_service.format_error( - "Buy more tokens to continue with this model, or switch to a free model", + yield _emit_stream_error( + message=( + "Buy more tokens to continue with this model, or switch to a free model" + ), + error_kind="premium_quota_exhausted", error_code="PREMIUM_QUOTA_EXHAUSTED", + severity="info", + is_expected=True, + extra={ + "resolved_config_id": llm_config_id, + "auto_fallback": False, + }, ) yield streaming_service.format_done() return if not llm: - yield streaming_service.format_error("Failed to create LLM instance") + yield _emit_stream_error( + message="Failed to create LLM instance", + error_kind="server_error", + error_code="SERVER_ERROR", + ) yield streaming_service.format_done() return @@ -2528,10 +2748,23 @@ async def stream_resume_chat( except Exception as e: import traceback + ( + error_kind, + error_code, + severity, + is_expected, + user_message, + ) = _classify_stream_exception(e, flow_label="resume") error_message = f"Error during resume: {e!s}" print(f"[stream_resume_chat] {error_message}") print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}") - yield streaming_service.format_error(error_message) + yield _emit_stream_error( + message=user_message, + error_kind=error_kind, + error_code=error_code, + severity=severity, + is_expected=is_expected, + ) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 034aa484c..1f8168837 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -1,9 +1,18 @@ +import inspect +import json +import logging +from pathlib import Path +import re + import pytest +import app.tasks.chat.stream_new_chat as stream_new_chat_module from app.tasks.chat.stream_new_chat import ( StreamResult, + _classify_stream_exception, _contract_enforcement_active, _evaluate_file_contract_outcome, + _log_chat_stream_error, _tool_output_has_error, ) @@ -45,3 +54,113 @@ def test_contract_enforcement_local_only(): result.filesystem_mode = "cloud" assert not _contract_enforcement_active(result) + + +def _extract_chat_stream_payload(record_message: str) -> dict: + prefix = "[chat_stream_error] " + assert record_message.startswith(prefix) + return json.loads(record_message[len(prefix) :]) + + +def test_unified_chat_stream_error_log_schema(caplog): + with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"): + _log_chat_stream_error( + flow="new", + error_kind="server_error", + error_code="SERVER_ERROR", + severity="warn", + is_expected=False, + request_id="req-123", + thread_id=101, + search_space_id=202, + user_id="user-1", + message="Error during chat: boom", + ) + + record = next(r for r in caplog.records if "[chat_stream_error]" in r.message) + payload = _extract_chat_stream_payload(record.message) + + required_keys = { + "event", + "flow", + "error_kind", + "error_code", + "severity", + "is_expected", + "request_id", + "thread_id", + "search_space_id", + "user_id", + "message", + } + assert required_keys.issubset(payload.keys()) + assert payload["event"] == "chat_stream_error" + assert payload["flow"] == "new" + assert payload["error_code"] == "SERVER_ERROR" + + +def test_premium_quota_uses_unified_chat_stream_log_shape(caplog): + with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"): + _log_chat_stream_error( + flow="resume", + error_kind="premium_quota_exhausted", + error_code="PREMIUM_QUOTA_EXHAUSTED", + severity="info", + is_expected=True, + request_id="req-premium", + thread_id=303, + search_space_id=404, + user_id="user-2", + message="Buy more tokens to continue with this model, or switch to a free model", + extra={"auto_fallback": False}, + ) + + record = next(r for r in caplog.records if "[chat_stream_error]" in r.message) + payload = _extract_chat_stream_payload(record.message) + assert payload["event"] == "chat_stream_error" + assert payload["error_kind"] == "premium_quota_exhausted" + assert payload["error_code"] == "PREMIUM_QUOTA_EXHAUSTED" + assert payload["flow"] == "resume" + assert payload["is_expected"] is True + assert payload["auto_fallback"] is False + + +def test_stream_error_emission_keeps_machine_error_codes(): + source = inspect.getsource(stream_new_chat_module) + format_error_calls = re.findall(r"format_error\(", source) + emitted_error_codes = set(re.findall(r'error_code="([A-Z_]+)"', source)) + + # Both new/resume stream paths now route through local emitters that always + # pass a machine-readable error_code. + assert len(format_error_calls) == 2 + assert { + "PREMIUM_QUOTA_EXHAUSTED", + "SERVER_ERROR", + }.issubset(emitted_error_codes) + assert 'flow: Literal["new", "regenerate"] = "new"' in source + assert "flow=flow" in source + assert 'flow="resume"' in source + + +def test_stream_exception_classifies_rate_limited(): + exc = Exception( + '{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}' + ) + kind, code, severity, is_expected, user_message = _classify_stream_exception( + exc, flow_label="chat" + ) + assert kind == "rate_limited" + assert code == "RATE_LIMITED" + assert severity == "warn" + assert is_expected is True + assert "temporarily rate-limited" in user_message + + +def test_premium_classification_is_error_code_driven(): + classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts" + source = classifier_path.read_text(encoding="utf-8") + + assert "PREMIUM_KEYWORDS" not in source + assert "RATE_LIMIT_KEYWORDS" not in source + assert "normalized.includes(" not in source + assert 'if (errorCode === "PREMIUM_QUOTA_EXHAUSTED") {' in source diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index a2985ab0c..ffd58e660 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -49,6 +49,10 @@ import { useMessagesSync } from "@/hooks/use-messages-sync"; import { getAgentFilesystemSelection } from "@/lib/agent-filesystem"; import { documentsApiService } from "@/lib/apis/documents-api.service"; import { getBearerToken } from "@/lib/auth-utils"; +import { + classifyChatError, + type ChatFlow, +} from "@/lib/chat/chat-error-classifier"; import { convertToThreadMessage } from "@/lib/chat/message-utils"; import { isPodcastGenerating, @@ -84,7 +88,8 @@ import { import { NotFoundError } from "@/lib/error"; import { trackChatCreated, - trackChatError, + trackChatBlocked, + trackChatErrorDetailed, trackChatMessageSent, trackChatResponseReceived, } from "@/lib/posthog/events"; @@ -201,26 +206,6 @@ const BASE_TOOLS_WITH_UI = new Set([ // "write_todos", // Disabled for now ]); -const PREMIUM_QUOTA_ASSISTANT_MESSAGE = - "I can’t continue with the current premium model because your premium tokens are exhausted. Switch to a free model or buy more tokens to continue."; - -function getPinnedPremiumQuotaErrorMessage(error: unknown): string | null { - if (!(error instanceof Error)) return null; - const withCode = error as Error & { errorCode?: string }; - if (withCode.errorCode === "PREMIUM_QUOTA_EXHAUSTED") { - return error.message; - } - const normalized = error.message.toLowerCase(); - if ( - !normalized.includes("premium tokens exhausted") - && !normalized.includes("premium token quota exceeded") - && !normalized.includes("buy more tokens") - ) { - return null; - } - return error.message; -} - export default function NewChatPage() { const params = useParams(); const queryClient = useQueryClient(); @@ -378,6 +363,81 @@ export default function NewChatPage() { return Number.isNaN(parsed) ? 0 : parsed; }, [params.chat_id]); + const handleChatFailure = useCallback( + async ({ + error, + flow, + threadId, + assistantMsgId, + }: { + error: unknown; + flow: ChatFlow; + threadId: number | null; + assistantMsgId: string; + }) => { + const normalized = classifyChatError({ + error, + flow, + context: { + searchSpaceId, + threadId, + }, + }); + + const logger = + normalized.severity === "error" + ? console.error + : normalized.severity === "warn" + ? console.warn + : console.info; + logger(`[NewChatPage] ${flow} ${normalized.kind}:`, error); + + const telemetryPayload = { + flow, + kind: normalized.kind, + error_code: normalized.errorCode, + severity: normalized.severity, + is_expected: normalized.isExpected, + message: normalized.userMessage, + }; + if (normalized.telemetryEvent === "chat_blocked") { + trackChatBlocked(searchSpaceId, threadId, telemetryPayload); + } else { + trackChatErrorDetailed(searchSpaceId, threadId, telemetryPayload); + } + + if (normalized.channel === "silent") { + return; + } + + if (normalized.channel === "pinned_inline") { + if (threadId) { + setPremiumAlertForThread({ + threadId, + message: normalized.userMessage, + userId: currentUser?.id ?? null, + }); + } + if (normalized.assistantMessage) { + await persistAssistantErrorMessage({ + threadId, + assistantMsgId, + text: normalized.assistantMessage, + }); + } + return; + } + + toast.error(normalized.userMessage); + }, + [ + currentUser?.id, + persistAssistantErrorMessage, + searchSpaceId, + setPremiumAlertForThread, + ] + ); + // Initialize thread and load messages // For new chats (no urlChatId), we use lazy creation - thread is created on first message const initializeThread = useCallback(async () => { @@ -1018,36 +1078,11 @@ export default function NewChatPage() { } return; } - const premiumQuotaAlertMessage = getPinnedPremiumQuotaErrorMessage(error); - if (premiumQuotaAlertMessage) { - console.info("[NewChatPage] Premium quota exhausted:", error); - } else { - console.error("[NewChatPage] Chat error:", error); - } - - // Track chat error - trackChatError( - searchSpaceId, - currentThreadId, - error instanceof Error ? error.message : "Unknown error" - ); - - if (premiumQuotaAlertMessage) { - setPremiumAlertForThread({ - threadId: currentThreadId, - message: premiumQuotaAlertMessage, - userId: currentUser?.id ?? null, - }); - } else { - toast.error("Failed to get response. Please try again."); - } - await persistAssistantErrorMessage({ + await handleChatFailure({ + error, + flow: "new", threadId: currentThreadId, assistantMsgId, - text: - (premiumQuotaAlertMessage - ? PREMIUM_QUOTA_ASSISTANT_MESSAGE - : undefined) ?? "Sorry, there was an error. Please try again.", }); } finally { setIsRunning(false); @@ -1071,8 +1106,7 @@ export default function NewChatPage() { pendingUserImageUrls, setPendingUserImageUrls, toolsWithUI, - setPremiumAlertForThread, - persistAssistantErrorMessage, + handleChatFailure, ] ); @@ -1333,28 +1367,11 @@ export default function NewChatPage() { if (error instanceof Error && error.name === "AbortError") { return; } - const premiumQuotaAlertMessage = getPinnedPremiumQuotaErrorMessage(error); - if (premiumQuotaAlertMessage) { - console.info("[NewChatPage] Premium quota exhausted during resume:", error); - } else { - console.error("[NewChatPage] Resume error:", error); - } - if (premiumQuotaAlertMessage) { - setPremiumAlertForThread({ - threadId: resumeThreadId, - message: premiumQuotaAlertMessage, - userId: currentUser?.id ?? null, - }); - } else { - toast.error("Failed to resume. Please try again."); - } - await persistAssistantErrorMessage({ + await handleChatFailure({ + error, + flow: "resume", threadId: resumeThreadId, assistantMsgId, - text: - (premiumQuotaAlertMessage - ? PREMIUM_QUOTA_ASSISTANT_MESSAGE - : undefined) ?? "Sorry, there was an error. Please try again.", }); } finally { setIsRunning(false); @@ -1365,11 +1382,9 @@ export default function NewChatPage() { pendingInterrupt, messages, searchSpaceId, - currentUser?.id, tokenUsageStore, toolsWithUI, - setPremiumAlertForThread, - persistAssistantErrorMessage, + handleChatFailure, ] ); @@ -1491,15 +1506,6 @@ export default function NewChatPage() { userQueryToDisplay = newUserQuery; } - // Remove the last two messages (user + assistant) from the UI immediately - // The backend will also delete them from the database - setMessages((prev) => { - if (prev.length >= 2) { - return prev.slice(0, -2); - } - return prev; - }); - // Start streaming setIsRunning(true); const controller = new AbortController(); @@ -1530,19 +1536,9 @@ export default function NewChatPage() { createdAt: new Date(), metadata: isEdit ? undefined : originalUserMessageMetadata, }; - setMessages((prev) => [...prev, userMessage]); - - // Add placeholder assistant message - setMessages((prev) => [ - ...prev, - { - id: assistantMsgId, - role: "assistant", - content: [{ type: "text", text: "" }], - createdAt: new Date(), - }, - ]); - + const userContentToPersist = isEdit + ? (editExtras?.userMessageContent ?? [{ type: "text", text: newUserQuery ?? "" }]) + : originalUserMessageContent || [{ type: "text", text: userQueryToDisplay || "" }]; try { const selection = await getAgentFilesystemSelection(searchSpaceId); const requestBody: Record<string, unknown> = { @@ -1570,6 +1566,22 @@ export default function NewChatPage() { throw new Error(`Backend error: ${response.status}`); } + // Only switch UI to regenerated placeholder messages after the backend accepts + // regenerate. This avoids local message loss when regenerate fails early (e.g. 400). + setMessages((prev) => { + const base = prev.length >= 2 ? prev.slice(0, -2) : prev; + return [ + ...base, + userMessage, + { + id: assistantMsgId, + role: "assistant", + content: [{ type: "text", text: "" }], + createdAt: new Date(), + }, + ]; + }); + const flushMessages = () => { setMessages((prev) => prev.map((m) => @@ -1654,10 +1666,6 @@ export default function NewChatPage() { if (contentParts.length > 0) { try { // Persist user message (for both edit and reload modes, since backend deleted it) - const userContentToPersist = isEdit - ? (editExtras?.userMessageContent ?? [{ type: "text", text: newUserQuery ?? "" }]) - : originalUserMessageContent || [{ type: "text", text: userQueryToDisplay || "" }]; - const savedUserMessage = await appendMessage(threadId, { role: "user", content: userContentToPersist, @@ -1692,33 +1700,11 @@ export default function NewChatPage() { return; } batcher.dispose(); - const premiumQuotaAlertMessage = getPinnedPremiumQuotaErrorMessage(error); - if (premiumQuotaAlertMessage) { - console.info("[NewChatPage] Premium quota exhausted during regeneration:", error); - } else { - console.error("[NewChatPage] Regeneration error:", error); - } - trackChatError( - searchSpaceId, - threadId, - error instanceof Error ? error.message : "Unknown error" - ); - if (premiumQuotaAlertMessage) { - setPremiumAlertForThread({ - threadId, - message: premiumQuotaAlertMessage, - userId: currentUser?.id ?? null, - }); - } else { - toast.error("Failed to regenerate response. Please try again."); - } - await persistAssistantErrorMessage({ + await handleChatFailure({ + error, + flow: "regenerate", threadId, assistantMsgId, - text: - (premiumQuotaAlertMessage - ? PREMIUM_QUOTA_ASSISTANT_MESSAGE - : undefined) ?? "Sorry, there was an error. Please try again.", }); } finally { setIsRunning(false); @@ -1730,11 +1716,9 @@ export default function NewChatPage() { searchSpaceId, messages, disabledTools, - currentUser?.id, tokenUsageStore, toolsWithUI, - setPremiumAlertForThread, - persistAssistantErrorMessage, + handleChatFailure, ] ); diff --git a/surfsense_web/lib/chat/chat-error-classifier.ts b/surfsense_web/lib/chat/chat-error-classifier.ts new file mode 100644 index 000000000..dc9bb09df --- /dev/null +++ b/surfsense_web/lib/chat/chat-error-classifier.ts @@ -0,0 +1,273 @@ +export type ChatFlow = "new" | "resume" | "regenerate"; + +export type ChatErrorKind = + | "premium_quota_exhausted" + | "auth_expired" + | "rate_limited" + | "network_offline" + | "stream_interrupted" + | "stream_parse_error" + | "tool_execution_error" + | "persist_message_failed" + | "server_error" + | "unknown"; + +export type ChatErrorChannel = "pinned_inline" | "toast" | "silent"; +export type ChatTelemetryEvent = "chat_blocked" | "chat_error"; +export type ChatErrorSeverity = "info" | "warn" | "error"; + +export interface NormalizedChatError { + kind: ChatErrorKind; + channel: ChatErrorChannel; + severity: ChatErrorSeverity; + telemetryEvent: ChatTelemetryEvent; + isExpected: boolean; + userMessage: string; + assistantMessage?: string; + rawMessage?: string; + errorCode?: string; + details?: Record<string, unknown>; +} + +export interface RawChatErrorInput { + error: unknown; + flow: ChatFlow; + context?: { + searchSpaceId?: number; + threadId?: number | null; + }; +} + +export const PREMIUM_QUOTA_ASSISTANT_MESSAGE = + "I can’t continue with the current premium model because your premium tokens are exhausted. Switch to a free model or buy more tokens to continue."; + +function getErrorMessage(error: unknown): string { + if (error instanceof Error) return error.message; + if (typeof error === "string") return error; + try { + return JSON.stringify(error); + } catch { + return "Unknown error"; + } +} + +function getErrorCode(error: unknown, parsedJson: Record<string, unknown> | null): string | undefined { + if (error instanceof Error) { + const withCode = error as Error & { errorCode?: string }; + if (withCode.errorCode) return withCode.errorCode; + } + + if (typeof error === "object" && error !== null) { + const withCode = error as { errorCode?: unknown }; + if (typeof withCode.errorCode === "string" && withCode.errorCode) { + return withCode.errorCode; + } + } + + if (parsedJson) { + const topLevelCode = parsedJson.errorCode; + if (typeof topLevelCode === "string" && topLevelCode) { + return topLevelCode; + } + } + + return undefined; +} + +function parseEmbeddedJson(text: string): Record<string, unknown> | null { + const candidates = [text]; + const firstBraceIdx = text.indexOf("{"); + if (firstBraceIdx >= 0) { + candidates.push(text.slice(firstBraceIdx)); + } + for (const candidate of candidates) { + try { + const parsed = JSON.parse(candidate); + if (typeof parsed === "object" && parsed !== null) { + return parsed as Record<string, unknown>; + } + } catch { + // noop + } + } + return null; +} + +function inferProviderErrorType(parsedJson: Record<string, unknown> | null): string | undefined { + if (!parsedJson) return undefined; + const topLevelType = parsedJson.type; + if (typeof topLevelType === "string" && topLevelType) return topLevelType; + const nestedError = parsedJson.error; + if (typeof nestedError === "object" && nestedError !== null) { + const nestedType = (nestedError as Record<string, unknown>).type; + if (typeof nestedType === "string" && nestedType) return nestedType; + } + return undefined; +} + +export function classifyChatError(input: RawChatErrorInput): NormalizedChatError { + const { error } = input; + const rawMessage = getErrorMessage(error); + const parsedJson = parseEmbeddedJson(rawMessage); + const errorCode = getErrorCode(error, parsedJson); + const providerErrorType = inferProviderErrorType(parsedJson); + const providerTypeNormalized = providerErrorType?.toLowerCase() ?? ""; + const errorName = error instanceof Error ? error.name : undefined; + + if (errorName === "AbortError") { + return { + kind: "stream_interrupted", + channel: "silent", + severity: "info", + telemetryEvent: "chat_error", + isExpected: true, + userMessage: "Request canceled.", + rawMessage, + errorCode, + details: { flow: input.flow }, + }; + } + + if (errorCode === "PREMIUM_QUOTA_EXHAUSTED") { + return { + kind: "premium_quota_exhausted", + channel: "pinned_inline", + severity: "info", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: + "Buy more tokens to continue with this model, or switch to a free model.", + assistantMessage: PREMIUM_QUOTA_ASSISTANT_MESSAGE, + rawMessage, + errorCode: errorCode ?? "PREMIUM_QUOTA_EXHAUSTED", + details: { flow: input.flow }, + }; + } + + if ( + errorCode === "AUTH_EXPIRED" || + errorCode === "UNAUTHORIZED" + ) { + return { + kind: "auth_expired", + channel: "toast", + severity: "warn", + telemetryEvent: "chat_error", + isExpected: true, + userMessage: "Your session expired. Please sign in again.", + rawMessage, + errorCode: errorCode ?? "AUTH_EXPIRED", + details: { flow: input.flow }, + }; + } + + if ( + errorCode === "RATE_LIMITED" || + providerTypeNormalized === "rate_limit_error" + ) { + return { + kind: "rate_limited", + channel: "toast", + severity: "warn", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: + "This model is temporarily rate-limited. Please try again in a few seconds or switch models.", + rawMessage, + errorCode: errorCode ?? "RATE_LIMITED", + details: { flow: input.flow, providerErrorType }, + }; + } + + if ( + errorCode === "NETWORK_ERROR" + ) { + return { + kind: "network_offline", + channel: "toast", + severity: "warn", + telemetryEvent: "chat_error", + isExpected: true, + userMessage: "Connection issue detected. Check your internet and try again.", + rawMessage, + errorCode: errorCode ?? "NETWORK_ERROR", + details: { flow: input.flow }, + }; + } + + if ( + errorCode === "STREAM_PARSE_ERROR" + ) { + return { + kind: "stream_parse_error", + channel: "toast", + severity: "error", + telemetryEvent: "chat_error", + isExpected: false, + userMessage: "We hit a response formatting issue. Please try again.", + rawMessage, + errorCode: errorCode ?? "STREAM_PARSE_ERROR", + details: { flow: input.flow }, + }; + } + + if ( + errorCode === "TOOL_EXECUTION_ERROR" + ) { + return { + kind: "tool_execution_error", + channel: "toast", + severity: "error", + telemetryEvent: "chat_error", + isExpected: false, + userMessage: "A tool failed while processing your request. Please try again.", + rawMessage, + errorCode: errorCode ?? "TOOL_EXECUTION_ERROR", + details: { flow: input.flow }, + }; + } + + if ( + errorCode === "PERSIST_MESSAGE_FAILED" + ) { + return { + kind: "persist_message_failed", + channel: "toast", + severity: "error", + telemetryEvent: "chat_error", + isExpected: false, + userMessage: "Response generated, but saving failed. Please retry once.", + rawMessage, + errorCode: errorCode ?? "PERSIST_MESSAGE_FAILED", + details: { flow: input.flow }, + }; + } + + if ( + errorCode === "SERVER_ERROR" + ) { + return { + kind: "server_error", + channel: "toast", + severity: "error", + telemetryEvent: "chat_error", + isExpected: false, + userMessage: "We couldn’t complete this response right now. Please try again.", + rawMessage, + errorCode: errorCode ?? "SERVER_ERROR", + details: { flow: input.flow, providerErrorType }, + }; + } + + return { + kind: "unknown", + channel: "toast", + severity: "error", + telemetryEvent: "chat_error", + isExpected: false, + userMessage: "We couldn’t complete this response right now. Please try again.", + rawMessage, + errorCode, + details: { flow: input.flow, providerErrorType }, + }; +} diff --git a/surfsense_web/lib/posthog/events.ts b/surfsense_web/lib/posthog/events.ts index 34ed3044d..30e58215a 100644 --- a/surfsense_web/lib/posthog/events.ts +++ b/surfsense_web/lib/posthog/events.ts @@ -1,5 +1,6 @@ import posthog from "posthog-js"; import { getConnectorTelemetryMeta } from "@/components/assistant-ui/connector-popup/constants/connector-constants"; +import type { ChatErrorKind, ChatFlow, ChatErrorSeverity } from "@/lib/chat/chat-error-classifier"; /** * PostHog Analytics Event Definitions @@ -139,6 +140,55 @@ export function trackChatError(searchSpaceId: number, chatId: number, error?: st }); } +export interface ChatFailureTelemetry { + flow: ChatFlow; + kind: ChatErrorKind; + error_code?: string; + severity: ChatErrorSeverity; + is_expected: boolean; + message?: string; +} + +export function trackChatBlocked( + searchSpaceId: number, + chatId: number | null, + payload: ChatFailureTelemetry +) { + safeCapture( + "chat_blocked", + compact({ + search_space_id: searchSpaceId, + chat_id: chatId ?? undefined, + flow: payload.flow, + kind: payload.kind, + error_code: payload.error_code, + severity: payload.severity, + is_expected: payload.is_expected, + message: payload.message, + }) + ); +} + +export function trackChatErrorDetailed( + searchSpaceId: number, + chatId: number | null, + payload: ChatFailureTelemetry +) { + safeCapture( + "chat_error", + compact({ + search_space_id: searchSpaceId, + chat_id: chatId ?? undefined, + flow: payload.flow, + kind: payload.kind, + error_code: payload.error_code, + severity: payload.severity, + is_expected: payload.is_expected, + message: payload.message, + }) + ); +} + /** * Track a message sent from the unauthenticated "free" / anonymous chat * flow. This is intentionally a separate event from `chat_message_sent` From fd4d0817d14939f0c2c9421dabc6b83213d7a17f Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 12:38:11 +0530 Subject: [PATCH 16/68] feat(chat): implement comprehensive error handling for chat operations, including detailed response parsing and improved user message persistence --- .../new-chat/[[...chat_id]]/page.tsx | 262 ++++++++++++------ 1 file changed, 180 insertions(+), 82 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index ffd58e660..b6afaf131 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -222,6 +222,7 @@ export default function NewChatPage() { interruptData: Record<string, unknown>; } | null>(null); const toolsWithUI = useMemo(() => new Set([...BASE_TOOLS_WITH_UI]), []); + const setMessageDocumentsMap = useSetAtom(messageDocumentsMapAtom); const persistAssistantErrorMessage = useCallback( async ({ @@ -267,14 +268,107 @@ export default function NewChatPage() { [tokenUsageStore] ); + const persistUserTurn = useCallback( + async ({ + threadId, + userMsgId, + content, + mentionedDocs, + logContext, + }: { + threadId: number | null; + userMsgId: string; + content: unknown; + mentionedDocs?: MentionedDocumentInfo[]; + logContext: string; + }) => { + if (!threadId) return null; + try { + const normalizedContent = Array.isArray(content) + ? ([...content] as unknown[]) + : [content]; + const hasMentionedDocumentsPart = normalizedContent.some((part) => + MentionedDocumentsPartSchema.safeParse(part).success + ); + if (mentionedDocs && mentionedDocs.length > 0 && !hasMentionedDocumentsPart) { + normalizedContent.push({ + type: "mentioned-documents", + documents: mentionedDocs, + }); + } + + const savedUserMessage = await appendMessage(threadId, { + role: "user", + content: normalizedContent as AppendMessage["content"], + }); + const newUserMsgId = `msg-${savedUserMessage.id}`; + setMessages((prev) => + prev.map((m) => (m.id === userMsgId ? { ...m, id: newUserMsgId } : m)) + ); + if (mentionedDocs && mentionedDocs.length > 0) { + setMessageDocumentsMap((prev) => { + const { [userMsgId]: _, ...rest } = prev; + return { + ...rest, + [newUserMsgId]: mentionedDocs, + }; + }); + } + return newUserMsgId; + } catch (err) { + console.error(`Failed to persist ${logContext} user message:`, err); + return null; + } + }, + [setMessageDocumentsMap] + ); + + const persistAssistantTurn = useCallback( + async ({ + threadId, + assistantMsgId, + content, + tokenUsage, + logContext, + onRemapped, + }: { + threadId: number | null; + assistantMsgId: string; + content: unknown; + tokenUsage?: Record<string, unknown>; + logContext: string; + onRemapped?: (newMsgId: string) => void; + }) => { + if (!threadId) return null; + try { + const savedMessage = await appendMessage(threadId, { + role: "assistant", + content: content as AppendMessage["content"], + token_usage: tokenUsage, + }); + const newMsgId = `msg-${savedMessage.id}`; + tokenUsageStore.rename(assistantMsgId, newMsgId); + setMessages((prev) => + prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) + ); + onRemapped?.(newMsgId); + return newMsgId; + } catch (err) { + console.error(`Failed to persist ${logContext} assistant message:`, err); + return null; + } + }, + [tokenUsageStore] + ); + // Get disabled tools from the tool toggle UI const disabledTools = useAtomValue(disabledToolsAtom); // Get mentioned document IDs from the composer. const mentionedDocumentIds = useAtomValue(mentionedDocumentIdsAtom); const mentionedDocuments = useAtomValue(mentionedDocumentsAtom); + const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom); const setMentionedDocuments = useSetAtom(mentionedDocumentsAtom); - const setMessageDocumentsMap = useSetAtom(messageDocumentsMapAtom); const setCurrentThreadState = useSetAtom(currentThreadAtom); const setPremiumAlertForThread = useSetAtom(setPremiumAlertForThreadAtom); const setTargetCommentId = useSetAtom(setTargetCommentIdAtom); @@ -1023,29 +1117,20 @@ export default function NewChatPage() { // Skip persistence for interrupted messages -- handleResume will persist the final version const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI); if (contentParts.length > 0 && !wasInterrupted) { - try { - const savedMessage = await appendMessage(currentThreadId, { - role: "assistant", - content: finalContent, - token_usage: tokenUsageData ?? undefined, - }); - - // Update message ID from temporary to database ID so comments work immediately - const newMsgId = `msg-${savedMessage.id}`; - tokenUsageStore.rename(assistantMsgId, newMsgId); - setMessages((prev) => - prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) - ); - - // Update pending interrupt with the new persisted message ID - setPendingInterrupt((prev) => - prev && prev.assistantMsgId === assistantMsgId - ? { ...prev, assistantMsgId: newMsgId } - : prev - ); - } catch (err) { - console.error("Failed to persist assistant message:", err); - } + await persistAssistantTurn({ + threadId: currentThreadId, + assistantMsgId, + content: finalContent, + tokenUsage: tokenUsageData ?? undefined, + logContext: "new chat", + onRemapped: (newMsgId) => { + setPendingInterrupt((prev) => + prev && prev.assistantMsgId === assistantMsgId + ? { ...prev, assistantMsgId: newMsgId } + : prev + ); + }, + }); // Track successful response trackChatResponseReceived(searchSpaceId, currentThreadId); @@ -1061,20 +1146,12 @@ export default function NewChatPage() { ); if (hasContent && currentThreadId) { const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); - try { - const savedMessage = await appendMessage(currentThreadId, { - role: "assistant", - content: partialContent, - }); - - // Update message ID from temporary to database ID - const newMsgId = `msg-${savedMessage.id}`; - setMessages((prev) => - prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) - ); - } catch (err) { - console.error("Failed to persist partial assistant message:", err); - } + await persistAssistantTurn({ + threadId: currentThreadId, + assistantMsgId, + content: partialContent, + logContext: "partial new chat", + }); } return; } @@ -1107,6 +1184,7 @@ export default function NewChatPage() { setPendingUserImageUrls, toolsWithUI, handleChatFailure, + persistAssistantTurn, ] ); @@ -1347,20 +1425,13 @@ export default function NewChatPage() { const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI); if (contentParts.length > 0) { - try { - const savedMessage = await appendMessage(resumeThreadId, { - role: "assistant", - content: finalContent, - token_usage: tokenUsageData ?? undefined, - }); - const newMsgId = `msg-${savedMessage.id}`; - tokenUsageStore.rename(assistantMsgId, newMsgId); - setMessages((prev) => - prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) - ); - } catch (err) { - console.error("Failed to persist resumed assistant message:", err); - } + await persistAssistantTurn({ + threadId: resumeThreadId, + assistantMsgId, + content: finalContent, + tokenUsage: tokenUsageData ?? undefined, + logContext: "resumed chat", + }); } } catch (error) { batcher.dispose(); @@ -1385,6 +1456,7 @@ export default function NewChatPage() { tokenUsageStore, toolsWithUI, handleChatFailure, + persistAssistantTurn, ] ); @@ -1462,6 +1534,7 @@ export default function NewChatPage() { editExtras?: { userMessageContent: ThreadMessageLike["content"]; userImages: NewChatUserImagePayload[]; + sourceUserMessageId?: string; } ) => { if (!threadId) { @@ -1487,11 +1560,13 @@ export default function NewChatPage() { let userQueryToDisplay: string | undefined; let originalUserMessageContent: ThreadMessageLike["content"] | null = null; let originalUserMessageMetadata: ThreadMessageLike["metadata"] | undefined; + let sourceUserMessageId: string | undefined = editExtras?.sourceUserMessageId; if (!isEdit) { // Reload mode - find and preserve the last user message content const lastUserMessage = [...messages].reverse().find((m) => m.role === "user"); if (lastUserMessage) { + sourceUserMessageId = lastUserMessage.id; originalUserMessageContent = lastUserMessage.content; originalUserMessageMetadata = lastUserMessage.metadata; // Extract text for the API request @@ -1524,6 +1599,8 @@ export default function NewChatPage() { const { contentParts, toolCallIndices } = contentPartsState; const batcher = new FrameBatchedUpdater(); let tokenUsageData: Record<string, unknown> | null = null; + let regenerateAccepted = false; + let userPersisted = false; // Add placeholder messages to UI // Always add back the user message (with new query for edit, or original content for reload) @@ -1539,6 +1616,10 @@ export default function NewChatPage() { const userContentToPersist = isEdit ? (editExtras?.userMessageContent ?? [{ type: "text", text: newUserQuery ?? "" }]) : originalUserMessageContent || [{ type: "text", text: userQueryToDisplay || "" }]; + const sourceMentionedDocs = + sourceUserMessageId && messageDocumentsMap[sourceUserMessageId] + ? messageDocumentsMap[sourceUserMessageId] + : []; try { const selection = await getAgentFilesystemSelection(searchSpaceId); const requestBody: Record<string, unknown> = { @@ -1565,6 +1646,7 @@ export default function NewChatPage() { if (!response.ok) { throw new Error(`Backend error: ${response.status}`); } + regenerateAccepted = true; // Only switch UI to regenerated placeholder messages after the backend accepts // regenerate. This avoids local message loss when regenerate fails early (e.g. 400). @@ -1581,6 +1663,12 @@ export default function NewChatPage() { }, ]; }); + if (sourceMentionedDocs.length > 0) { + setMessageDocumentsMap((prev) => ({ + ...prev, + [userMsgId]: sourceMentionedDocs, + })); + } const flushMessages = () => { setMessages((prev) => @@ -1664,47 +1752,45 @@ export default function NewChatPage() { // Persist messages after streaming completes const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI); if (contentParts.length > 0) { - try { - // Persist user message (for both edit and reload modes, since backend deleted it) - const savedUserMessage = await appendMessage(threadId, { - role: "user", - content: userContentToPersist, - }); + const persistedUserMsgId = await persistUserTurn({ + threadId, + userMsgId, + content: userContentToPersist, + mentionedDocs: sourceMentionedDocs, + logContext: "regenerated", + }); + userPersisted = Boolean(persistedUserMsgId); - // Update user message ID to database ID - const newUserMsgId = `msg-${savedUserMessage.id}`; - setMessages((prev) => - prev.map((m) => (m.id === userMsgId ? { ...m, id: newUserMsgId } : m)) - ); + await persistAssistantTurn({ + threadId, + assistantMsgId, + content: finalContent, + tokenUsage: tokenUsageData ?? undefined, + logContext: "regenerated", + }); - // Persist assistant message - const savedMessage = await appendMessage(threadId, { - role: "assistant", - content: finalContent, - token_usage: tokenUsageData ?? undefined, - }); - - const newMsgId = `msg-${savedMessage.id}`; - tokenUsageStore.rename(assistantMsgId, newMsgId); - setMessages((prev) => - prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) - ); - - trackChatResponseReceived(searchSpaceId, threadId); - } catch (err) { - console.error("Failed to persist regenerated message:", err); - } + trackChatResponseReceived(searchSpaceId, threadId); } } catch (error) { if (error instanceof Error && error.name === "AbortError") { return; } batcher.dispose(); + if (regenerateAccepted && !userPersisted) { + const persistedUserMsgId = await persistUserTurn({ + threadId, + userMsgId, + content: userContentToPersist, + mentionedDocs: sourceMentionedDocs, + logContext: "regenerated (stream error)", + }); + userPersisted = Boolean(persistedUserMsgId); + } await handleChatFailure({ error, flow: "regenerate", threadId, - assistantMsgId, + assistantMsgId: regenerateAccepted ? assistantMsgId : "no-persist-assistant", }); } finally { setIsRunning(false); @@ -1716,9 +1802,13 @@ export default function NewChatPage() { searchSpaceId, messages, disabledTools, + messageDocumentsMap, + setMessageDocumentsMap, tokenUsageStore, toolsWithUI, handleChatFailure, + persistAssistantTurn, + persistUserTurn, ] ); @@ -1733,7 +1823,15 @@ export default function NewChatPage() { } const userMessageContent = message.content as unknown as ThreadMessageLike["content"]; - await handleRegenerate(queryForApi, { userMessageContent, userImages }); + const sourceUserMessageId = + typeof (message as { id?: unknown }).id === "string" + ? ((message as { id?: string }).id ?? undefined) + : undefined; + await handleRegenerate(queryForApi, { + userMessageContent, + userImages, + sourceUserMessageId, + }); }, [handleRegenerate] ); From 35ea0eae53a24875e368111f294b366f48f2d9fa Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 14:03:09 +0530 Subject: [PATCH 17/68] feat(chat): enhance error classification and handling for thread busy scenarios, improving user feedback and response management --- .../app/tasks/chat/stream_new_chat.py | 106 +++++---- .../unit/test_stream_new_chat_contract.py | 33 ++- .../new-chat/[[...chat_id]]/page.tsx | 209 +++++++++++++----- .../components/free-chat/anonymous-chat.tsx | 16 +- .../components/free-chat/free-chat-page.tsx | 52 ++++- .../lib/chat/chat-error-classifier.ts | 17 ++ 6 files changed, 322 insertions(+), 111 deletions(-) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index a0be55c1b..d6ca5418c 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -19,6 +19,7 @@ import re import time from collections.abc import AsyncGenerator from dataclasses import dataclass, field +from functools import partial from typing import Any, Literal from uuid import UUID @@ -30,6 +31,7 @@ from sqlalchemy.orm import selectinload from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent from app.agents.new_chat.checkpointer import get_checkpointer +from app.agents.new_chat.errors import BusyError from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection from app.agents.new_chat.llm_config import ( AgentConfig, @@ -315,6 +317,15 @@ def _classify_stream_exception( flow_label: str, ) -> tuple[str, str, Literal["info", "warn", "error"], bool, str]: raw = str(exc) + if isinstance(exc, BusyError) or "Thread is busy with another request" in raw: + return ( + "thread_busy", + "THREAD_BUSY", + "warn", + True, + "Another response is still finishing for this thread. Please try again in a moment.", + ) + parsed = _parse_error_payload(raw) provider_error_type = "" if parsed: @@ -345,6 +356,37 @@ def _classify_stream_exception( ) +def _emit_stream_terminal_error( + *, + streaming_service: VercelStreamingService, + flow: str, + request_id: str | None, + thread_id: int, + search_space_id: int, + user_id: str | None, + message: str, + error_kind: str = "server_error", + error_code: str = "SERVER_ERROR", + severity: Literal["info", "warn", "error"] = "error", + is_expected: bool = False, + extra: dict[str, Any] | None = None, +) -> str: + _log_chat_stream_error( + flow=flow, + error_kind=error_kind, + error_code=error_code, + severity=severity, + is_expected=is_expected, + request_id=request_id, + thread_id=thread_id, + search_space_id=search_space_id, + user_id=user_id, + message=message, + extra=extra, + ) + return streaming_service.format_error(message, error_code=error_code) + + async def _stream_agent_events( agent: Any, config: dict[str, Any], @@ -1541,29 +1583,15 @@ async def stream_new_chat( _premium_reserved = 0 _premium_request_id: str | None = None - def _emit_stream_error( - *, - message: str, - error_kind: str = "server_error", - error_code: str = "SERVER_ERROR", - severity: Literal["info", "warn", "error"] = "error", - is_expected: bool = False, - extra: dict[str, Any] | None = None, - ) -> str: - _log_chat_stream_error( - flow=flow, - error_kind=error_kind, - error_code=error_code, - severity=severity, - is_expected=is_expected, - request_id=request_id, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - message=message, - extra=extra, - ) - return streaming_service.format_error(message, error_code=error_code) + _emit_stream_error = partial( + _emit_stream_terminal_error, + streaming_service=streaming_service, + flow=flow, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + ) session = async_session_maker() try: @@ -2380,29 +2408,15 @@ async def stream_resume_chat( accumulator = start_turn() - def _emit_stream_error( - *, - message: str, - error_kind: str = "server_error", - error_code: str = "SERVER_ERROR", - severity: Literal["info", "warn", "error"] = "error", - is_expected: bool = False, - extra: dict[str, Any] | None = None, - ) -> str: - _log_chat_stream_error( - flow="resume", - error_kind=error_kind, - error_code=error_code, - severity=severity, - is_expected=is_expected, - request_id=request_id, - thread_id=chat_id, - search_space_id=search_space_id, - user_id=user_id, - message=message, - extra=extra, - ) - return streaming_service.format_error(message, error_code=error_code) + _emit_stream_error = partial( + _emit_stream_terminal_error, + streaming_service=streaming_service, + flow="resume", + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + ) session = async_session_maker() try: diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 1f8168837..125177084 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -1,12 +1,13 @@ import inspect import json import logging -from pathlib import Path import re +from pathlib import Path import pytest import app.tasks.chat.stream_new_chat as stream_new_chat_module +from app.agents.new_chat.errors import BusyError from app.tasks.chat.stream_new_chat import ( StreamResult, _classify_stream_exception, @@ -130,14 +131,14 @@ def test_stream_error_emission_keeps_machine_error_codes(): format_error_calls = re.findall(r"format_error\(", source) emitted_error_codes = set(re.findall(r'error_code="([A-Z_]+)"', source)) - # Both new/resume stream paths now route through local emitters that always - # pass a machine-readable error_code. - assert len(format_error_calls) == 2 + # All stream paths should route through one shared terminal error emitter. + assert len(format_error_calls) == 1 assert { "PREMIUM_QUOTA_EXHAUSTED", "SERVER_ERROR", }.issubset(emitted_error_codes) assert 'flow: Literal["new", "regenerate"] = "new"' in source + assert "_emit_stream_terminal_error" in source assert "flow=flow" in source assert 'flow="resume"' in source @@ -156,6 +157,30 @@ def test_stream_exception_classifies_rate_limited(): assert "temporarily rate-limited" in user_message +def test_stream_exception_classifies_thread_busy(): + exc = BusyError(request_id="thread-123") + kind, code, severity, is_expected, user_message = _classify_stream_exception( + exc, flow_label="chat" + ) + assert kind == "thread_busy" + assert code == "THREAD_BUSY" + assert severity == "warn" + assert is_expected is True + assert "still finishing for this thread" in user_message + + +def test_stream_exception_classifies_thread_busy_from_message(): + exc = Exception("Thread is busy with another request") + kind, code, severity, is_expected, user_message = _classify_stream_exception( + exc, flow_label="chat" + ) + assert kind == "thread_busy" + assert code == "THREAD_BUSY" + assert severity == "warn" + assert is_expected is True + assert "still finishing for this thread" in user_message + + def test_premium_classification_is_error_code_driven(): classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts" source = classifier_path.read_text(encoding="utf-8") diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index b6afaf131..70e188612 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -67,6 +67,7 @@ import { type ContentPartsState, FrameBatchedUpdater, readSSEStream, + type SSEEvent, type ThinkingStepData, updateThinkingSteps, updateToolCall, @@ -136,6 +137,75 @@ function markInterruptsCompleted(contentParts: Array<{ type: string; result?: un } } +function toStreamTerminalError( + event: Extract<SSEEvent, { type: "error" }> +): Error & { errorCode?: string } { + return Object.assign(new Error(event.errorText || "Server error"), { + errorCode: event.errorCode, + }); +} + +async function toHttpResponseError(response: Response): Promise<Error & { errorCode?: string }> { + const statusDefaultCode = + response.status === 409 + ? "THREAD_BUSY" + : response.status === 429 + ? "RATE_LIMITED" + : response.status === 401 || response.status === 403 + ? "AUTH_EXPIRED" + : "SERVER_ERROR"; + + let rawBody = ""; + try { + rawBody = await response.text(); + } catch { + // noop + } + + let parsedBody: Record<string, unknown> | null = null; + if (rawBody) { + try { + const parsed = JSON.parse(rawBody); + if (typeof parsed === "object" && parsed !== null) { + parsedBody = parsed as Record<string, unknown>; + } + } catch { + // noop + } + } + + const detail = parsedBody?.detail; + const detailObject = + typeof detail === "object" && detail !== null ? (detail as Record<string, unknown>) : null; + const detailMessage = typeof detail === "string" ? detail : undefined; + const topLevelMessage = + typeof parsedBody?.message === "string" ? (parsedBody.message as string) : undefined; + const detailNestedMessage = + typeof detailObject?.message === "string" ? (detailObject.message as string) : undefined; + + const topLevelCode = + typeof parsedBody?.errorCode === "string" + ? parsedBody.errorCode + : typeof parsedBody?.error_code === "string" + ? parsedBody.error_code + : undefined; + const detailCode = + typeof detailObject?.errorCode === "string" + ? detailObject.errorCode + : typeof detailObject?.error_code === "string" + ? detailObject.error_code + : undefined; + + const errorCode = detailCode ?? topLevelCode ?? statusDefaultCode; + const message = + detailNestedMessage ?? + detailMessage ?? + topLevelMessage ?? + `Backend error: ${response.status}`; + + return Object.assign(new Error(message), { errorCode }); +} + /** * Zod schema for mentioned document info (for type-safe parsing) */ @@ -532,6 +602,43 @@ export default function NewChatPage() { ] ); + const handleStreamTerminalError = useCallback( + async ({ + error, + flow, + threadId, + assistantMsgId, + accepted, + onAbort, + onAcceptedStreamError, + }: { + error: unknown; + flow: ChatFlow; + threadId: number | null; + assistantMsgId: string; + accepted: boolean; + onAbort?: () => Promise<void>; + onAcceptedStreamError?: () => Promise<void>; + }) => { + if (error instanceof Error && error.name === "AbortError") { + await onAbort?.(); + return; + } + + if (accepted) { + await onAcceptedStreamError?.(); + } + + await handleChatFailure({ + error, + flow, + threadId, + assistantMsgId: accepted ? assistantMsgId : "no-persist-assistant", + }); + }, + [handleChatFailure] + ); + // Initialize thread and load messages // For new chats (no urlChatId), we use lazy creation - thread is created on first message const initializeThread = useCallback(async () => { @@ -880,6 +987,7 @@ export default function NewChatPage() { const { contentParts, toolCallIndices } = contentPartsState; let wasInterrupted = false; let tokenUsageData: Record<string, unknown> | null = null; + let newAccepted = false; // Add placeholder assistant message setMessages((prev) => [ @@ -951,8 +1059,9 @@ export default function NewChatPage() { }); if (!response.ok) { - throw new Error(`Backend error: ${response.status}`); + throw await toHttpResponseError(response); } + newAccepted = true; const flushMessages = () => { setMessages((prev) => @@ -1106,9 +1215,7 @@ export default function NewChatPage() { break; case "error": - throw Object.assign(new Error(parsed.errorText || "Server error"), { - errorCode: parsed.errorCode, - }); + throw toStreamTerminalError(parsed); } } @@ -1137,29 +1244,29 @@ export default function NewChatPage() { } } catch (error) { batcher.dispose(); - if (error instanceof Error && error.name === "AbortError") { - // Request was cancelled by user - persist partial response if any content was received - const hasContent = contentParts.some( - (part) => - (part.type === "text" && part.text.length > 0) || - (part.type === "tool-call" && toolsWithUI.has(part.toolName)) - ); - if (hasContent && currentThreadId) { - const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); - await persistAssistantTurn({ - threadId: currentThreadId, - assistantMsgId, - content: partialContent, - logContext: "partial new chat", - }); - } - return; - } - await handleChatFailure({ + await handleStreamTerminalError({ error, flow: "new", threadId: currentThreadId, assistantMsgId, + accepted: newAccepted, + onAbort: async () => { + // Request was cancelled by user - persist partial response if any content was received + const hasContent = contentParts.some( + (part) => + (part.type === "text" && part.text.length > 0) || + (part.type === "tool-call" && toolsWithUI.has(part.toolName)) + ); + if (hasContent && currentThreadId) { + const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); + await persistAssistantTurn({ + threadId: currentThreadId, + assistantMsgId, + content: partialContent, + logContext: "partial new chat", + }); + } + }, }); } finally { setIsRunning(false); @@ -1183,7 +1290,7 @@ export default function NewChatPage() { pendingUserImageUrls, setPendingUserImageUrls, toolsWithUI, - handleChatFailure, + handleStreamTerminalError, persistAssistantTurn, ] ); @@ -1221,6 +1328,7 @@ export default function NewChatPage() { }; const { contentParts, toolCallIndices } = contentPartsState; let tokenUsageData: Record<string, unknown> | null = null; + let resumeAccepted = false; const existingMsg = messages.find((m) => m.id === assistantMsgId); if (existingMsg && Array.isArray(existingMsg.content)) { @@ -1302,8 +1410,9 @@ export default function NewChatPage() { }); if (!response.ok) { - throw new Error(`Backend error: ${response.status}`); + throw await toHttpResponseError(response); } + resumeAccepted = true; const flushMessages = () => { setMessages((prev) => @@ -1415,9 +1524,7 @@ export default function NewChatPage() { break; case "error": - throw Object.assign(new Error(parsed.errorText || "Server error"), { - errorCode: parsed.errorCode, - }); + throw toStreamTerminalError(parsed); } } @@ -1435,14 +1542,12 @@ export default function NewChatPage() { } } catch (error) { batcher.dispose(); - if (error instanceof Error && error.name === "AbortError") { - return; - } - await handleChatFailure({ + await handleStreamTerminalError({ error, flow: "resume", threadId: resumeThreadId, assistantMsgId, + accepted: resumeAccepted, }); } finally { setIsRunning(false); @@ -1455,7 +1560,7 @@ export default function NewChatPage() { searchSpaceId, tokenUsageStore, toolsWithUI, - handleChatFailure, + handleStreamTerminalError, persistAssistantTurn, ] ); @@ -1644,7 +1749,7 @@ export default function NewChatPage() { }); if (!response.ok) { - throw new Error(`Backend error: ${response.status}`); + throw await toHttpResponseError(response); } regenerateAccepted = true; @@ -1741,9 +1846,7 @@ export default function NewChatPage() { break; case "error": - throw Object.assign(new Error(parsed.errorText || "Server error"), { - errorCode: parsed.errorCode, - }); + throw toStreamTerminalError(parsed); } } @@ -1772,25 +1875,25 @@ export default function NewChatPage() { trackChatResponseReceived(searchSpaceId, threadId); } } catch (error) { - if (error instanceof Error && error.name === "AbortError") { - return; - } batcher.dispose(); - if (regenerateAccepted && !userPersisted) { - const persistedUserMsgId = await persistUserTurn({ - threadId, - userMsgId, - content: userContentToPersist, - mentionedDocs: sourceMentionedDocs, - logContext: "regenerated (stream error)", - }); - userPersisted = Boolean(persistedUserMsgId); - } - await handleChatFailure({ + await handleStreamTerminalError({ error, flow: "regenerate", threadId, - assistantMsgId: regenerateAccepted ? assistantMsgId : "no-persist-assistant", + assistantMsgId, + accepted: regenerateAccepted, + onAcceptedStreamError: async () => { + if (!userPersisted) { + const persistedUserMsgId = await persistUserTurn({ + threadId, + userMsgId, + content: userContentToPersist, + mentionedDocs: sourceMentionedDocs, + logContext: "regenerated (stream error)", + }); + userPersisted = Boolean(persistedUserMsgId); + } + }, }); } finally { setIsRunning(false); @@ -1806,7 +1909,7 @@ export default function NewChatPage() { setMessageDocumentsMap, tokenUsageStore, toolsWithUI, - handleChatFailure, + handleStreamTerminalError, persistAssistantTurn, persistUserTurn, ] diff --git a/surfsense_web/components/free-chat/anonymous-chat.tsx b/surfsense_web/components/free-chat/anonymous-chat.tsx index b286c5316..3de2ca434 100644 --- a/surfsense_web/components/free-chat/anonymous-chat.tsx +++ b/surfsense_web/components/free-chat/anonymous-chat.tsx @@ -104,7 +104,13 @@ export function AnonymousChat({ model }: AnonymousChatProps) { setMessages((prev) => prev.filter((m) => m.id !== assistantId)); return; } - throw new Error(`Stream error: ${response.status}`); + const body = await response.text().catch(() => ""); + const errorCode = response.status === 409 ? "THREAD_BUSY" : "SERVER_ERROR"; + const message = + errorCode === "THREAD_BUSY" + ? "A previous response is still stopping. Please try again in a moment." + : `Stream error: ${response.status}`; + throw Object.assign(new Error(body || message), { errorCode }); } for await (const event of readSSEStream(response)) { @@ -115,10 +121,12 @@ export function AnonymousChat({ model }: AnonymousChatProps) { prev.map((m) => (m.id === assistantId ? { ...m, content: m.content + event.delta } : m)) ); } else if (event.type === "error") { + const message = + event.errorCode === "THREAD_BUSY" + ? "A previous response is still stopping. Please try again in a moment." + : event.errorText; setMessages((prev) => - prev.map((m) => - m.id === assistantId ? { ...m, content: m.content || event.errorText } : m - ) + prev.map((m) => (m.id === assistantId ? { ...m, content: m.content || message } : m)) ); } else if ("type" in event && event.type === "data-token-usage") { // After streaming completes, refresh quota diff --git a/surfsense_web/components/free-chat/free-chat-page.tsx b/surfsense_web/components/free-chat/free-chat-page.tsx index deac1fd00..dd6693b35 100644 --- a/surfsense_web/components/free-chat/free-chat-page.tsx +++ b/surfsense_web/components/free-chat/free-chat-page.tsx @@ -48,6 +48,48 @@ function parseCaptchaError(status: number, body: string): string | null { return null; } +function normalizeFreeChatErrorMessage(error: unknown): string { + if (!(error instanceof Error)) return "An unexpected error occurred"; + const code = (error as Error & { errorCode?: string }).errorCode; + if (code === "THREAD_BUSY") { + return "A previous response is still stopping. Please try again in a moment."; + } + return error.message || "An unexpected error occurred"; +} + +function toFreeChatHttpError(status: number, body: string): Error & { errorCode?: string } { + let errorCode: string | undefined; + let message = body || `Server error: ${status}`; + try { + const parsed = JSON.parse(body) as Record<string, unknown>; + const detail = + typeof parsed.detail === "object" && parsed.detail !== null + ? (parsed.detail as Record<string, unknown>) + : null; + errorCode = + (typeof detail?.error_code === "string" ? detail.error_code : undefined) ?? + (typeof detail?.errorCode === "string" ? detail.errorCode : undefined) ?? + (typeof parsed.error_code === "string" ? parsed.error_code : undefined) ?? + (typeof parsed.errorCode === "string" ? parsed.errorCode : undefined); + message = + (typeof detail?.message === "string" ? detail.message : undefined) ?? + (typeof parsed.message === "string" ? parsed.message : undefined) ?? + (typeof parsed.detail === "string" ? parsed.detail : undefined) ?? + message; + } catch { + // non-json response + } + + if (!errorCode) { + if (status === 409) errorCode = "THREAD_BUSY"; + else if (status === 429) errorCode = "RATE_LIMITED"; + else if (status === 401 || status === 403) errorCode = "AUTH_EXPIRED"; + else errorCode = "SERVER_ERROR"; + } + + return Object.assign(new Error(message), { errorCode }); +} + export function FreeChatPage() { const anonMode = useAnonymousMode(); const modelSlug = anonMode.isAnonymous ? anonMode.modelSlug : ""; @@ -117,7 +159,7 @@ export function FreeChatPage() { const body = await response.text().catch(() => ""); const captchaCode = parseCaptchaError(response.status, body); if (captchaCode) return "captcha"; - throw new Error(body || `Server error: ${response.status}`); + throw toFreeChatHttpError(response.status, body); } const currentThinkingSteps = new Map<string, ThinkingStepData>(); @@ -187,7 +229,9 @@ export function FreeChatPage() { break; case "error": - throw new Error(parsed.errorText || "Server error"); + throw Object.assign(new Error(parsed.errorText || "Server error"), { + errorCode: parsed.errorCode, + }); } } batcher.flush(); @@ -277,7 +321,7 @@ export function FreeChatPage() { } catch (error) { if (error instanceof Error && error.name === "AbortError") return; console.error("[FreeChatPage] Chat error:", error); - const errorText = error instanceof Error ? error.message : "An unexpected error occurred"; + const errorText = normalizeFreeChatErrorMessage(error); setMessages((prev) => prev.map((m) => m.id === assistantMsgId @@ -336,7 +380,7 @@ export function FreeChatPage() { } catch (error) { if (error instanceof Error && error.name === "AbortError") return; console.error("[FreeChatPage] Retry error:", error); - const errorText = error instanceof Error ? error.message : "An unexpected error occurred"; + const errorText = normalizeFreeChatErrorMessage(error); setMessages((prev) => prev.map((m) => m.id === assistantMsgId diff --git a/surfsense_web/lib/chat/chat-error-classifier.ts b/surfsense_web/lib/chat/chat-error-classifier.ts index dc9bb09df..4341f7dc5 100644 --- a/surfsense_web/lib/chat/chat-error-classifier.ts +++ b/surfsense_web/lib/chat/chat-error-classifier.ts @@ -2,6 +2,7 @@ export type ChatFlow = "new" | "resume" | "regenerate"; export type ChatErrorKind = | "premium_quota_exhausted" + | "thread_busy" | "auth_expired" | "rate_limited" | "network_offline" @@ -144,6 +145,22 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } + if ( + errorCode === "THREAD_BUSY" + ) { + return { + kind: "thread_busy", + channel: "toast", + severity: "warn", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: "A previous response is still stopping. Please try again in a moment.", + rawMessage, + errorCode: errorCode ?? "THREAD_BUSY", + details: { flow: input.flow }, + }; + } + if ( errorCode === "AUTH_EXPIRED" || errorCode === "UNAUTHORIZED" From f60e742facdd2933d56958fe82de923c7aefab0f Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 14:58:56 +0530 Subject: [PATCH 18/68] feat(chat): implement pre-accept failure handling and unified retry messaging for chat operations, enhancing user experience and error management --- .../unit/test_stream_new_chat_contract.py | 72 ++++++++ .../new-chat/[[...chat_id]]/page.tsx | 173 ++++++++++++++---- .../lib/chat/chat-error-classifier.ts | 24 ++- 3 files changed, 229 insertions(+), 40 deletions(-) diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 125177084..9f4280063 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -189,3 +189,75 @@ def test_premium_classification_is_error_code_driven(): assert "RATE_LIMIT_KEYWORDS" not in source assert "normalized.includes(" not in source assert 'if (errorCode === "PREMIUM_QUOTA_EXHAUSTED") {' in source + + +def test_stream_terminal_error_handler_has_pre_accept_soft_rollback_hook(): + page_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx" + ) + source = page_path.read_text(encoding="utf-8") + + assert "onPreAcceptFailure?: () => Promise<void>;" in source + assert "if (!accepted) {" in source + assert "await onPreAcceptFailure?.();" in source + assert "await onAcceptedStreamError?.();" in source + assert "setMessages((prev) => prev.filter((m) => m.id !== userMsgId));" in source + assert "setMessageDocumentsMap((prev) => {" in source + + +def test_toast_only_pre_accept_policy_has_no_inline_failed_marker(): + user_message_path = ( + Path(__file__).resolve().parents[3] / "surfsense_web/components/assistant-ui/user-message.tsx" + ) + source = user_message_path.read_text(encoding="utf-8") + + assert "Not sent. Edit and retry." not in source + assert "failed_pre_accept" not in source + + +def test_network_send_failures_use_unified_retry_toast_message(): + classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts" + classifier_source = classifier_path.read_text(encoding="utf-8") + page_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx" + ) + page_source = page_path.read_text(encoding="utf-8") + + assert '"send_failed_pre_accept"' in classifier_source + assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source + assert "if (withCode.code) return withCode.code;" in classifier_source + assert 'userMessage: "Message not sent. Please retry."' in classifier_source + assert 'userMessage: "Connection issue. Please try again."' in classifier_source + assert "tagPreAcceptSendFailure(error)" in page_source + assert 'existingCode === "THREAD_BUSY"' in page_source + assert 'existingCode === "AUTH_EXPIRED"' in page_source + assert 'existingCode === "UNAUTHORIZED"' in page_source + assert 'existingCode === "RATE_LIMITED"' in page_source + assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in page_source + assert 'errorCode: "NETWORK_ERROR"' not in page_source + assert "Failed to start chat. Please try again." not in page_source + + +def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows(): + page_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx" + ) + source = page_path.read_text(encoding="utf-8") + + # Each flow tracks accepted boundary and passes it into shared terminal handling. + assert "let newAccepted = false;" in source + assert "let resumeAccepted = false;" in source + assert "let regenerateAccepted = false;" in source + assert "accepted: newAccepted," in source + assert "accepted: resumeAccepted," in source + assert "accepted: regenerateAccepted," in source + + # Pre-accept abort in resume/regenerate exits without persistence. + assert "if (!resumeAccepted) return;" in source + assert "if (!regenerateAccepted) return;" in source + + # New flow persists only when accepted and not already persisted. + assert "if (newAccepted && !userPersisted) {" in source diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 70e188612..80ee9e9cd 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -206,6 +206,26 @@ async function toHttpResponseError(response: Response): Promise<Error & { errorC return Object.assign(new Error(message), { errorCode }); } +function tagPreAcceptSendFailure(error: unknown): unknown { + if (error instanceof Error) { + const withCode = error as Error & { errorCode?: string; code?: string }; + const existingCode = withCode.errorCode ?? withCode.code; + if ( + existingCode === "THREAD_BUSY" || + existingCode === "AUTH_EXPIRED" || + existingCode === "UNAUTHORIZED" || + existingCode === "RATE_LIMITED" + ) { + return Object.assign(error, { errorCode: existingCode }); + } + return Object.assign(error, { errorCode: "SEND_FAILED_PRE_ACCEPT" }); + } + + return Object.assign(new Error("Failed to send message before stream acceptance"), { + errorCode: "SEND_FAILED_PRE_ACCEPT", + }); +} + /** * Zod schema for mentioned document info (for type-safe parsing) */ @@ -610,6 +630,7 @@ export default function NewChatPage() { assistantMsgId, accepted, onAbort, + onPreAcceptFailure, onAcceptedStreamError, }: { error: unknown; @@ -618,6 +639,7 @@ export default function NewChatPage() { assistantMsgId: string; accepted: boolean; onAbort?: () => Promise<void>; + onPreAcceptFailure?: () => Promise<void>; onAcceptedStreamError?: () => Promise<void>; }) => { if (error instanceof Error && error.name === "AbortError") { @@ -625,12 +647,14 @@ export default function NewChatPage() { return; } - if (accepted) { + if (!accepted) { + await onPreAcceptFailure?.(); + } else { await onAcceptedStreamError?.(); } await handleChatFailure({ - error, + error: !accepted ? tagPreAcceptSendFailure(error) : error, flow, threadId, assistantMsgId: accepted ? assistantMsgId : "no-persist-assistant", @@ -863,7 +887,12 @@ export default function NewChatPage() { ); } catch (error) { console.error("[NewChatPage] Failed to create thread:", error); - toast.error("Failed to start chat. Please try again."); + await handleChatFailure({ + error: tagPreAcceptSendFailure(error), + flow: "new", + threadId: currentThreadId, + assistantMsgId: "no-persist-assistant", + }); return; } } @@ -948,27 +977,6 @@ export default function NewChatPage() { }); } - appendMessage(currentThreadId, { - role: "user", - content: persistContent, - }) - .then((savedMessage) => { - const newUserMsgId = `msg-${savedMessage.id}`; - setMessages((prev) => - prev.map((m) => (m.id === userMsgId ? { ...m, id: newUserMsgId } : m)) - ); - setMessageDocumentsMap((prev) => { - const docs = prev[userMsgId]; - if (!docs) return prev; - const { [userMsgId]: _, ...rest } = prev; - return { ...rest, [newUserMsgId]: docs }; - }); - if (isNewThread) { - queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] }); - } - }) - .catch((err) => console.error("Failed to persist user message:", err)); - // Start streaming response setIsRunning(true); const controller = new AbortController(); @@ -988,17 +996,7 @@ export default function NewChatPage() { let wasInterrupted = false; let tokenUsageData: Record<string, unknown> | null = null; let newAccepted = false; - - // Add placeholder assistant message - setMessages((prev) => [ - ...prev, - { - id: assistantMsgId, - role: "assistant", - content: [{ type: "text", text: "" }], - createdAt: new Date(), - }, - ]); + let userPersisted = false; try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; @@ -1062,6 +1060,15 @@ export default function NewChatPage() { throw await toHttpResponseError(response); } newAccepted = true; + setMessages((prev) => [ + ...prev, + { + id: assistantMsgId, + role: "assistant", + content: [{ type: "text", text: "" }], + createdAt: new Date(), + }, + ]); const flushMessages = () => { setMessages((prev) => @@ -1224,6 +1231,20 @@ export default function NewChatPage() { // Skip persistence for interrupted messages -- handleResume will persist the final version const finalContent = buildContentForPersistence(contentPartsState, toolsWithUI); if (contentParts.length > 0 && !wasInterrupted) { + if (!userPersisted) { + const persistedUserMsgId = await persistUserTurn({ + threadId: currentThreadId, + userMsgId, + content: persistContent, + mentionedDocs: allMentionedDocs, + logContext: "new chat", + }); + userPersisted = Boolean(persistedUserMsgId); + if (userPersisted && isNewThread) { + queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] }); + } + } + await persistAssistantTurn({ threadId: currentThreadId, assistantMsgId, @@ -1251,6 +1272,20 @@ export default function NewChatPage() { assistantMsgId, accepted: newAccepted, onAbort: async () => { + if (newAccepted && !userPersisted) { + const persistedUserMsgId = await persistUserTurn({ + threadId: currentThreadId, + userMsgId, + content: persistContent, + mentionedDocs: allMentionedDocs, + logContext: "new chat (aborted)", + }); + userPersisted = Boolean(persistedUserMsgId); + if (userPersisted && isNewThread) { + queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] }); + } + } + // Request was cancelled by user - persist partial response if any content was received const hasContent = contentParts.some( (part) => @@ -1267,6 +1302,29 @@ export default function NewChatPage() { }); } }, + onAcceptedStreamError: async () => { + if (!userPersisted) { + const persistedUserMsgId = await persistUserTurn({ + threadId: currentThreadId, + userMsgId, + content: persistContent, + mentionedDocs: allMentionedDocs, + logContext: "new chat (stream error)", + }); + userPersisted = Boolean(persistedUserMsgId); + if (userPersisted && isNewThread) { + queryClient.invalidateQueries({ queryKey: ["threads", String(searchSpaceId)] }); + } + } + }, + onPreAcceptFailure: async () => { + setMessages((prev) => prev.filter((m) => m.id !== userMsgId)); + setMessageDocumentsMap((prev) => { + if (!(userMsgId in prev)) return prev; + const { [userMsgId]: _removed, ...rest } = prev; + return rest; + }); + }, }); } finally { setIsRunning(false); @@ -1291,7 +1349,9 @@ export default function NewChatPage() { setPendingUserImageUrls, toolsWithUI, handleStreamTerminalError, + handleChatFailure, persistAssistantTurn, + persistUserTurn, ] ); @@ -1548,6 +1608,22 @@ export default function NewChatPage() { threadId: resumeThreadId, assistantMsgId, accepted: resumeAccepted, + onAbort: async () => { + if (!resumeAccepted) return; + const hasContent = contentParts.some( + (part) => + (part.type === "text" && part.text.length > 0) || + (part.type === "tool-call" && toolsWithUI.has(part.toolName)) + ); + if (!hasContent) return; + const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); + await persistAssistantTurn({ + threadId: resumeThreadId, + assistantMsgId, + content: partialContent, + logContext: "partial resumed chat", + }); + }, }); } finally { setIsRunning(false); @@ -1882,6 +1958,33 @@ export default function NewChatPage() { threadId, assistantMsgId, accepted: regenerateAccepted, + onAbort: async () => { + if (!regenerateAccepted) return; + if (!userPersisted) { + const persistedUserMsgId = await persistUserTurn({ + threadId, + userMsgId, + content: userContentToPersist, + mentionedDocs: sourceMentionedDocs, + logContext: "regenerated (aborted)", + }); + userPersisted = Boolean(persistedUserMsgId); + } + const hasContent = contentParts.some( + (part) => + (part.type === "text" && part.text.length > 0) || + (part.type === "tool-call" && toolsWithUI.has(part.toolName)) + ); + if (!hasContent) return; + const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); + await persistAssistantTurn({ + threadId, + assistantMsgId, + content: partialContent, + tokenUsage: tokenUsageData ?? undefined, + logContext: "partial regenerated chat", + }); + }, onAcceptedStreamError: async () => { if (!userPersisted) { const persistedUserMsgId = await persistUserTurn({ diff --git a/surfsense_web/lib/chat/chat-error-classifier.ts b/surfsense_web/lib/chat/chat-error-classifier.ts index 4341f7dc5..57341a4c3 100644 --- a/surfsense_web/lib/chat/chat-error-classifier.ts +++ b/surfsense_web/lib/chat/chat-error-classifier.ts @@ -3,6 +3,7 @@ export type ChatFlow = "new" | "resume" | "regenerate"; export type ChatErrorKind = | "premium_quota_exhausted" | "thread_busy" + | "send_failed_pre_accept" | "auth_expired" | "rate_limited" | "network_offline" @@ -54,8 +55,9 @@ function getErrorMessage(error: unknown): string { function getErrorCode(error: unknown, parsedJson: Record<string, unknown> | null): string | undefined { if (error instanceof Error) { - const withCode = error as Error & { errorCode?: string }; + const withCode = error as Error & { errorCode?: string; code?: string }; if (withCode.errorCode) return withCode.errorCode; + if (withCode.code) return withCode.code; } if (typeof error === "object" && error !== null) { @@ -161,6 +163,20 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } + if (errorCode === "SEND_FAILED_PRE_ACCEPT") { + return { + kind: "send_failed_pre_accept", + channel: "toast", + severity: "warn", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: "Message not sent. Please retry.", + rawMessage, + errorCode: errorCode ?? "SEND_FAILED_PRE_ACCEPT", + details: { flow: input.flow }, + }; + } + if ( errorCode === "AUTH_EXPIRED" || errorCode === "UNAUTHORIZED" @@ -196,16 +212,14 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } - if ( - errorCode === "NETWORK_ERROR" - ) { + if (errorCode === "NETWORK_ERROR") { return { kind: "network_offline", channel: "toast", severity: "warn", telemetryEvent: "chat_error", isExpected: true, - userMessage: "Connection issue detected. Check your internet and try again.", + userMessage: "Connection issue. Please try again.", rawMessage, errorCode: errorCode ?? "NETWORK_ERROR", details: { flow: input.flow }, From e651c41372b1d0946cd63ff96c224ce6beeb7acc Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Thu, 30 Apr 2026 03:13:58 -0700 Subject: [PATCH 19/68] feat: enhance tool input streaming and agent action handling for improved chat experience --- .../app/services/new_streaming_service.py | 13 +- .../app/tasks/chat/stream_new_chat.py | 256 +++++++-- .../tasks/chat/test_extract_chunk_parts.py | 43 ++ .../tasks/chat/test_tool_input_streaming.py | 527 ++++++++++++++++++ .../new-chat/[[...chat_id]]/page.tsx | 277 +++++---- .../atoms/chat/agent-actions.atom.ts | 194 ------- .../agent-action-log/action-log-sheet.tsx | 33 +- .../assistant-ui/revert-turn-button.tsx | 60 +- .../components/assistant-ui/tool-fallback.tsx | 475 ++++++++++++---- .../components/free-chat/free-chat-page.tsx | 24 +- .../contracts/types/chat-messages.types.ts | 9 +- .../hooks/use-agent-actions-query.ts | 416 ++++++++++++++ surfsense_web/hooks/use-messages-sync.ts | 8 + surfsense_web/lib/chat/streaming-state.ts | 60 +- surfsense_web/zero/schema/chat.ts | 7 + 15 files changed, 1857 insertions(+), 545 deletions(-) create mode 100644 surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py delete mode 100644 surfsense_web/atoms/chat/agent-actions.atom.ts create mode 100644 surfsense_web/hooks/use-agent-actions-query.ts diff --git a/surfsense_backend/app/services/new_streaming_service.py b/surfsense_backend/app/services/new_streaming_service.py index 5dbae91c5..3531d37af 100644 --- a/surfsense_backend/app/services/new_streaming_service.py +++ b/surfsense_backend/app/services/new_streaming_service.py @@ -595,8 +595,17 @@ class VercelStreamingService: Format the start of tool input streaming. Args: - tool_call_id: The unique tool call identifier (synthetic, derived - from LangGraph ``run_id`` so the frontend has a stable card id). + tool_call_id: The unique tool call identifier. May be EITHER the + synthetic ``call_<run_id>`` id derived from LangGraph + ``run_id`` (legacy / ``SURFSENSE_ENABLE_STREAM_PARITY_V2`` + OFF, or the unmatched-fallback path under parity_v2) OR + the authoritative LangChain ``tool_call.id`` (parity_v2 + path: when the provider streams ``tool_call_chunks`` we + register the ``index`` and reuse the lc-id as the card + id so live ``tool-input-delta`` events can be routed + without a downstream join). Either way, the same id is + preserved across ``tool-input-start`` / ``-delta`` / + ``-available`` / ``tool-output-available`` for one call. tool_name: The name of the tool being called. langchain_tool_call_id: Optional authoritative LangChain ``tool_call.id``. When set, surfaces as diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 1493c4326..c94945bb1 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -338,6 +338,42 @@ def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None: ) +def _legacy_match_lc_id( + pending_tool_call_chunks: list[dict[str, Any]], + tool_name: str, + run_id: str, + lc_tool_call_id_by_run: dict[str, str], +) -> str | None: + """Best-effort match a buffered ``tool_call_chunk`` to a tool name. + + Pure extract of the legacy in-line match used at ``on_tool_start`` for + parity_v2-OFF and unmatched (chunk path didn't register an index for + this call) tools. Pops the next id-bearing chunk whose ``name`` + matches ``tool_name`` (or any id-bearing chunk as a fallback) and + returns its id. Mutates ``pending_tool_call_chunks`` and + ``lc_tool_call_id_by_run`` in place. + """ + matched_idx: int | None = None + for idx, tcc in enumerate(pending_tool_call_chunks): + if tcc.get("name") == tool_name and tcc.get("id"): + matched_idx = idx + break + if matched_idx is None: + for idx, tcc in enumerate(pending_tool_call_chunks): + if tcc.get("id"): + matched_idx = idx + break + if matched_idx is None: + return None + matched = pending_tool_call_chunks.pop(matched_idx) + candidate = matched.get("id") + if isinstance(candidate, str) and candidate: + if run_id: + lc_tool_call_id_by_run[run_id] = candidate + return candidate + return None + + async def _stream_agent_events( agent: Any, config: dict[str, Any], @@ -403,10 +439,28 @@ async def _stream_agent_events( # ``tool_call_chunks`` from ``on_chat_model_stream``, key them by # name, and pop the next unconsumed entry at ``on_tool_start``. The # authoritative id is later filled in at ``on_tool_end`` from - # ``ToolMessage.tool_call_id``. + # ``ToolMessage.tool_call_id``. Under parity_v2 we ALSO short-circuit + # this list for chunks that already registered into ``index_to_meta`` + # below — so this list is reserved for the parity_v2-OFF / unmatched + # fallback path only and never re-pops a chunk we already streamed. pending_tool_call_chunks: list[dict[str, Any]] = [] lc_tool_call_id_by_run: dict[str, str] = {} + # parity_v2 only: live tool-call argument streaming. ``index_to_meta`` + # is keyed by the chunk's ``index`` field — LangChain + # ``ToolCallChunk``s for the same call share an index but only the + # first chunk carries id+name (subsequent ones are id=None, + # name=None, args="<delta>"). We register an index when both id and + # name are observed on a chunk (per ToolCallChunk semantics they + # arrive together on the first chunk), then route every later chunk + # at that index to the same ``ui_id`` as a ``tool-input-delta``. + # ``ui_tool_call_id_by_run`` maps LangGraph ``run_id`` to the + # ``ui_id`` used for that call's ``tool-input-start`` so the matching + # ``tool-output-available`` (emitted from ``on_tool_end``) lands on + # the same card. + index_to_meta: dict[int, dict[str, str]] = {} + ui_tool_call_id_by_run: dict[str, str] = {} + # Per-tool-end mutable cache for the LangChain tool_call_id resolved # at ``on_tool_end``. ``_emit_tool_output`` reads this so every # ``format_tool_output_available`` call automatically carries the @@ -452,13 +506,6 @@ async def _stream_agent_events( continue parts = _extract_chunk_parts(chunk) - # Accumulate any tool_call_chunks for best-effort - # correlation with ``on_tool_start`` below. We don't emit - # anything here; the matching is done at tool-start time. - if parity_v2 and parts["tool_call_chunks"]: - for tcc in parts["tool_call_chunks"]: - pending_tool_call_chunks.append(tcc) - reasoning_delta = parts["reasoning"] text_delta = parts["text"] @@ -504,6 +551,71 @@ async def _stream_agent_events( yield streaming_service.format_text_delta(current_text_id, text_delta) accumulated_text += text_delta + # Live tool-call argument streaming. Runs AFTER text/reasoning + # processing so chunks containing both stay in their natural + # wire order (text → text-end → tool-input-start). Active + # text/reasoning are closed inside the registration branch + # before ``tool-input-start`` so the frontend sees a clean + # part boundary even when providers interleave. + if parity_v2 and parts["tool_call_chunks"]: + for tcc in parts["tool_call_chunks"]: + idx = tcc.get("index") + + # Register this index when we first see id+name + # TOGETHER. Per LangChain ToolCallChunk semantics the + # first chunk for a tool call carries both fields + # together; later chunks have id=None, name=None and + # only ``args``. Requiring BOTH keeps wire + # ``tool-input-start`` always carrying a real + # toolName (assistant-ui's typed tool-part dispatch + # keys off it). + if idx is not None and idx not in index_to_meta: + lc_id = tcc.get("id") + name = tcc.get("name") + if lc_id and name: + ui_id = lc_id + + # Close active text/reasoning so wire + # ordering stays clean even on providers + # that interleave text and tool-call chunks + # within the same stream window. + if current_text_id is not None: + yield streaming_service.format_text_end(current_text_id) + current_text_id = None + if current_reasoning_id is not None: + yield streaming_service.format_reasoning_end( + current_reasoning_id + ) + current_reasoning_id = None + + index_to_meta[idx] = { + "ui_id": ui_id, + "lc_id": lc_id, + "name": name, + } + yield streaming_service.format_tool_input_start( + ui_id, + name, + langchain_tool_call_id=lc_id, + ) + + # Emit args delta for any chunk at a registered + # index (including idless continuations). Once an + # index is owned by ``index_to_meta`` we DO NOT + # append to ``pending_tool_call_chunks`` — that list + # is reserved for the parity_v2-OFF / unmatched + # fallback path so it never re-pops chunks already + # consumed here (skip-append). + meta = index_to_meta.get(idx) if idx is not None else None + if meta: + args_chunk = tcc.get("args") or "" + if args_chunk: + yield streaming_service.format_tool_input_delta( + meta["ui_id"], args_chunk + ) + else: + pending_tool_call_chunks.append(tcc) + elif event_type == "on_tool_start": active_tool_depth += 1 tool_name = event.get("name", "unknown_tool") @@ -834,44 +946,65 @@ async def _stream_agent_events( status="in_progress", ) - tool_call_id = ( - f"call_{run_id[:32]}" - if run_id - else streaming_service.generate_tool_call_id() - ) - - # Best-effort attach the LangChain ``tool_call_id``. We - # pop the first chunk in ``pending_tool_call_chunks`` whose - # name matches; if none match (the chunked args may not yet - # carry a ``name`` field, or the model skipped the chunked - # form) we leave ``langchainToolCallId`` unset for now and - # fill it in authoritatively at ``on_tool_end`` from - # ``ToolMessage.tool_call_id``. - langchain_tool_call_id: str | None = None - if parity_v2 and pending_tool_call_chunks: - matched_idx: int | None = None - for idx, tcc in enumerate(pending_tool_call_chunks): - if tcc.get("name") == tool_name and tcc.get("id"): - matched_idx = idx + # Resolve the card identity. If the chunk-emission loop + # already registered an ``index`` for this tool call (parity_v2 + # path), reuse the same ui_id so the card sees: + # tool-input-start → deltas… → tool-input-available → + # tool-output-available all keyed by lc_id. Otherwise fall + # back to the synthetic ``call_<run_id>`` id and the legacy + # best-effort match against ``pending_tool_call_chunks``. + matched_meta: dict[str, str] | None = None + if parity_v2: + # FIFO over indices 0,1,2…; first unassigned same-name + # match wins. Handles parallel same-name calls (e.g. two + # write_file calls) deterministically as long as the + # model interleaves on_tool_start in the same order it + # streamed the args. + taken_ui_ids = set(ui_tool_call_id_by_run.values()) + for meta in index_to_meta.values(): + if meta["name"] == tool_name and meta["ui_id"] not in taken_ui_ids: + matched_meta = meta break - if matched_idx is None: - for idx, tcc in enumerate(pending_tool_call_chunks): - if tcc.get("id"): - matched_idx = idx - break - if matched_idx is not None: - matched = pending_tool_call_chunks.pop(matched_idx) - candidate = matched.get("id") - if isinstance(candidate, str) and candidate: - langchain_tool_call_id = candidate - if run_id: - lc_tool_call_id_by_run[run_id] = candidate - yield streaming_service.format_tool_input_start( - tool_call_id, - tool_name, - langchain_tool_call_id=langchain_tool_call_id, - ) + tool_call_id: str + langchain_tool_call_id: str | None = None + if matched_meta is not None: + tool_call_id = matched_meta["ui_id"] + langchain_tool_call_id = matched_meta["lc_id"] + # ``tool-input-start`` already fired during chunk + # emission — skip the duplicate. No pruning is needed + # because the chunk-emission loop intentionally never + # appends registered-index chunks to + # ``pending_tool_call_chunks`` (skip-append). + if run_id: + lc_tool_call_id_by_run[run_id] = matched_meta["lc_id"] + else: + tool_call_id = ( + f"call_{run_id[:32]}" + if run_id + else streaming_service.generate_tool_call_id() + ) + # Legacy fallback: parity_v2 OFF, or parity_v2 ON but the + # provider didn't stream tool_call_chunks for this call + # (no index registered). Run the existing best-effort + # match BEFORE emitting start so we still attach an + # authoritative ``langchainToolCallId`` when possible. + if parity_v2: + langchain_tool_call_id = _legacy_match_lc_id( + pending_tool_call_chunks, + tool_name, + run_id, + lc_tool_call_id_by_run, + ) + yield streaming_service.format_tool_input_start( + tool_call_id, + tool_name, + langchain_tool_call_id=langchain_tool_call_id, + ) + + if run_id: + ui_tool_call_id_by_run[run_id] = tool_call_id + # Sanitize tool_input: strip runtime-injected non-serializable # values (e.g. LangChain ToolRuntime) before sending over SSE. if isinstance(tool_input, dict): @@ -924,7 +1057,15 @@ async def _stream_agent_events( result.write_succeeded = True result.verification_succeeded = True - tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown" + # Look up the SAME card id used at on_tool_start (either the + # parity_v2 lc-id-derived ui_id or the legacy synthetic + # ``call_<run_id>``) so the output event always lands on the + # same card as start/delta/available. Fallback preserves the + # legacy synthetic shape for parity_v2-OFF / unknown-run paths. + tool_call_id = ui_tool_call_id_by_run.get( + run_id, + f"call_{run_id[:32]}" if run_id else "call_unknown", + ) original_step_id = tool_step_ids.get( run_id, f"{step_prefix}-unknown-{run_id[:8]}" ) @@ -935,17 +1076,22 @@ async def _stream_agent_events( # at ``on_tool_start`` time (kept in ``lc_tool_call_id_by_run``) # if the output isn't a ToolMessage. The value is stored in # ``current_lc_tool_call_id`` so ``_emit_tool_output`` - # picks it up for every output emit below. Stays None when - # parity_v2 is off so legacy emit paths are untouched. + # picks it up for every output emit below. + # + # Emitted in BOTH parity_v2 and legacy modes: the chat tool + # card needs the LangChain id to match against the + # ``data-action-log`` SSE event (keyed by ``lc_tool_call_id``) + # so the inline Revert button can light up. Reading + # ``raw_output.tool_call_id`` is a cheap, non-mutating attribute + # access that is safe regardless of feature-flag state. current_lc_tool_call_id["value"] = None - if parity_v2: - authoritative = getattr(raw_output, "tool_call_id", None) - if isinstance(authoritative, str) and authoritative: - current_lc_tool_call_id["value"] = authoritative - if run_id: - lc_tool_call_id_by_run[run_id] = authoritative - elif run_id and run_id in lc_tool_call_id_by_run: - current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id] + authoritative = getattr(raw_output, "tool_call_id", None) + if isinstance(authoritative, str) and authoritative: + current_lc_tool_call_id["value"] = authoritative + if run_id: + lc_tool_call_id_by_run[run_id] = authoritative + elif run_id and run_id in lc_tool_call_id_by_run: + current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id] if tool_name == "read_file": yield streaming_service.format_thinking_step( diff --git a/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py b/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py index 7f32bf456..1263a5fe1 100644 --- a/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py +++ b/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py @@ -183,3 +183,46 @@ class TestDefensive: assert out["text"] == "" assert out["reasoning"] == "" assert out["tool_call_chunks"] == [] + + +class TestIdlessContinuationChunks: + """Per LangChain ``ToolCallChunk`` semantics, the FIRST chunk for a + tool call carries id+name; later chunks for the same call have + ``id=None, name=None`` and only ``args`` + ``index``. Live tool-call + argument streaming relies on those idless continuation chunks + flowing through ``_extract_chunk_parts`` UNTOUCHED so the upstream + chunk-emission loop can still route them by ``index``. + """ + + def test_idless_continuation_chunk_preserved_verbatim(self) -> None: + chunk = _FakeChunk( + tool_call_chunks=[ + {"id": None, "name": None, "args": '_path":"/x"}', "index": 0} + ] + ) + out = _extract_chunk_parts(chunk) + assert len(out["tool_call_chunks"]) == 1 + tcc = out["tool_call_chunks"][0] + assert tcc.get("id") is None + assert tcc.get("name") is None + assert tcc.get("args") == '_path":"/x"}' + assert tcc.get("index") == 0 + + def test_first_then_idless_sequence_preserves_index(self) -> None: + """Both chunks for the same call share an ``index`` key — the + index-routing loop in ``stream_new_chat`` depends on it.""" + first = _FakeChunk( + tool_call_chunks=[ + {"id": "lc-1", "name": "write_file", "args": '{"file', "index": 0} + ] + ) + cont = _FakeChunk( + tool_call_chunks=[ + {"id": None, "name": None, "args": '_path":"/x"}', "index": 0} + ] + ) + out_first = _extract_chunk_parts(first) + out_cont = _extract_chunk_parts(cont) + assert out_first["tool_call_chunks"][0]["index"] == 0 + assert out_cont["tool_call_chunks"][0]["index"] == 0 + assert out_cont["tool_call_chunks"][0].get("id") is None diff --git a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py new file mode 100644 index 000000000..9258d5cfe --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py @@ -0,0 +1,527 @@ +"""Unit tests for live tool-call argument streaming. + +Pins the wire format that ``_stream_agent_events`` emits when +``SURFSENSE_ENABLE_STREAM_PARITY_V2=true``: ``tool-input-start`` → +``tool-input-delta``... → ``tool-input-available`` → ``tool-output-available`` +all keyed by the same LangChain ``tool_call.id``. + +Identity is tracked in ``index_to_meta`` (per-chunk ``index``) and +``ui_tool_call_id_by_run`` (LangGraph ``run_id``); both are private to +``_stream_agent_events`` so we exercise them via the public wire output. + +These tests also lock in the legacy / parity_v2-OFF behaviour so the +synthetic ``call_<run_id>`` shape stays stable for older clients. +""" + +from __future__ import annotations + +import json +from collections.abc import AsyncGenerator +from dataclasses import dataclass, field +from typing import Any + +import pytest + +import app.tasks.chat.stream_new_chat as stream_module +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.services.new_streaming_service import VercelStreamingService +from app.tasks.chat.stream_new_chat import ( + StreamResult, + _legacy_match_lc_id, + _stream_agent_events, +) + +pytestmark = pytest.mark.unit + + +@dataclass +class _FakeChunk: + """Minimal stand-in for ``AIMessageChunk``.""" + + content: Any = "" + additional_kwargs: dict[str, Any] = field(default_factory=dict) + tool_call_chunks: list[dict[str, Any]] = field(default_factory=list) + + +@dataclass +class _FakeToolMessage: + """Stand-in for ``ToolMessage`` returned by ``on_tool_end``.""" + + content: Any + tool_call_id: str | None = None + + +class _FakeAgentState: + """Stand-in for ``StateSnapshot`` returned by ``aget_state``.""" + + def __init__(self) -> None: + # Empty values keeps the cloud-fallback safety-net branch a no-op, + # and an empty ``tasks`` list keeps the post-stream interrupt + # check a no-op too. + self.values: dict[str, Any] = {} + self.tasks: list[Any] = [] + + +class _FakeAgent: + """Replays a list of ``astream_events`` events.""" + + def __init__(self, events: list[dict[str, Any]]) -> None: + self._events = events + + async def astream_events( # type: ignore[no-untyped-def] + self, _input_data: Any, *, config: dict[str, Any], version: str + ) -> AsyncGenerator[dict[str, Any], None]: + del config, version # unused, contract-compatible + for ev in self._events: + yield ev + + async def aget_state(self, _config: dict[str, Any]) -> _FakeAgentState: + # Called once after astream_events drains so the cloud-fallback + # safety net can inspect staged filesystem work. The fake stays + # empty so the safety net is a no-op. + return _FakeAgentState() + + +def _model_stream( + *, + text: str = "", + reasoning: str = "", + tool_call_chunks: list[dict[str, Any]] | None = None, + tags: list[str] | None = None, +) -> dict[str, Any]: + return ( + { + "event": "on_chat_model_stream", + "tags": tags or [], + "data": { + "chunk": _FakeChunk( + content=text, + tool_call_chunks=list(tool_call_chunks or []), + ) + }, + # reasoning piggybacks via additional_kwargs path; if needed, + # override content to a typed-block list. Most tests just check + # tool_call_chunks routing so this is fine. + } + if not reasoning + else { + "event": "on_chat_model_stream", + "tags": tags or [], + "data": { + "chunk": _FakeChunk( + content=text, + additional_kwargs={"reasoning_content": reasoning}, + tool_call_chunks=list(tool_call_chunks or []), + ) + }, + } + ) + + +def _tool_start( + *, + name: str, + run_id: str, + input_payload: dict[str, Any] | None = None, +) -> dict[str, Any]: + return { + "event": "on_tool_start", + "name": name, + "run_id": run_id, + "data": {"input": input_payload or {}}, + } + + +def _tool_end( + *, + name: str, + run_id: str, + tool_call_id: str | None = None, + output: Any = "ok", +) -> dict[str, Any]: + return { + "event": "on_tool_end", + "name": name, + "run_id": run_id, + "data": { + "output": _FakeToolMessage( + content=json.dumps(output) if not isinstance(output, str) else output, + tool_call_id=tool_call_id, + ) + }, + } + + +@pytest.fixture +def parity_v2_on(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + stream_module, + "get_flags", + lambda: AgentFeatureFlags(enable_stream_parity_v2=True), + ) + + +@pytest.fixture +def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + stream_module, + "get_flags", + lambda: AgentFeatureFlags(enable_stream_parity_v2=False), + ) + + +async def _drain(events: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Run ``_stream_agent_events`` against a fake agent and return the + SSE payloads (parsed JSON) it yielded. + """ + agent = _FakeAgent(events) + service = VercelStreamingService() + result = StreamResult() + config = {"configurable": {"thread_id": "test-thread"}} + sse_lines: list[str] = [] + async for sse in _stream_agent_events( + agent, config, {}, service, result, step_prefix="thinking" + ): + sse_lines.append(sse) + + parsed: list[dict[str, Any]] = [] + for line in sse_lines: + if not line.startswith("data: "): + continue + body = line[len("data: ") :].rstrip("\n") + if not body or body == "[DONE]": + continue + try: + parsed.append(json.loads(body)) + except json.JSONDecodeError: + continue + return parsed + + +def _types(payloads: list[dict[str, Any]]) -> list[str]: + return [p.get("type", "?") for p in payloads] + + +def _of_type(payloads: list[dict[str, Any]], type_name: str) -> list[dict[str, Any]]: + return [p for p in payloads if p.get("type") == type_name] + + +# --------------------------------------------------------------------------- +# Helper: ``_legacy_match_lc_id`` is a pure refactor; assert behaviour. +# --------------------------------------------------------------------------- + + +class TestLegacyMatch: + def test_pops_first_id_bearing_chunk_with_matching_name(self) -> None: + chunks: list[dict[str, Any]] = [ + {"id": "x1", "name": "ls"}, + {"id": "y1", "name": "write_file"}, + ] + runs: dict[str, str] = {} + result = _legacy_match_lc_id(chunks, "write_file", "run-1", runs) + assert result == "y1" + assert chunks == [{"id": "x1", "name": "ls"}] + assert runs == {"run-1": "y1"} + + def test_falls_back_to_any_id_bearing_when_name_mismatches(self) -> None: + chunks: list[dict[str, Any]] = [{"id": "anon", "name": None}] + runs: dict[str, str] = {} + out = _legacy_match_lc_id(chunks, "ls", "run-2", runs) + assert out == "anon" + assert chunks == [] + + def test_returns_none_when_no_id_bearing_chunk(self) -> None: + chunks: list[dict[str, Any]] = [{"id": None, "name": None}] + runs: dict[str, str] = {} + assert _legacy_match_lc_id(chunks, "ls", "run-3", runs) is None + assert chunks == [{"id": None, "name": None}] + assert runs == {} + + +# --------------------------------------------------------------------------- +# parity_v2 wire format tests. +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_idless_chunk_merging_by_index(parity_v2_on: None) -> None: + """First chunk carries id+name; later idless chunks at the same + ``index`` merge into the SAME ``tool-input-start`` ui id and emit + one ``tool-input-delta`` per chunk.""" + events = [ + _model_stream( + tool_call_chunks=[ + {"id": "lc-1", "name": "write_file", "args": '{"file', "index": 0} + ], + ), + _model_stream( + tool_call_chunks=[ + {"id": None, "name": None, "args": '_path":"/x"}', "index": 0} + ], + ), + _tool_start( + name="write_file", run_id="run-A", input_payload={"file_path": "/x"} + ), + _tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"), + ] + + payloads = await _drain(events) + + starts = _of_type(payloads, "tool-input-start") + deltas = _of_type(payloads, "tool-input-delta") + available = _of_type(payloads, "tool-input-available") + output = _of_type(payloads, "tool-output-available") + + assert len(starts) == 1 + assert starts[0]["toolCallId"] == "lc-1" + assert starts[0]["toolName"] == "write_file" + assert starts[0]["langchainToolCallId"] == "lc-1" + + assert [d["inputTextDelta"] for d in deltas] == ['{"file', '_path":"/x"}'] + assert all(d["toolCallId"] == "lc-1" for d in deltas) + + assert len(available) == 1 + assert available[0]["toolCallId"] == "lc-1" + + assert len(output) == 1 + assert output[0]["toolCallId"] == "lc-1" + + +@pytest.mark.asyncio +async def test_two_interleaved_tool_calls_route_by_index( + parity_v2_on: None, +) -> None: + """Two same-name calls with distinct indices keep their deltas + routed to the right card.""" + events = [ + _model_stream( + tool_call_chunks=[ + {"id": "lc-A", "name": "write_file", "args": '{"a":1', "index": 0}, + {"id": "lc-B", "name": "write_file", "args": '{"b":2', "index": 1}, + ] + ), + _model_stream( + tool_call_chunks=[ + {"id": None, "name": None, "args": "}", "index": 0}, + {"id": None, "name": None, "args": "}", "index": 1}, + ] + ), + _tool_start(name="write_file", run_id="run-A", input_payload={"a": 1}), + _tool_end(name="write_file", run_id="run-A", tool_call_id="lc-A"), + _tool_start(name="write_file", run_id="run-B", input_payload={"b": 2}), + _tool_end(name="write_file", run_id="run-B", tool_call_id="lc-B"), + ] + + payloads = await _drain(events) + + starts = _of_type(payloads, "tool-input-start") + deltas = _of_type(payloads, "tool-input-delta") + output = _of_type(payloads, "tool-output-available") + + assert {s["toolCallId"] for s in starts} == {"lc-A", "lc-B"} + + by_id: dict[str, list[str]] = {"lc-A": [], "lc-B": []} + for d in deltas: + by_id[d["toolCallId"]].append(d["inputTextDelta"]) + assert by_id["lc-A"] == ['{"a":1', "}"] + assert by_id["lc-B"] == ['{"b":2', "}"] + + assert {o["toolCallId"] for o in output} == {"lc-A", "lc-B"} + + +@pytest.mark.asyncio +async def test_identity_stable_across_lifecycle(parity_v2_on: None) -> None: + """Whatever id ``tool-input-start`` chose must be the SAME id used + on ``tool-input-available`` AND ``tool-output-available``.""" + events = [ + _model_stream( + tool_call_chunks=[ + {"id": "lc-9", "name": "ls", "args": '{"path":"/"}', "index": 0} + ] + ), + _tool_start(name="ls", run_id="run-X", input_payload={"path": "/"}), + _tool_end(name="ls", run_id="run-X", tool_call_id="lc-9"), + ] + payloads = await _drain(events) + relevant = [ + p + for p in payloads + if p.get("type") + in {"tool-input-start", "tool-input-available", "tool-output-available"} + ] + assert {p["toolCallId"] for p in relevant} == {"lc-9"} + + +@pytest.mark.asyncio +async def test_no_duplicate_tool_input_start(parity_v2_on: None) -> None: + """When the chunk-emission loop already fired ``tool-input-start`` + for this run, ``on_tool_start`` MUST NOT emit a second one.""" + events = [ + _model_stream( + tool_call_chunks=[ + {"id": "lc-1", "name": "write_file", "args": "{}", "index": 0} + ] + ), + _tool_start(name="write_file", run_id="run-A", input_payload={}), + _tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"), + ] + payloads = await _drain(events) + starts = _of_type(payloads, "tool-input-start") + assert len(starts) == 1 + assert starts[0]["toolCallId"] == "lc-1" + + +@pytest.mark.asyncio +async def test_active_text_closes_before_early_tool_input_start( + parity_v2_on: None, +) -> None: + """Streaming a text-delta then a tool-call chunk in subsequent + chunks: the wire MUST contain ``text-end`` before the FIRST + ``tool-input-start`` (clean part boundary on the frontend).""" + events = [ + _model_stream(text="Working on it"), + _model_stream( + tool_call_chunks=[ + {"id": "lc-1", "name": "write_file", "args": "{}", "index": 0} + ] + ), + _tool_start(name="write_file", run_id="run-A", input_payload={}), + _tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"), + ] + types = _types(await _drain(events)) + text_end_idx = types.index("text-end") + start_idx = types.index("tool-input-start") + assert text_end_idx < start_idx + + +@pytest.mark.asyncio +async def test_mixed_text_and_tool_chunk_preserve_order( + parity_v2_on: None, +) -> None: + """One AIMessageChunk that carries BOTH ``text`` content AND + ``tool_call_chunks`` should emit the text delta FIRST, then close + text, then ``tool-input-start``+``tool-input-delta``.""" + events = [ + _model_stream( + text="I'll update it", + tool_call_chunks=[ + { + "id": "lc-1", + "name": "write_file", + "args": '{"file_path":"/x"}', + "index": 0, + } + ], + ), + _tool_start( + name="write_file", run_id="run-A", input_payload={"file_path": "/x"} + ), + _tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"), + ] + types = _types(await _drain(events)) + # text-start … text-delta … text-end … tool-input-start … tool-input-delta + assert types.index("text-start") < types.index("text-delta") + assert types.index("text-delta") < types.index("text-end") + assert types.index("text-end") < types.index("tool-input-start") + assert types.index("tool-input-start") < types.index("tool-input-delta") + + +@pytest.mark.asyncio +async def test_parity_v2_off_preserves_legacy_shape( + parity_v2_off: None, +) -> None: + """When the flag is OFF, no deltas are emitted and the ``toolCallId`` + is ``call_<run_id>`` (NOT the lc id).""" + events = [ + _model_stream( + tool_call_chunks=[ + {"id": "lc-1", "name": "ls", "args": '{"path":"/"}', "index": 0} + ] + ), + _tool_start(name="ls", run_id="run-A", input_payload={"path": "/"}), + _tool_end(name="ls", run_id="run-A", tool_call_id="lc-1"), + ] + payloads = await _drain(events) + + assert _of_type(payloads, "tool-input-delta") == [] + starts = _of_type(payloads, "tool-input-start") + assert len(starts) == 1 + assert starts[0]["toolCallId"].startswith("call_run-A") + # No ``langchainToolCallId`` propagation on ``tool-input-start`` in + # legacy mode (the start event fires before the ToolMessage is + # available, so we can't extract the authoritative LangChain id yet). + assert "langchainToolCallId" not in starts[0] + output = _of_type(payloads, "tool-output-available") + assert output[0]["toolCallId"].startswith("call_run-A") + # ``tool-output-available`` MUST carry ``langchainToolCallId`` even + # in legacy mode: the chat tool card uses it to backfill the + # LangChain id and join against the ``data-action-log`` SSE event + # (keyed by ``lc_tool_call_id``) so the inline Revert button can + # light up. Sourced from the returned ``ToolMessage.tool_call_id``, + # which is populated regardless of feature-flag state. + assert output[0]["langchainToolCallId"] == "lc-1" + + +@pytest.mark.asyncio +async def test_skip_append_prevents_stale_id_reuse( + parity_v2_on: None, +) -> None: + """Two same-name tools: the SECOND tool's ``langchainToolCallId`` + must NOT come from the first tool's chunk (``pending_tool_call_chunks`` + must stay empty for indexed-registered chunks).""" + events = [ + _model_stream( + tool_call_chunks=[ + {"id": "lc-A", "name": "write_file", "args": "{}", "index": 0}, + {"id": "lc-B", "name": "write_file", "args": "{}", "index": 1}, + ] + ), + _tool_start(name="write_file", run_id="run-1", input_payload={}), + _tool_end(name="write_file", run_id="run-1", tool_call_id="lc-A"), + _tool_start(name="write_file", run_id="run-2", input_payload={}), + _tool_end(name="write_file", run_id="run-2", tool_call_id="lc-B"), + ] + payloads = await _drain(events) + + starts = _of_type(payloads, "tool-input-start") + # Two distinct lc ids, each its own card. + assert {s["toolCallId"] for s in starts} == {"lc-A", "lc-B"} + # Each tool-output-available landed on its respective card. + output = _of_type(payloads, "tool-output-available") + assert {o["toolCallId"] for o in output} == {"lc-A", "lc-B"} + + +@pytest.mark.asyncio +async def test_registration_waits_for_both_id_and_name( + parity_v2_on: None, +) -> None: + """An id-only chunk (no name yet) must NOT emit ``tool-input-start``.""" + events = [ + _model_stream( + tool_call_chunks=[{"id": "lc-1", "name": None, "args": "", "index": 0}] + ), + ] + payloads = await _drain(events) + assert _of_type(payloads, "tool-input-start") == [] + + +@pytest.mark.asyncio +async def test_unmatched_fallback_still_attaches_lc_id( + parity_v2_on: None, +) -> None: + """parity_v2 ON, but the provider didn't include an ``index``: the + legacy fallback path must still emit ``tool-input-start`` with the + matching ``langchainToolCallId``.""" + events = [ + # No index on the chunk → not registered into index_to_meta; + # falls through to ``pending_tool_call_chunks`` so the legacy + # match path can pop it at on_tool_start. + _model_stream(tool_call_chunks=[{"id": "lc-orphan", "name": "ls", "args": ""}]), + _tool_start(name="ls", run_id="run-1", input_payload={"path": "/"}), + _tool_end(name="ls", run_id="run-1", tool_call_id="lc-orphan"), + ] + payloads = await _drain(events) + starts = _of_type(payloads, "tool-input-start") + assert len(starts) == 1 + assert starts[0]["toolCallId"].startswith("call_run-1") + assert starts[0]["langchainToolCallId"] == "lc-orphan" diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index c2086e80a..e5ac61cd9 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -14,13 +14,6 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; import { z } from "zod"; import { disabledToolsAtom } from "@/atoms/agent-tools/agent-tools.atoms"; -import { - agentActionsByChatTurnIdAtom, - markAgentActionRevertedAtom, - resetAgentActionMapAtom, - updateAgentActionReversibleAtom, - upsertAgentActionAtom, -} from "@/atoms/chat/agent-actions.atom"; import { clearTargetCommentIdAtom, currentThreadAtom, @@ -55,6 +48,12 @@ import { type TokenUsageData, TokenUsageProvider, } from "@/components/assistant-ui/token-usage-context"; +import { + applyActionLogSse, + applyActionLogUpdatedSse, + markActionRevertedInCache, + useAgentActionsQuery, +} from "@/hooks/use-agent-actions-query"; import { useChatSessionStateSync } from "@/hooks/use-chat-session-state"; import { useMessagesSync } from "@/hooks/use-messages-sync"; import { getAgentFilesystemSelection } from "@/lib/agent-filesystem"; @@ -71,12 +70,12 @@ import { addToolCall, appendReasoning, appendText, + appendToolInputDelta, buildContentForPersistence, buildContentForUI, type ContentPartsState, endReasoning, FrameBatchedUpdater, - findToolCallIdByLcId, readSSEStream, type ThinkingStepData, type ToolUIGate, @@ -246,14 +245,6 @@ export default function NewChatPage() { const setAgentCreatedDocuments = useSetAtom(agentCreatedDocumentsAtom); const pendingUserImageUrls = useAtomValue(pendingUserImageDataUrlsAtom); const setPendingUserImageUrls = useSetAtom(pendingUserImageDataUrlsAtom); - // Agent action log SSE side-channel. - const upsertAgentAction = useSetAtom(upsertAgentActionAtom); - const updateAgentActionReversible = useSetAtom(updateAgentActionReversibleAtom); - const markAgentActionReverted = useSetAtom(markAgentActionRevertedAtom); - const resetAgentActionMap = useSetAtom(resetAgentActionMapAtom); - // Chat-turn-keyed action map for the edit-from-position pre-flight - // that decides whether to show the confirmation dialog. - const agentActionsByChatTurnId = useAtomValue(agentActionsByChatTurnIdAtom); // Edit dialog state. Holds the message id being edited and // the (already extracted) regenerate args so we can resume the edit // after the user picks "revert all" / "continue" / "cancel". @@ -282,6 +273,11 @@ export default function NewChatPage() { content: unknown; author_id: string | null; created_at: string; + // Forwarded so ``convertToThreadMessage`` can rebuild the + // ``metadata.custom.chatTurnId`` on the + // ``ThreadMessageLike``. Required by the inline Revert + // button's per-turn fallback. + turn_id?: string | null; }[] ) => { if (isRunning) { @@ -314,6 +310,11 @@ export default function NewChatPage() { created_at: msg.created_at, author_display_name: member?.user_display_name ?? existingAuthor?.displayName ?? null, author_avatar_url: member?.user_avatar_url ?? existingAuthor?.avatarUrl ?? null, + // Forward the per-turn correlation id so the + // inline Revert button's ``(chat_turn_id, + // tool_name, position)`` fallback survives the + // post-stream Zero re-sync. + turn_id: msg.turn_id ?? null, }); }); }); @@ -330,6 +331,13 @@ export default function NewChatPage() { return Number.isNaN(parsed) ? 0 : parsed; }, [params.search_space_id]); + // Unified store for agent-action rows (the same react-query cache + // the agent-actions sheet, the inline Revert button, and the + // per-turn Revert button all read). Hydrates from + // ``GET /threads/{id}/actions`` and is updated incrementally by the + // SSE handlers + revert-batch results below — no atom side-channel. + const { items: agentActionItems } = useAgentActionsQuery(threadId); + // Extract chat_id from URL params const urlChatId = useMemo(() => { const id = params.chat_id; @@ -357,7 +365,8 @@ export default function NewChatPage() { clearPlanOwnerRegistry(); closeReportPanel(); closeEditorPanel(); - resetAgentActionMap(); + // Note: agent-action data is keyed by threadId in react-query so + // switching threads naturally swaps caches; no explicit reset. try { if (urlChatId > 0) { @@ -426,7 +435,6 @@ export default function NewChatPage() { removeChatTab, searchSpaceId, tokenUsageStore, - resetAgentActionMap, ]); // Initialize on mount, and re-init when switching search spaces (even if urlChatId is the same) @@ -779,6 +787,15 @@ export default function NewChatPage() { ); }; const scheduleFlush = () => batcher.schedule(flushMessages); + // Force-flush helper: ``batcher.flush()`` is a no-op when + // ``dirty=false`` (e.g. a tool starts before any text + // streamed). ``scheduleFlush(); batcher.flush()`` sets + // the dirty bit FIRST so terminal events render + // promptly without the 50ms throttle delay. + const forceFlush = () => { + scheduleFlush(); + batcher.flush(); + }; for await (const parsed of readSSEStream(response)) { switch (parsed.type) { @@ -815,13 +832,23 @@ export default function NewChatPage() { false, parsed.langchainToolCallId ); - batcher.flush(); + forceFlush(); + break; + + case "tool-input-delta": + // High-frequency event: deltas can fire dozens + // of times per call, so use throttled + // scheduleFlush (NOT forceFlush) to coalesce. + appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); + scheduleFlush(); break; case "tool-input-available": { + const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); if (toolCallIndices.has(parsed.toolCallId)) { updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {}, + argsText: finalArgsText, langchainToolCallId: parsed.langchainToolCallId, }); } else { @@ -834,8 +861,14 @@ export default function NewChatPage() { false, parsed.langchainToolCallId ); + // addToolCall doesn't accept argsText today; + // backfill via updateToolCall so the new card + // renders pretty-printed JSON. + updateToolCall(contentPartsState, parsed.toolCallId, { + argsText: finalArgsText, + }); } - batcher.flush(); + forceFlush(); break; } @@ -854,7 +887,7 @@ export default function NewChatPage() { } } } - batcher.flush(); + forceFlush(); break; } @@ -950,34 +983,17 @@ export default function NewChatPage() { } case "data-action-log": { - const al = parsed.data; - const matchedToolCallId = al.lc_tool_call_id - ? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id) - : null; - upsertAgentAction({ - action: { - id: al.id, - threadId: currentThreadId, - lcToolCallId: al.lc_tool_call_id, - chatTurnId: al.chat_turn_id, - toolName: al.tool_name, - reversible: al.reversible, - reverseDescriptorPresent: al.reverse_descriptor_present, - error: al.error, - revertedByActionId: null, - isRevertAction: false, - createdAt: al.created_at, - }, - toolCallId: matchedToolCallId, - }); + applyActionLogSse(queryClient, currentThreadId, searchSpaceId, parsed.data); break; } case "data-action-log-updated": { - updateAgentActionReversible({ - id: parsed.data.id, - reversible: parsed.data.reversible, - }); + applyActionLogUpdatedSse( + queryClient, + currentThreadId, + parsed.data.id, + parsed.data.reversible + ); break; } @@ -1179,6 +1195,15 @@ export default function NewChatPage() { toolName: String(p.toolName), args: (p.args as Record<string, unknown>) ?? {}, result: p.result as unknown, + // Restore argsText so persisted pretty-printed + // JSON survives reloads (assistant-ui prefers + // supplied argsText over JSON.stringify(args)). + // langchainToolCallId restoration also fixes a + // pre-existing dropped-id bug on resume. + ...(typeof p.argsText === "string" ? { argsText: p.argsText } : {}), + ...(typeof p.langchainToolCallId === "string" + ? { langchainToolCallId: p.langchainToolCallId } + : {}), }); contentPartsState.currentTextPartIndex = -1; } else if (p.type === "data-thinking-steps") { @@ -1200,7 +1225,12 @@ export default function NewChatPage() { const editedAction = decisions[0].edited_action; for (const part of contentParts) { if (part.type === "tool-call" && part.toolName === editedAction.name) { - part.args = { ...part.args, ...editedAction.args }; + const mergedArgs = { ...part.args, ...editedAction.args }; + part.args = mergedArgs; + // Sync argsText so the rendered card shows the + // edited inputs — assistant-ui prefers caller- + // supplied argsText over JSON.stringify(args). + part.argsText = JSON.stringify(mergedArgs, null, 2); break; } } @@ -1256,6 +1286,10 @@ export default function NewChatPage() { ); }; const scheduleFlush = () => batcher.schedule(flushMessages); + const forceFlush = () => { + scheduleFlush(); + batcher.flush(); + }; for await (const parsed of readSSEStream(response)) { switch (parsed.type) { @@ -1292,13 +1326,20 @@ export default function NewChatPage() { false, parsed.langchainToolCallId ); - batcher.flush(); + forceFlush(); break; - case "tool-input-available": + case "tool-input-delta": + appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); + scheduleFlush(); + break; + + case "tool-input-available": { + const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); if (toolCallIndices.has(parsed.toolCallId)) { updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {}, + argsText: finalArgsText, langchainToolCallId: parsed.langchainToolCallId, }); } else { @@ -1311,9 +1352,13 @@ export default function NewChatPage() { false, parsed.langchainToolCallId ); + updateToolCall(contentPartsState, parsed.toolCallId, { + argsText: finalArgsText, + }); } - batcher.flush(); + forceFlush(); break; + } case "tool-output-available": updateToolCall(contentPartsState, parsed.toolCallId, { @@ -1321,7 +1366,7 @@ export default function NewChatPage() { langchainToolCallId: parsed.langchainToolCallId, }); markInterruptsCompleted(contentParts); - batcher.flush(); + forceFlush(); break; case "data-thinking-step": { @@ -1381,34 +1426,17 @@ export default function NewChatPage() { } case "data-action-log": { - const al = parsed.data; - const matchedToolCallId = al.lc_tool_call_id - ? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id) - : null; - upsertAgentAction({ - action: { - id: al.id, - threadId: resumeThreadId, - lcToolCallId: al.lc_tool_call_id, - chatTurnId: al.chat_turn_id, - toolName: al.tool_name, - reversible: al.reversible, - reverseDescriptorPresent: al.reverse_descriptor_present, - error: al.error, - revertedByActionId: null, - isRevertAction: false, - createdAt: al.created_at, - }, - toolCallId: matchedToolCallId, - }); + applyActionLogSse(queryClient, resumeThreadId, searchSpaceId, parsed.data); break; } case "data-action-log-updated": { - updateAgentActionReversible({ - id: parsed.data.id, - reversible: parsed.data.reversible, - }); + applyActionLogUpdatedSse( + queryClient, + resumeThreadId, + parsed.data.id, + parsed.data.reversible + ); break; } @@ -1502,6 +1530,11 @@ export default function NewChatPage() { return { ...part, args: decision.edited_action.args, // Update displayed args + // Sync argsText so the rendered card shows + // the edited inputs — assistant-ui prefers + // caller-supplied argsText over + // JSON.stringify(args). + argsText: JSON.stringify(decision.edited_action.args, null, 2), result: { ...(part.result as Record<string, unknown>), __decided__: decisionType, @@ -1712,6 +1745,10 @@ export default function NewChatPage() { ); }; const scheduleFlush = () => batcher.schedule(flushMessages); + const forceFlush = () => { + scheduleFlush(); + batcher.flush(); + }; for await (const parsed of readSSEStream(response)) { switch (parsed.type) { @@ -1748,13 +1785,20 @@ export default function NewChatPage() { false, parsed.langchainToolCallId ); - batcher.flush(); + forceFlush(); break; - case "tool-input-available": + case "tool-input-delta": + appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); + scheduleFlush(); + break; + + case "tool-input-available": { + const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); if (toolCallIndices.has(parsed.toolCallId)) { updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {}, + argsText: finalArgsText, langchainToolCallId: parsed.langchainToolCallId, }); } else { @@ -1767,9 +1811,13 @@ export default function NewChatPage() { false, parsed.langchainToolCallId ); + updateToolCall(contentPartsState, parsed.toolCallId, { + argsText: finalArgsText, + }); } - batcher.flush(); + forceFlush(); break; + } case "tool-output-available": updateToolCall(contentPartsState, parsed.toolCallId, { @@ -1786,7 +1834,7 @@ export default function NewChatPage() { } } } - batcher.flush(); + forceFlush(); break; case "data-thinking-step": { @@ -1802,34 +1850,21 @@ export default function NewChatPage() { } case "data-action-log": { - const al = parsed.data; - const matchedToolCallId = al.lc_tool_call_id - ? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id) - : null; - upsertAgentAction({ - action: { - id: al.id, - threadId, - lcToolCallId: al.lc_tool_call_id, - chatTurnId: al.chat_turn_id, - toolName: al.tool_name, - reversible: al.reversible, - reverseDescriptorPresent: al.reverse_descriptor_present, - error: al.error, - revertedByActionId: null, - isRevertAction: false, - createdAt: al.created_at, - }, - toolCallId: matchedToolCallId, - }); + if (threadId !== null) { + applyActionLogSse(queryClient, threadId, searchSpaceId, parsed.data); + } break; } case "data-action-log-updated": { - updateAgentActionReversible({ - id: parsed.data.id, - reversible: parsed.data.reversible, - }); + if (threadId !== null) { + applyActionLogUpdatedSse( + queryClient, + threadId, + parsed.data.id, + parsed.data.reversible + ); + } break; } @@ -1866,12 +1901,16 @@ export default function NewChatPage() { : `Reverted ${summary.reverted} downstream actions before regenerating.` ); } - for (const r of summary.results) { - if (r.status === "reverted" || r.status === "already_reverted") { - markAgentActionReverted({ - id: r.action_id, - newActionId: r.new_action_id ?? null, - }); + if (threadId !== null) { + for (const r of summary.results) { + if (r.status === "reverted" || r.status === "already_reverted") { + markActionRevertedInCache( + queryClient, + threadId, + r.action_id, + r.new_action_id ?? null + ); + } } } break; @@ -2019,16 +2058,26 @@ export default function NewChatPage() { const downstream = messages.slice(editedIndex + 1); downstreamTotalCount = downstream.length; const seenTurns = new Set<string>(); + const downstreamTurnIds = new Set<string>(); for (const m of downstream) { const meta = (m.metadata ?? {}) as { custom?: { chatTurnId?: string } }; const tid = meta.custom?.chatTurnId; if (!tid || seenTurns.has(tid)) continue; seenTurns.add(tid); - const turnActions = agentActionsByChatTurnId.get(tid) ?? []; - for (const a of turnActions) { - if (a.reversible && a.revertedByActionId === null && !a.isRevertAction && !a.error) { - downstreamReversibleCount += 1; - } + downstreamTurnIds.add(tid); + } + // Source of truth: the unified react-query cache. Every + // action whose ``chat_turn_id`` belongs to the slice we're + // about to drop counts toward the prompt. + for (const a of agentActionItems) { + if (!a.chat_turn_id || !downstreamTurnIds.has(a.chat_turn_id)) continue; + if ( + a.reversible && + (a.reverted_by_action_id === null || a.reverted_by_action_id === undefined) && + !a.is_revert_action && + (a.error === null || a.error === undefined) + ) { + downstreamReversibleCount += 1; } } } @@ -2052,7 +2101,7 @@ export default function NewChatPage() { downstreamTotalCount, }); }, - [handleRegenerate, messages, agentActionsByChatTurnId] + [handleRegenerate, messages, agentActionItems] ); const handleEditDialogChoice = useCallback( diff --git a/surfsense_web/atoms/chat/agent-actions.atom.ts b/surfsense_web/atoms/chat/agent-actions.atom.ts deleted file mode 100644 index 7830c8751..000000000 --- a/surfsense_web/atoms/chat/agent-actions.atom.ts +++ /dev/null @@ -1,194 +0,0 @@ -"use client"; - -import { atom } from "jotai"; - -/** - * Minimal per-row projection of ``AgentActionLog`` that the tool card - * needs to decide whether to render a Revert button. - * - * Fields are deliberately a subset of the full ``AgentAction`` so the - * SSE side-channel (``data-action-log`` / ``data-action-log-updated``) - * can populate them without depending on the REST endpoint - * ``GET /threads/.../actions`` (which 503s when - * ``SURFSENSE_ENABLE_ACTION_LOG`` is off). - */ -export interface AgentActionLite { - id: number; - threadId: number | null; - lcToolCallId: string | null; - chatTurnId: string | null; - toolName: string; - reversible: boolean; - reverseDescriptorPresent: boolean; - error: boolean; - revertedByActionId: number | null; - isRevertAction: boolean; - createdAt: string | null; -} - -/** - * Map keyed off the LangChain ``tool_call.id`` (mirrors ``ContentPart - * tool-call.langchainToolCallId``). - */ -export const agentActionByLcIdAtom = atom<Map<string, AgentActionLite>>(new Map()); - -/** - * Parallel map keyed off the synthetic chat-card ``toolCallId`` - * (``call_<run-id>``) so ``ToolFallback`` (which only receives the - * synthetic id from assistant-ui) can join its card to the action log. - * - * Both maps are kept in sync by ``upsertAgentActionAtom``. - */ -export const agentActionByToolCallIdAtom = atom<Map<string, AgentActionLite>>(new Map()); - -/** - * Index keyed by ``chat_turn_id`` so the per-turn revert UI can answer - * "how many reversible actions does this assistant turn contain?" in - * O(1). Each entry's array is ordered by insertion (which - * for a single turn matches ``created_at`` because action-log writes - * happen synchronously). - */ -export const agentActionsByChatTurnIdAtom = atom<Map<string, AgentActionLite[]>>(new Map()); - -/** - * Action to upsert one ``AgentActionLite`` row. - * - * ``toolCallId`` is the synthetic card id (``call_<run-id>`` from - * ``stream_new_chat.py``). When provided alongside ``lcToolCallId``, the - * action is indexed under BOTH ids so the tool card can perform the - * lookup without going via the streaming state. - */ -export const upsertAgentActionAtom = atom( - null, - (_get, set, payload: { action: AgentActionLite; toolCallId?: string | null }) => { - const { action, toolCallId } = payload; - const upsertInto = ( - prev: Map<string, AgentActionLite>, - key: string - ): Map<string, AgentActionLite> => { - const next = new Map(prev); - const existing = next.get(key); - next.set(key, { - ...action, - // Preserve the local "reverted" bookkeeping if a reversibility - // flip arrives AFTER the user already reverted via the REST - // route. We never want a stale ``reversible=true`` event to - // resurrect a Reverted card. - revertedByActionId: existing?.revertedByActionId ?? action.revertedByActionId, - isRevertAction: existing?.isRevertAction ?? action.isRevertAction, - }); - return next; - }; - if (action.lcToolCallId) { - set(agentActionByLcIdAtom, (prev) => upsertInto(prev, action.lcToolCallId as string)); - } - if (toolCallId) { - set(agentActionByToolCallIdAtom, (prev) => upsertInto(prev, toolCallId)); - } - if (action.chatTurnId) { - set(agentActionsByChatTurnIdAtom, (prev) => { - const next = new Map(prev); - const turnId = action.chatTurnId as string; - const existing = next.get(turnId) ?? []; - const priorEntry = existing.find((row) => row.id === action.id); - const merged: AgentActionLite = { - ...action, - revertedByActionId: priorEntry?.revertedByActionId ?? action.revertedByActionId, - isRevertAction: priorEntry?.isRevertAction ?? action.isRevertAction, - }; - const others = existing.filter((row) => row.id !== action.id); - next.set(turnId, [...others, merged]); - return next; - }); - } - } -); - -function mutateById( - prev: Map<string, AgentActionLite>, - id: number, - mutator: (entry: AgentActionLite) => AgentActionLite -): Map<string, AgentActionLite> { - let mutated = false; - const next = new Map(prev); - for (const [key, value] of next) { - if (value.id === id) { - next.set(key, mutator(value)); - mutated = true; - } - } - return mutated ? next : prev; -} - -function mutateByIdInTurnIndex( - prev: Map<string, AgentActionLite[]>, - id: number, - mutator: (entry: AgentActionLite) => AgentActionLite -): Map<string, AgentActionLite[]> { - let mutated = false; - const next = new Map(prev); - for (const [key, list] of next) { - let listMutated = false; - const updated = list.map((row) => { - if (row.id === id) { - listMutated = true; - return mutator(row); - } - return row; - }); - if (listMutated) { - next.set(key, updated); - mutated = true; - } - } - return mutated ? next : prev; -} - -/** - * Action to flip an existing entry's ``reversible`` flag, keyed by the - * AgentActionLog row id (the SSE ``data-action-log-updated`` payload - * does NOT carry ``lcToolCallId``). - */ -export const updateAgentActionReversibleAtom = atom( - null, - (_get, set, payload: { id: number; reversible: boolean }) => { - const apply = (entry: AgentActionLite): AgentActionLite => ({ - ...entry, - reversible: payload.reversible, - }); - set(agentActionByLcIdAtom, (prev) => mutateById(prev, payload.id, apply)); - set(agentActionByToolCallIdAtom, (prev) => mutateById(prev, payload.id, apply)); - set(agentActionsByChatTurnIdAtom, (prev) => mutateByIdInTurnIndex(prev, payload.id, apply)); - } -); - -/** Action to mark an existing entry as reverted (post-revert call). */ -export const markAgentActionRevertedAtom = atom( - null, - (_get, set, payload: { id: number; newActionId: number | null }) => { - const apply = (entry: AgentActionLite): AgentActionLite => ({ - ...entry, - revertedByActionId: payload.newActionId ?? -1, - }); - set(agentActionByLcIdAtom, (prev) => mutateById(prev, payload.id, apply)); - set(agentActionByToolCallIdAtom, (prev) => mutateById(prev, payload.id, apply)); - set(agentActionsByChatTurnIdAtom, (prev) => mutateByIdInTurnIndex(prev, payload.id, apply)); - } -); - -/** Mark every action in a turn as reverted, given a list of (id, newActionId) pairs. */ -export const markAgentActionsRevertedBatchAtom = atom( - null, - (_get, set, payload: { entries: Array<{ id: number; newActionId: number | null }> }) => { - for (const entry of payload.entries) { - set(markAgentActionRevertedAtom, entry); - } - } -); - -/** Reset all maps (e.g. when the active thread changes). */ -export const resetAgentActionMapAtom = atom(null, (_get, set) => { - set(agentActionByLcIdAtom, new Map()); - set(agentActionByToolCallIdAtom, new Map()); - set(agentActionsByChatTurnIdAtom, new Map()); -}); diff --git a/surfsense_web/components/agent-action-log/action-log-sheet.tsx b/surfsense_web/components/agent-action-log/action-log-sheet.tsx index 68d2ffef3..32c25771a 100644 --- a/surfsense_web/components/agent-action-log/action-log-sheet.tsx +++ b/surfsense_web/components/agent-action-log/action-log-sheet.tsx @@ -1,9 +1,9 @@ "use client"; -import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { useQueryClient } from "@tanstack/react-query"; import { useAtom, useAtomValue } from "jotai"; import { Activity, RefreshCcw } from "lucide-react"; -import { useCallback, useMemo } from "react"; +import { useCallback } from "react"; import { actionLogSheetAtom } from "@/atoms/agent/action-log-sheet.atom"; import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom"; import { Badge } from "@/components/ui/badge"; @@ -17,15 +17,12 @@ import { SheetTitle, } from "@/components/ui/sheet"; import { Skeleton } from "@/components/ui/skeleton"; -import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; +import { + agentActionsQueryKey, + useAgentActionsQuery, +} from "@/hooks/use-agent-actions-query"; import { ActionLogItem } from "./action-log-item"; -const ACTION_LOG_PAGE_SIZE = 50; - -function actionLogQueryKey(threadId: number) { - return ["agent-actions", threadId] as const; -} - function EmptyState() { return ( <div className="flex flex-1 flex-col items-center justify-center gap-3 px-6 text-center"> @@ -85,25 +82,17 @@ export function ActionLogSheet() { const threadId = state.threadId; - const { data, isLoading, isFetching, isError, error, refetch } = useQuery({ - queryKey: threadId !== null ? actionLogQueryKey(threadId) : ["agent-actions", "none"], - queryFn: () => - agentActionsApiService.listForThread(threadId as number, { - page: 0, - pageSize: ACTION_LOG_PAGE_SIZE, - }), - enabled: state.open && threadId !== null && actionLogEnabled, - staleTime: 15 * 1000, - }); + const { data, items, isLoading, isFetching, isError, error, refetch } = useAgentActionsQuery( + threadId, + { enabled: state.open && actionLogEnabled } + ); const handleRevertSuccess = useCallback(() => { if (threadId !== null) { - queryClient.invalidateQueries({ queryKey: actionLogQueryKey(threadId) }); + queryClient.invalidateQueries({ queryKey: agentActionsQueryKey(threadId) }); } }, [queryClient, threadId]); - const items = useMemo(() => data?.items ?? [], [data]); - return ( <Sheet open={state.open} onOpenChange={(open) => setState((s) => ({ ...s, open }))}> <SheetContent diff --git a/surfsense_web/components/assistant-ui/revert-turn-button.tsx b/surfsense_web/components/assistant-ui/revert-turn-button.tsx index af71299d0..733162c80 100644 --- a/surfsense_web/components/assistant-ui/revert-turn-button.tsx +++ b/surfsense_web/components/assistant-ui/revert-turn-button.tsx @@ -4,26 +4,22 @@ * "Revert turn" button rendered at the bottom of every completed * assistant turn that has at least one reversible action. * - * The button reads the action map keyed by ``chat_turn_id`` from the - * SSE side-channel (``data-action-log`` events). It shows a confirmation - * dialog summarising "N reversible / M total" and, on confirm, calls - * ``POST /threads/{id}/revert-turn/{chat_turn_id}``. + * The button reads from the unified ``useAgentActionsQuery`` cache + * (the SAME react-query cache the agent-actions sheet and the inline + * Revert button consume) filtered by ``chat_turn_id``. It shows a + * confirmation dialog summarising "N reversible / M total" and, on + * confirm, calls ``POST /threads/{id}/revert-turn/{chat_turn_id}``. * * The route returns a per-action result list and never collapses the * batch into a 4xx — so we render any failed/not_reversible rows inline * with their messages. */ -import { useAtomValue, useSetAtom } from "jotai"; -import { selectAtom } from "jotai/utils"; +import { useQueryClient } from "@tanstack/react-query"; +import { useAtomValue } from "jotai"; import { CheckIcon, RotateCcw, XCircleIcon } from "lucide-react"; import { useMemo, useState } from "react"; import { toast } from "sonner"; -import { - type AgentActionLite, - agentActionsByChatTurnIdAtom, - markAgentActionsRevertedBatchAtom, -} from "@/atoms/chat/agent-actions.atom"; import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom"; import { AlertDialog, @@ -38,6 +34,10 @@ import { } from "@/components/ui/alert-dialog"; import { Button } from "@/components/ui/button"; import { getToolDisplayName } from "@/contracts/enums/toolIcons"; +import { + applyRevertTurnResultsToCache, + useAgentActionsQuery, +} from "@/hooks/use-agent-actions-query"; import { agentActionsApiService, type RevertTurnActionResult, @@ -49,49 +49,33 @@ interface RevertTurnButtonProps { chatTurnId: string | null | undefined; } -// Empty-array sentinel so the per-turn ``selectAtom`` slice returns a -// stable reference when the turn has no recorded actions yet. Without -// this every render allocates a fresh ``[]`` and Jotai's -// equality check would re-render the button on unrelated turn updates. -const EMPTY_ACTIONS: readonly AgentActionLite[] = Object.freeze([]); - export function RevertTurnButton({ chatTurnId }: RevertTurnButtonProps) { const session = useAtomValue(chatSessionStateAtom); - const markRevertedBatch = useSetAtom(markAgentActionsRevertedBatchAtom); + const threadId = session?.threadId ?? null; + const queryClient = useQueryClient(); + const { findByChatTurnId } = useAgentActionsQuery(threadId); const [isReverting, setIsReverting] = useState(false); const [confirmOpen, setConfirmOpen] = useState(false); const [resultsOpen, setResultsOpen] = useState(false); const [results, setResults] = useState<RevertTurnActionResult[]>([]); - // Subscribe ONLY to the slice of the global action map that belongs - // to ``chatTurnId``. Previously the button read the whole - // ``agentActionsByChatTurnIdAtom``, which meant every action - // upsert (one per tool call) re-rendered every Revert button on - // the page. With ``selectAtom`` we re-render only when our turn's - // list reference changes — and the upsert/mark atoms produce a - // fresh list reference for the affected turn only. - const sliceAtom = useMemo( - () => - selectAtom( - agentActionsByChatTurnIdAtom, - (turnIndex) => (chatTurnId ? turnIndex.get(chatTurnId) : undefined) ?? EMPTY_ACTIONS - ), - [chatTurnId] - ); - const actions = useAtomValue(sliceAtom); + const actions = useMemo(() => findByChatTurnId(chatTurnId), [findByChatTurnId, chatTurnId]); const reversibleCount = useMemo( () => actions.filter( - (a) => a.reversible && a.revertedByActionId === null && !a.isRevertAction && !a.error + (a) => + a.reversible && + (a.reverted_by_action_id === null || a.reverted_by_action_id === undefined) && + !a.is_revert_action && + (a.error === null || a.error === undefined) ).length, [actions] ); - const totalCount = useMemo(() => actions.filter((a) => !a.isRevertAction).length, [actions]); + const totalCount = useMemo(() => actions.filter((a) => !a.is_revert_action).length, [actions]); if (!chatTurnId) return null; if (reversibleCount === 0) return null; - const threadId = session?.threadId; if (!threadId) return null; const handleRevertTurn = async () => { @@ -103,7 +87,7 @@ export function RevertTurnButton({ chatTurnId }: RevertTurnButtonProps) { .filter((r) => r.status === "reverted" || r.status === "already_reverted") .map((r) => ({ id: r.action_id, newActionId: r.new_action_id ?? null })); if (revertedEntries.length > 0) { - markRevertedBatch({ entries: revertedEntries }); + applyRevertTurnResultsToCache(queryClient, threadId, revertedEntries); } if (response.status === "ok") { toast.success( diff --git a/surfsense_web/components/assistant-ui/tool-fallback.tsx b/surfsense_web/components/assistant-ui/tool-fallback.tsx index cc7582695..66e2ebd4a 100644 --- a/surfsense_web/components/assistant-ui/tool-fallback.tsx +++ b/surfsense_web/components/assistant-ui/tool-fallback.tsx @@ -1,12 +1,12 @@ -import type { ToolCallMessagePartComponent } from "@assistant-ui/react"; -import { useAtomValue, useSetAtom } from "jotai"; -import { CheckIcon, ChevronDownIcon, ChevronUpIcon, RotateCcw, XCircleIcon } from "lucide-react"; -import { useMemo, useState } from "react"; -import { toast } from "sonner"; import { - agentActionByToolCallIdAtom, - markAgentActionRevertedAtom, -} from "@/atoms/chat/agent-actions.atom"; + type ToolCallMessagePartComponent, + useAuiState, +} from "@assistant-ui/react"; +import { useQueryClient } from "@tanstack/react-query"; +import { useAtomValue } from "jotai"; +import { CheckIcon, ChevronDownIcon, RotateCcw, XCircleIcon } from "lucide-react"; +import { useEffect, useMemo, useState } from "react"; +import { toast } from "sonner"; import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom"; import { DoomLoopApprovalToolUI, @@ -24,8 +24,17 @@ import { AlertDialogTitle, AlertDialogTrigger, } from "@/components/ui/alert-dialog"; +import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; -import { getToolDisplayName, getToolIcon } from "@/contracts/enums/toolIcons"; +import { Card } from "@/components/ui/card"; +import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/components/ui/collapsible"; +import { Separator } from "@/components/ui/separator"; +import { Spinner } from "@/components/ui/spinner"; +import { getToolDisplayName } from "@/contracts/enums/toolIcons"; +import { + markActionRevertedInCache, + useAgentActionsQuery, +} from "@/hooks/use-agent-actions-query"; import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; import { AppError } from "@/lib/error"; import { isInterruptResult } from "@/lib/hitl"; @@ -34,31 +43,128 @@ import { cn } from "@/lib/utils"; /** * Inline Revert button rendered on a tool card when the matching * ``AgentActionLog`` row is reversible and hasn't been reverted yet. - * Reads from the SSE side-channel atom keyed by the synthetic - * ``toolCallId`` so it lights up even when ``GET /threads/.../actions`` - * is gated behind ``SURFSENSE_ENABLE_ACTION_LOG=False`` (503). + * + * Reads from the unified ``useAgentActionsQuery`` cache — the SAME + * react-query cache the agent-actions sheet consumes. SSE events + * (``data-action-log`` / ``data-action-log-updated``) and + * ``POST /threads/{id}/revert/{id}`` responses both flow through the + * cache via ``setQueryData`` helpers, so the card and the sheet stay + * in lockstep on every code path: page reload, navigation, live + * stream, post-stream reversibility flip, and explicit revert clicks. + * + * Match key (in priority order): + * 1. ``a.tool_call_id === toolCallId`` — direct hit in parity_v2 when + * the model streamed ``tool_call_chunks`` so the card's synthetic + * id IS the LangChain id. + * 2. ``a.tool_call_id === langchainToolCallId`` — legacy mode (or + * parity_v2 with provider-side chunk emission) where the card's + * synthetic id is ``call_<run_id>`` and the LangChain id is + * backfilled onto the part by ``tool-output-available``. + * 3. ``(chat_turn_id, tool_name, position-within-turn)`` — fallback + * for cards whose synthetic id is ``call_<run_id>`` AND whose + * ``langchainToolCallId`` never got backfilled (provider emitted + * the tool_call as a single payload with no chunks AND streaming + * pre-dated the ``tool-output-available langchainToolCallId`` + * backfill, e.g. older threads). Reads the parent message's + * ``chatTurnId`` and ``content`` via ``useAuiState`` so we can + * match position-by-tool-name within the turn against the + * action_log rows the server returned in ``created_at`` order. */ -function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) { +function ToolCardRevertButton({ + toolCallId, + toolName, + langchainToolCallId, +}: { + toolCallId: string; + toolName: string; + langchainToolCallId?: string; +}) { const session = useAtomValue(chatSessionStateAtom); - const actionMap = useAtomValue(agentActionByToolCallIdAtom); - const markReverted = useSetAtom(markAgentActionRevertedAtom); - const action = actionMap.get(toolCallId); + const threadId = session?.threadId ?? null; + const queryClient = useQueryClient(); + const { findByToolCallId, findByChatTurnAndTool } = useAgentActionsQuery(threadId); + + // Parent message metadata, read via the narrowest possible + // selectors so this card doesn't re-render on every text-delta of + // every other part in the same message during streaming. + // + // IMPORTANT — ``useAuiState`` re-renders the component whenever the + // returned slice's identity changes. Returning ``message?.content`` + // (an array) would re-render on every token because the runtime + // rebuilds the parts array. Returning a PRIMITIVE (the position + // number) lets ``useAuiState``'s ``Object.is`` check short-circuit + // when the position hasn't actually moved — which is the common + // case during text streaming, when only ``text``/``reasoning`` + // parts are mutating and the same-toolName tool-call ordering is + // stable. (See Vercel React rule ``rerender-defer-reads``.) + const chatTurnId = useAuiState(({ message }) => { + const meta = message?.metadata as { custom?: { chatTurnId?: string } } | undefined; + return meta?.custom?.chatTurnId ?? null; + }); + const positionInTurn = useAuiState(({ message }) => { + const content = message?.content; + if (!Array.isArray(content)) return -1; + let n = -1; + for (const part of content) { + if ( + part && + typeof part === "object" && + (part as { type?: string }).type === "tool-call" && + (part as { toolName?: string }).toolName === toolName + ) { + n += 1; + if ((part as { toolCallId?: string }).toolCallId === toolCallId) return n; + } + } + return -1; + }); + + const action = useMemo(() => { + // Tier 1 + 2: O(1) Map-backed direct id match. Covers + // ~all parity_v2 streams and any legacy stream that backfilled + // ``langchainToolCallId`` via ``tool-output-available``. + const direct = + findByToolCallId(toolCallId) ?? findByToolCallId(langchainToolCallId); + if (direct) return direct; + // Tier 3: position-within-turn fallback. Only kicks in when the + // card has a synthetic ``call_<run_id>`` id AND no + // ``langchainToolCallId`` was ever backfilled — i.e. the tool + // was emitted as a single non-chunked payload AND streaming + // pre-dated the on_tool_end backfill. + if (!chatTurnId || positionInTurn < 0) return null; + const turnSameTool = findByChatTurnAndTool(chatTurnId, toolName); + return turnSameTool[positionInTurn] ?? null; + }, [ + findByToolCallId, + findByChatTurnAndTool, + toolCallId, + langchainToolCallId, + chatTurnId, + toolName, + positionInTurn, + ]); + const [isReverting, setIsReverting] = useState(false); const [confirmOpen, setConfirmOpen] = useState(false); if (!action) return null; if (!action.reversible) return null; - if (action.revertedByActionId !== null) return null; - if (action.isRevertAction) return null; - if (action.error) return null; - const threadId = session?.threadId; + if (action.reverted_by_action_id !== null && action.reverted_by_action_id !== undefined) + return null; + if (action.is_revert_action) return null; + if (action.error !== null && action.error !== undefined) return null; if (!threadId) return null; const handleRevert = async () => { setIsReverting(true); try { const response = await agentActionsApiService.revert(threadId, action.id); - markReverted({ id: action.id, newActionId: response.new_action_id ?? null }); + markActionRevertedInCache( + queryClient, + threadId, + action.id, + response.new_action_id ?? null + ); toast.success(response.message || "Action reverted."); } catch (err) { // 503 means revert is gated off on this deployment — hide the @@ -91,8 +197,17 @@ function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) { e.stopPropagation(); setConfirmOpen(true); }} + disabled={isReverting} > - <RotateCcw className="size-3.5" /> + {isReverting ? ( + // Spinner's typed props don't accept ``data-icon`` and + // it renders an <output>, not an <svg>, so Button's + // auto-sizing rule doesn't apply. Bare spinner + + // Button's gap handle layout. + <Spinner size="xs" /> + ) : ( + <RotateCcw data-icon="inline-start" /> + )} Revert </Button> </AlertDialogTrigger> @@ -101,7 +216,7 @@ function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) { <AlertDialogTitle>Revert this action?</AlertDialogTitle> <AlertDialogDescription> This will undo{" "} - <span className="font-medium">{getToolDisplayName(action.toolName)}</span> and add a + <span className="font-medium">{getToolDisplayName(action.tool_name)}</span> and add a new entry to the history. Your chat is preserved — only the changes the agent made to your knowledge base or connected apps will be rolled back where possible. </AlertDialogDescription> @@ -114,8 +229,10 @@ function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) { handleRevert(); }} disabled={isReverting} + className="gap-1.5" > - {isReverting ? "Reverting…" : "Revert"} + {isReverting && <Spinner size="xs" />} + Revert </AlertDialogAction> </AlertDialogFooter> </AlertDialogContent> @@ -123,18 +240,49 @@ function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) { ); } -const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({ - toolCallId, - toolName, - argsText, - result, - status, -}) => { - const [isExpanded, setIsExpanded] = useState(false); +/** + * Compact tool-call card. + * + * shadcn composition note: we intentionally use ``Card`` as a visual + * frame WITHOUT ``CardHeader / CardContent``. The full composition's + * ``p-6`` padding doesn't fit a compact collapsible header that IS the + * trigger; using ``Card`` alone preserves the rounded border, shadow, + * and ``bg-card`` token (semantic colors) without forcing a layout + * that doesn't fit. All status colors use semantic tokens — no manual + * dark-mode overrides, no raw hex. + */ +const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => { + const { toolCallId, toolName, argsText, result, status } = props; + // ``langchainToolCallId`` is a SurfSense-specific extension the + // streaming pipeline attaches to the tool-call content part so + // the Revert button can resolve its ``AgentActionLog`` row even + // when only the LC id is known. assistant-ui's + // ``ToolCallMessagePartProps`` doesn't list it, but the runtime + // spreads ``{...part}`` so the prop reaches us at runtime. + const langchainToolCallId = (props as { langchainToolCallId?: string }).langchainToolCallId; const isCancelled = status?.type === "incomplete" && status.reason === "cancelled"; const isError = status?.type === "incomplete" && status.reason === "error"; const isRunning = status?.type === "running" || status?.type === "requires-action"; + + /* + Per-card expansion state. Initial value is ``isRunning`` so a + card streaming in mounts already-expanded (no flash of + collapsed → expanded on first paint), while a card loaded from + history (status="complete") mounts collapsed. The useEffect + below keeps this in lockstep with this card's own ``isRunning`` + when it transitions: false → true auto-expands (e.g. a tool + that re-runs after edit), true → false auto-collapses once the + tool finishes. Because the dep is per-card ``isRunning`` and + not the chat-level streaming flag, sibling cards on the same + assistant turn each manage their own expansion independently. + Once ``isRunning`` is false the user controls expansion via + ``onOpenChange``. + */ + const [isExpanded, setIsExpanded] = useState(isRunning); + useEffect(() => { + setIsExpanded(isRunning); + }, [isRunning]); const errorData = status?.type === "incomplete" ? status.error : undefined; const serializedError = useMemo( () => (errorData && typeof errorData !== "string" ? JSON.stringify(errorData) : null), @@ -160,108 +308,207 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({ : serializedError : null; - const Icon = getToolIcon(toolName); const displayName = getToolDisplayName(toolName); + const subtitle = errorReason ?? cancelledReason; return ( - <div + <Card className={cn( - "my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none", + "my-4 max-w-lg overflow-hidden", isCancelled && "opacity-60", - isError && "border-destructive/20 bg-destructive/5" + isError && "border-destructive/30" )} > - <button - type="button" - onClick={() => setIsExpanded((prev) => !prev)} - className="flex w-full items-center gap-3 px-5 py-4 text-left transition-colors hover:bg-muted/50 focus:outline-none focus-visible:outline-none" + {/* + ``group`` lets the chevron (rendered as a sibling of the + main trigger button) read the Collapsible Root's + ``data-[state=open]`` for rotation. The Collapsible is + fully controlled via ``isExpanded`` — the useEffect + above syncs it to ``isRunning`` so the card auto-opens + while a tool streams in and auto-collapses once it + finishes. We deliberately DON'T pass ``disabled`` so + both triggers stay clickable; ``onOpenChange`` is wired + to a setter that no-ops while ``isRunning`` (see + ``handleOpenChange`` below) which keeps the card pinned + open mid-stream without losing keyboard / pointer + affordance the moment streaming ends. + */} + <Collapsible + className="group" + open={isExpanded} + onOpenChange={(next) => { + // Block manual collapse while the tool is still + // streaming — otherwise a stray click on either + // trigger would close the card and hide the live + // ``argsText`` panel mid-run. After streaming the + // user has full control again. + if (isRunning) return; + setIsExpanded(next); + }} > - <div - className={cn( - "flex size-8 shrink-0 items-center justify-center rounded-lg", - isError ? "bg-destructive/10" : isCancelled ? "bg-muted" : "bg-primary/10" - )} - > - {isError ? ( - <XCircleIcon className="size-4 text-destructive" /> - ) : isCancelled ? ( - <XCircleIcon className="size-4 text-muted-foreground" /> - ) : isRunning ? ( - <Icon className="size-4 text-primary animate-pulse" /> - ) : ( - <CheckIcon className="size-4 text-primary" /> - )} - </div> + {/* + Header row: main trigger on the left (icon + title + col), Revert + chevron-trigger on the right as + siblings of the main trigger. The chevron is wrapped + in its OWN ``CollapsibleTrigger`` (Radix supports + multiple triggers per Root) so clicking the chevron + toggles the same state as clicking the title row. + The Revert button stays a separate AlertDialog + trigger and stops propagation in its onClick so it + doesn't toggle the collapsible while opening the + confirm dialog. Keeping these as flat siblings — + rather than nesting Revert / chevron inside the + title trigger — avoids invalid HTML + (button-in-button) and lets the Revert button + render in BOTH the collapsed and expanded states. + */} + <div className="flex items-stretch transition-colors hover:bg-muted/50"> + <CollapsibleTrigger asChild> + <button + type="button" + className={cn( + "flex flex-1 min-w-0 items-center gap-3 py-4 pl-5 pr-2 text-left", + // Inset ring — Card's ``overflow-hidden`` would + // clip an ``offset-2`` ring; ``ring-inset`` + // paints inside the button box. + "focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-inset", + "disabled:cursor-default" + )} + > + <div + className={cn( + "flex size-8 shrink-0 items-center justify-center rounded-lg", + isError ? "bg-destructive/10" : isCancelled ? "bg-muted" : "bg-primary/10" + )} + > + {isError ? ( + <XCircleIcon className="size-4 text-destructive" /> + ) : isCancelled ? ( + <XCircleIcon className="size-4 text-muted-foreground" /> + ) : isRunning ? ( + <Spinner size="sm" className="text-primary" /> + ) : ( + <CheckIcon className="size-4 text-primary" /> + )} + </div> - <div className="flex-1 min-w-0"> - <p - className={cn( - "text-sm font-semibold", - isError - ? "text-destructive" - : isCancelled - ? "text-muted-foreground line-through" - : "text-foreground" - )} - > - {isRunning - ? displayName - : isCancelled - ? `Cancelled: ${displayName}` - : isError - ? `Failed: ${displayName}` - : displayName} - </p> - {isRunning && <p className="text-xs text-muted-foreground mt-0.5">Working…</p>} - {cancelledReason && ( - <p className="text-xs text-muted-foreground mt-0.5 truncate">{cancelledReason}</p> - )} - {errorReason && ( - <p className="text-xs text-destructive/80 mt-0.5 truncate">{errorReason}</p> - )} - </div> + <div className="flex flex-1 min-w-0 flex-col gap-0.5"> + <div className="flex items-center gap-2"> + <p + className={cn( + "text-sm font-semibold truncate", + isCancelled && "text-muted-foreground line-through", + isError && "text-destructive" + )} + > + {displayName} + </p> + {isRunning && <Badge variant="secondary">Running</Badge>} + {isError && <Badge variant="destructive">Failed</Badge>} + {isCancelled && <Badge variant="outline">Cancelled</Badge>} + </div> + {subtitle && ( + <p + className={cn( + "text-xs truncate", + isError ? "text-destructive/80" : "text-muted-foreground" + )} + > + {subtitle} + </p> + )} + </div> + </button> + </CollapsibleTrigger> - {!isRunning && ( - <div className="shrink-0 text-muted-foreground"> - {isExpanded ? ( - <ChevronDownIcon className="size-4" /> - ) : ( - <ChevronUpIcon className="size-4" /> - )} + {/* + Right-side controls. The Revert button is + visible whenever the matching action is + reversible — including the collapsed state — + but ``ToolCardRevertButton`` itself returns + ``null`` while a tool is still running because + no action-log row exists yet, so it doesn't + need an explicit ``isRunning`` gate here. + */} + <div className="flex shrink-0 items-center gap-2 pl-2 pr-5"> + <ToolCardRevertButton + toolCallId={toolCallId} + toolName={toolName} + langchainToolCallId={langchainToolCallId} + /> + <CollapsibleTrigger asChild> + <button + type="button" + aria-label={isExpanded ? "Collapse details" : "Expand details"} + className={cn( + "flex size-7 shrink-0 items-center justify-center rounded-md", + "text-muted-foreground hover:bg-muted hover:text-foreground", + "focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-inset", + "disabled:cursor-default" + )} + > + <ChevronDownIcon + className={cn( + "size-4 transition-transform duration-200", + "group-data-[state=open]:rotate-180" + )} + /> + </button> + </CollapsibleTrigger> </div> - )} - </button> + </div> - {isExpanded && !isRunning && ( - <> - <div className="mx-5 h-px bg-border/50" /> - <div className="px-5 py-3 space-y-3"> - {argsText && ( - <div> - <p className="text-xs font-medium text-muted-foreground mb-1">Inputs</p> - <pre className="text-xs text-foreground/80 whitespace-pre-wrap break-all"> - {argsText} - </pre> + {/* + CollapsibleContent body — auto-open while streaming + (see ``open`` prop above) so the live ``argsText`` + streams into the Inputs panel directly, no need for + a separate "Live input" panel. Native + ``overflow-auto`` instead of ``ScrollArea`` because + Radix's Viewport can let content bleed past + ``max-h-*`` in dynamic flex layouts. ``min-w-0`` on + the column wrappers guarantees ``break-all`` wraps + correctly within the bounded ``max-w-lg`` Card. + */} + <CollapsibleContent> + <Separator /> + <div className="flex flex-col gap-3 px-5 py-3"> + {(argsText || isRunning) && ( + <div className="flex flex-col gap-1 min-w-0"> + <p className="text-xs font-medium text-muted-foreground">Inputs</p> + <div className="max-h-48 overflow-auto rounded-md bg-muted/40"> + {argsText ? ( + <pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono"> + {argsText} + </pre> + ) : ( + // Bridges the brief gap between + // ``tool-input-start`` (creates the + // card, ``argsText`` undefined) and + // the first ``tool-input-delta``. + <p className="px-3 py-2 text-xs italic text-muted-foreground"> + Waiting for input… + </p> + )} + </div> </div> )} {!isCancelled && result !== undefined && ( <> - <div className="h-px bg-border/30" /> - <div> - <p className="text-xs font-medium text-muted-foreground mb-1">Result</p> - <pre className="text-xs text-foreground/80 whitespace-pre-wrap break-all"> - {typeof result === "string" ? result : serializedResult} - </pre> + <Separator /> + <div className="flex flex-col gap-1 min-w-0"> + <p className="text-xs font-medium text-muted-foreground">Result</p> + <div className="max-h-64 overflow-auto rounded-md bg-muted/40"> + <pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono"> + {typeof result === "string" ? result : serializedResult} + </pre> + </div> </div> </> )} - <div className="flex justify-end"> - <ToolCardRevertButton toolCallId={toolCallId} /> - </div> </div> - </> - )} - </div> + </CollapsibleContent> + </Collapsible> + </Card> ); }; diff --git a/surfsense_web/components/free-chat/free-chat-page.tsx b/surfsense_web/components/free-chat/free-chat-page.tsx index bfdd613e2..05db99407 100644 --- a/surfsense_web/components/free-chat/free-chat-page.tsx +++ b/surfsense_web/components/free-chat/free-chat-page.tsx @@ -22,6 +22,7 @@ import { addToolCall, appendReasoning, appendText, + appendToolInputDelta, buildContentForUI, type ContentPartsState, endReasoning, @@ -146,6 +147,10 @@ export function FreeChatPage() { ); }; const scheduleFlush = () => batcher.schedule(flushMessages); + const forceFlush = () => { + scheduleFlush(); + batcher.flush(); + }; try { for await (const parsed of readSSEStream(response)) { @@ -183,13 +188,20 @@ export function FreeChatPage() { false, parsed.langchainToolCallId ); - batcher.flush(); + forceFlush(); break; - case "tool-input-available": + case "tool-input-delta": + appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); + scheduleFlush(); + break; + + case "tool-input-available": { + const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); if (toolCallIndices.has(parsed.toolCallId)) { updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {}, + argsText: finalArgsText, langchainToolCallId: parsed.langchainToolCallId, }); } else { @@ -202,16 +214,20 @@ export function FreeChatPage() { false, parsed.langchainToolCallId ); + updateToolCall(contentPartsState, parsed.toolCallId, { + argsText: finalArgsText, + }); } - batcher.flush(); + forceFlush(); break; + } case "tool-output-available": updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output, langchainToolCallId: parsed.langchainToolCallId, }); - batcher.flush(); + forceFlush(); break; case "data-thinking-step": { diff --git a/surfsense_web/contracts/types/chat-messages.types.ts b/surfsense_web/contracts/types/chat-messages.types.ts index 0859f9f3b..ef16bb366 100644 --- a/surfsense_web/contracts/types/chat-messages.types.ts +++ b/surfsense_web/contracts/types/chat-messages.types.ts @@ -1,7 +1,13 @@ import { z } from "zod"; /** - * Raw message from database (real-time sync) + * Raw message from database (real-time sync). + * + * ``turn_id`` is included so consumers (e.g. ``convertToThreadMessage``) + * can populate ``metadata.custom.chatTurnId`` on the + * ``ThreadMessageLike`` even after the live-collab Zero re-sync. The + * inline Revert button's ``(chat_turn_id, tool_name, position)`` + * fallback in tool-fallback.tsx depends on it. */ export const rawMessage = z.object({ id: z.number(), @@ -10,6 +16,7 @@ export const rawMessage = z.object({ content: z.unknown(), author_id: z.string().nullable(), created_at: z.string(), + turn_id: z.string().nullable().optional(), }); export type RawMessage = z.infer<typeof rawMessage>; diff --git a/surfsense_web/hooks/use-agent-actions-query.ts b/surfsense_web/hooks/use-agent-actions-query.ts new file mode 100644 index 000000000..9a722fb2e --- /dev/null +++ b/surfsense_web/hooks/use-agent-actions-query.ts @@ -0,0 +1,416 @@ +"use client"; + +import { type QueryClient, useQuery } from "@tanstack/react-query"; +import { useCallback, useEffect, useMemo, useRef } from "react"; +import { + type AgentAction, + type AgentActionListResponse, + agentActionsApiService, +} from "@/lib/apis/agent-actions-api.service"; + +// ============================================================================= +// DIAGNOSTIC LOGGING — gated behind a single switch. Flip ``RevertDebug`` +// to ``true`` to trace the full SSE → cache → card → button pipeline in +// the browser console. Off by default so we don't spam production. The +// infrastructure stays in place because the underlying id-mismatch +// failure mode is rare-but-real and surfaces only at runtime. +// ============================================================================= +const RevertDebug = false; +const dbg = (...args: unknown[]) => { + if (RevertDebug && typeof window !== "undefined") { + // eslint-disable-next-line no-console + console.log("[RevertDebug]", ...args); + } +}; + +/** + * Unified store for ``AgentActionLog`` rows scoped to one thread. + * + * Replaces the previous SSE side-channel atom mess + * (``agentActionByLcIdAtom`` / ``agentActionByToolCallIdAtom`` / + * ``agentActionsByChatTurnIdAtom``) and the standalone hydration hook. + * One react-query cache entry is now the single source of truth for: + * + * * the inline Revert button on every tool-call card + * * the per-turn "Revert turn" button under each assistant message + * * the edit-from-position pre-flight that decides whether to show + * the confirmation dialog + * * the agent-actions sheet + * + * The cache is hydrated by ``GET /threads/{id}/actions`` (sized to + * 200, the server max) and updated incrementally by helpers that turn + * SSE events / revert RPC responses into ``setQueryData`` mutations. + * That keeps the card and the sheet in lockstep on every code path — + * page reload, navigation, live stream, post-stream reversibility flip, + * and explicit revert clicks. + */ + +export const ACTION_LOG_PAGE_SIZE = 200; + +/** Stable react-query key for the per-thread action list. */ +export function agentActionsQueryKey(threadId: number | null) { + return threadId !== null + ? (["agent-actions", threadId] as const) + : (["agent-actions", "none"] as const); +} + +/** Subset of the SSE ``data-action-log`` payload we care about. */ +export interface ActionLogSseEvent { + id: number; + lc_tool_call_id: string | null; + chat_turn_id: string | null; + tool_name: string; + reversible: boolean; + reverse_descriptor_present: boolean; + error: boolean; + created_at: string | null; +} + +/** + * Append or upsert a freshly-emitted ``AgentActionLog`` row into the + * thread-scoped query cache. + * + * The SSE payload is a strict subset of ``AgentAction``; missing + * fields (``args``, ``reverse_descriptor``, ``user_id``) are filled + * with ``null`` placeholders. The next refetch (sheet open, user + * focus, route stale) backfills them — but the inline Revert button + * only reads the fields the SSE payload carries, so it lights up + * immediately. + */ +export function applyActionLogSse( + queryClient: QueryClient, + threadId: number, + searchSpaceId: number, + event: ActionLogSseEvent +): void { + dbg("applyActionLogSse: incoming SSE event", { + threadId, + searchSpaceId, + event, + }); + queryClient.setQueryData<AgentActionListResponse>( + agentActionsQueryKey(threadId), + (prev) => { + const placeholder: AgentAction = { + id: event.id, + thread_id: threadId, + user_id: null, + search_space_id: searchSpaceId, + tool_name: event.tool_name, + args: null, + result_id: null, + reversible: event.reversible, + reverse_descriptor: event.reverse_descriptor_present ? {} : null, + error: event.error ? {} : null, + reverse_of: null, + reverted_by_action_id: null, + is_revert_action: false, + tool_call_id: event.lc_tool_call_id, + chat_turn_id: event.chat_turn_id, + created_at: event.created_at ?? new Date().toISOString(), + }; + if (!prev) { + return { + items: [placeholder], + total: 1, + page: 0, + page_size: ACTION_LOG_PAGE_SIZE, + has_more: false, + }; + } + const existingIdx = prev.items.findIndex((a) => a.id === event.id); + if (existingIdx >= 0) { + const merged = [...prev.items]; + const existing = merged[existingIdx]; + if (existing) { + merged[existingIdx] = { + ...existing, + reversible: event.reversible, + tool_call_id: event.lc_tool_call_id ?? existing.tool_call_id, + chat_turn_id: event.chat_turn_id ?? existing.chat_turn_id, + }; + } + dbg("applyActionLogSse: merged into existing entry", { + id: event.id, + tool_call_id: merged[existingIdx]?.tool_call_id, + reversible: merged[existingIdx]?.reversible, + }); + return { ...prev, items: merged }; + } + dbg("applyActionLogSse: appended new placeholder", { + id: event.id, + tool_call_id: placeholder.tool_call_id, + tool_name: placeholder.tool_name, + reversible: placeholder.reversible, + cacheSizeAfter: prev.items.length + 1, + }); + // REST returns newest-first — keep that ordering when + // the server eventually refetches by prepending. + return { + ...prev, + items: [placeholder, ...prev.items], + total: prev.total + 1, + }; + } + ); +} + +/** + * Apply a post-SAVEPOINT reversibility flip + * (``data-action-log-updated`` SSE event) to the cache. + */ +export function applyActionLogUpdatedSse( + queryClient: QueryClient, + threadId: number, + id: number, + reversible: boolean +): void { + dbg("applyActionLogUpdatedSse: reversibility flip", { + threadId, + id, + reversible, + }); + queryClient.setQueryData<AgentActionListResponse>( + agentActionsQueryKey(threadId), + (prev) => { + if (!prev) { + dbg("applyActionLogUpdatedSse: NO prev cache for thread; flip dropped", { + threadId, + id, + }); + return prev; + } + let mutated = false; + const items = prev.items.map((a) => { + if (a.id !== id) return a; + mutated = true; + return { ...a, reversible }; + }); + if (!mutated) { + dbg("applyActionLogUpdatedSse: id not in cache; flip dropped", { + threadId, + id, + cacheSize: prev.items.length, + cacheIds: prev.items.map((a) => a.id), + }); + } + return mutated ? { ...prev, items } : prev; + } + ); +} + +/** + * Optimistically mark ``id`` as reverted. + * + * Used by the inline / per-turn Revert button immediately after the + * server returns success so the UI flips to "Reverted" without + * waiting for a refetch. ``newActionId`` is the id of the new + * ``is_revert_action`` row the server inserted; pass ``null`` if the + * server didn't return it. + */ +export function markActionRevertedInCache( + queryClient: QueryClient, + threadId: number, + id: number, + newActionId: number | null +): void { + queryClient.setQueryData<AgentActionListResponse>( + agentActionsQueryKey(threadId), + (prev) => { + if (!prev) return prev; + let mutated = false; + const items = prev.items.map((a) => { + if (a.id !== id) return a; + mutated = true; + // ``-1`` is a sentinel meaning "we know it was reverted + // but the server didn't tell us the new row's id". + return { + ...a, + reverted_by_action_id: newActionId ?? -1, + }; + }); + return mutated ? { ...prev, items } : prev; + } + ); +} + +/** + * Apply a batch of revert results (per-turn revert response) to the + * cache. Anything in the ``reverted`` / ``already_reverted`` buckets + * gets its ``reverted_by_action_id`` set; other rows are left alone. + */ +export function applyRevertTurnResultsToCache( + queryClient: QueryClient, + threadId: number, + entries: Array<{ id: number; newActionId: number | null }> +): void { + if (entries.length === 0) return; + queryClient.setQueryData<AgentActionListResponse>( + agentActionsQueryKey(threadId), + (prev) => { + if (!prev) return prev; + const lookup = new Map(entries.map((e) => [e.id, e.newActionId])); + let mutated = false; + const items = prev.items.map((a) => { + if (!lookup.has(a.id)) return a; + mutated = true; + const newActionId = lookup.get(a.id) ?? null; + return { ...a, reverted_by_action_id: newActionId ?? -1 }; + }); + return mutated ? { ...prev, items } : prev; + } + ); +} + +/** + * Read-side hook used by the card, the turn button, the sheet, and + * the edit-from-position pre-flight. + * + * Returns the raw query state plus convenience selectors so consumers + * don't reach into ``data.items`` directly. ``enabled`` is the only + * knob — pass ``false`` to keep the query dormant when the consumer + * doesn't yet have a thread id. + */ +export function useAgentActionsQuery( + threadId: number | null, + options: { enabled?: boolean } = {} +) { + const enabled = (options.enabled ?? true) && threadId !== null; + const query = useQuery({ + queryKey: agentActionsQueryKey(threadId), + queryFn: async () => { + dbg("useAgentActionsQuery: REST fetch START", { + threadId, + pageSize: ACTION_LOG_PAGE_SIZE, + }); + const res = await agentActionsApiService.listForThread(threadId as number, { + page: 0, + pageSize: ACTION_LOG_PAGE_SIZE, + }); + dbg("useAgentActionsQuery: REST fetch DONE", { + threadId, + total: res.total, + returned: res.items.length, + items: res.items.map((a) => ({ + id: a.id, + tool_name: a.tool_name, + tool_call_id: a.tool_call_id, + reversible: a.reversible, + reverted_by_action_id: a.reverted_by_action_id, + is_revert_action: a.is_revert_action, + })), + }); + return res; + }, + enabled, + staleTime: 15 * 1000, + }); + + const items = useMemo(() => query.data?.items ?? [], [query.data]); + + // Index ``items`` once per change so the lookups below are O(1) + // instead of O(N) per card per render. With the cache sized to 200 + // rows and many tool cards visible at once, the unindexed scan was + // the hottest path on every assistant text-delta. (Vercel React + // rule ``js-index-maps`` / ``js-set-map-lookups``.) + const byToolCallId = useMemo(() => { + const m = new Map<string, AgentAction>(); + for (const a of items) { + if (a.tool_call_id) m.set(a.tool_call_id, a); + } + return m; + }, [items]); + + // Pre-grouped + pre-sorted (oldest-first, the order the agent + // actually executed them in) so the (chat_turn_id, tool_name, + // position) fallback in ``tool-fallback.tsx`` is also O(1) per + // card. Excludes ``is_revert_action`` rows so the position index + // matches the agent's original execution order. + const byTurnAndTool = useMemo(() => { + const m = new Map<string, AgentAction[]>(); + for (const a of items) { + if (!a.chat_turn_id || a.is_revert_action) continue; + const key = `${a.chat_turn_id}::${a.tool_name}`; + const bucket = m.get(key); + if (bucket) bucket.push(a); + else m.set(key, [a]); + } + for (const bucket of m.values()) { + bucket.sort( + (a, b) => + new Date(a.created_at).getTime() - new Date(b.created_at).getTime() + ); + } + return m; + }, [items]); + + // Snapshot the cache shape when its size changes — easiest way to + // spot when the cache is empty or stale at the moment a card + // mounts. Tracked on a ref so we don't re-run the diff on + // reference-equal cache reads. + const lastSnapshotRef = useRef<{ threadId: number | null; size: number } | null>(null); + useEffect(() => { + const last = lastSnapshotRef.current; + if (!last || last.threadId !== threadId || last.size !== items.length) { + dbg("useAgentActionsQuery: cache snapshot", { + threadId, + enabled, + itemCount: items.length, + itemKeys: items.slice(0, 8).map((a) => ({ + id: a.id, + tool_name: a.tool_name, + tool_call_id: a.tool_call_id, + chat_turn_id: a.chat_turn_id, + reversible: a.reversible, + })), + }); + lastSnapshotRef.current = { threadId, size: items.length }; + } + }, [threadId, enabled, items]); + + const findByToolCallId = useCallback( + (toolCallId: string | null | undefined): AgentAction | null => { + if (!toolCallId) return null; + const found = byToolCallId.get(toolCallId) ?? null; + if (!found && items.length > 0) { + dbg("findByToolCallId: MISS", { + queriedToolCallId: toolCallId, + itemCount: items.length, + availableToolCallIds: Array.from(byToolCallId.keys()), + }); + } + return found; + }, + [byToolCallId, items.length] + ); + + const findByChatTurnId = useCallback( + (chatTurnId: string | null | undefined): AgentAction[] => { + if (!chatTurnId) return []; + // Per-turn aggregation is uncommon enough (only the + // "Revert turn" button uses it) that re-scanning is fine; + // indexing it would just bloat memory. + return items.filter((a) => a.chat_turn_id === chatTurnId); + }, + [items] + ); + + const findByChatTurnAndTool = useCallback( + ( + chatTurnId: string | null | undefined, + toolName: string | null | undefined + ): AgentAction[] => { + if (!chatTurnId || !toolName) return []; + return byTurnAndTool.get(`${chatTurnId}::${toolName}`) ?? []; + }, + [byTurnAndTool] + ); + + return { + ...query, + items, + findByToolCallId, + findByChatTurnId, + findByChatTurnAndTool, + }; +} diff --git a/surfsense_web/hooks/use-messages-sync.ts b/surfsense_web/hooks/use-messages-sync.ts index ddbe8a757..5ccda23a5 100644 --- a/surfsense_web/hooks/use-messages-sync.ts +++ b/surfsense_web/hooks/use-messages-sync.ts @@ -31,6 +31,14 @@ export function useMessagesSync( content: msg.content, author_id: msg.authorId ?? null, created_at: new Date(msg.createdAt).toISOString(), + // Forward the per-turn correlation id so post-stream Zero + // re-syncs preserve ``metadata.custom.chatTurnId`` on the + // converted ``ThreadMessageLike``. Without this the inline + // Revert button's ``(chat_turn_id, tool_name, position)`` + // fallback breaks the moment Zero overwrites the messages + // state after a live stream completes (see + // ``handleSyncedMessagesUpdate`` in the chat page). + turn_id: msg.turnId ?? null, })); onMessagesUpdateRef.current(mapped); diff --git a/surfsense_web/lib/chat/streaming-state.ts b/surfsense_web/lib/chat/streaming-state.ts index 26fd7b98c..54faf7e7c 100644 --- a/surfsense_web/lib/chat/streaming-state.ts +++ b/surfsense_web/lib/chat/streaming-state.ts @@ -16,6 +16,23 @@ export type ContentPart = toolName: string; args: Record<string, unknown>; result?: unknown; + /** + * Live / finalized JSON text for the tool's input arguments. + * + * - During streaming: accumulated partial JSON text from + * ``tool-input-delta`` events (may be invalid JSON + * mid-stream). assistant-ui's argsText parser tolerates + * invalid JSON gracefully (changelog 0.7.32 / 0.7.78). + * - On completion (``tool-input-available``): replaced with + * ``JSON.stringify(input, null, 2)`` so the post-stream + * card renders pretty-printed JSON instead of the + * model's possibly-fragmented formatting. + * + * Per assistant-ui ``ThreadMessageLike`` precedence + * (changelog 0.11.6 ``d318c83``), when ``argsText`` is + * supplied it wins over ``JSON.stringify(args)``. + */ + argsText?: string; /** * Authoritative LangChain ``tool_call.id`` propagated by the backend * via ``langchainToolCallId`` on tool-input-start/available and @@ -282,12 +299,22 @@ export function findToolCallIdByLcId( export function updateToolCall( state: ContentPartsState, toolCallId: string, - update: { args?: Record<string, unknown>; result?: unknown; langchainToolCallId?: string } + update: { + args?: Record<string, unknown>; + argsText?: string; + result?: unknown; + langchainToolCallId?: string; + } ): void { const index = state.toolCallIndices.get(toolCallId); if (index !== undefined && state.contentParts[index]?.type === "tool-call") { const tc = state.contentParts[index] as ContentPart & { type: "tool-call" }; if (update.args) tc.args = update.args; + // ``!== undefined`` (NOT a truthy check): an explicit empty + // string CAN clear, and a finalization with + // ``JSON.stringify({}, null, 2) === "{}"`` (truthy but + // represents an empty-input call) still applies. + if (update.argsText !== undefined) tc.argsText = update.argsText; if (update.result !== undefined) tc.result = update.result; // Only backfill langchainToolCallId if not already set — the // authoritative ``on_tool_end`` value should override an earlier @@ -299,6 +326,25 @@ export function updateToolCall( } } +/** + * Append a streamed args-delta chunk to the active tool call's + * ``argsText``. No-ops when no card has been registered yet for the + * given ``toolCallId`` (the matching ``tool-input-start`` either lost + * the wire race or this id never had a card — either way the deltas + * have nowhere safe to land). + */ +export function appendToolInputDelta( + state: ContentPartsState, + toolCallId: string, + delta: string +): void { + const idx = state.toolCallIndices.get(toolCallId); + if (idx === undefined) return; + const tc = state.contentParts[idx]; + if (tc?.type !== "tool-call") return; + tc.argsText = (tc.argsText ?? "") + delta; +} + function _hasInterruptResult(part: ContentPart): boolean { if (part.type !== "tool-call") return false; const r = (part as { result?: unknown }).result; @@ -371,6 +417,18 @@ export type SSEEvent = /** Authoritative LangChain ``tool_call.id``. Optional. */ langchainToolCallId?: string; } + | { + /** + * Live tool-call argument delta. Concatenated into + * ``argsText`` on the matching ``tool-call`` content part + * by ``appendToolInputDelta``. parity_v2 only — the legacy + * code path emits ``tool-input-available`` without prior + * deltas. + */ + type: "tool-input-delta"; + toolCallId: string; + inputTextDelta: string; + } | { type: "tool-input-available"; toolCallId: string; diff --git a/surfsense_web/zero/schema/chat.ts b/surfsense_web/zero/schema/chat.ts index 0293059fd..fb3d7651e 100644 --- a/surfsense_web/zero/schema/chat.ts +++ b/surfsense_web/zero/schema/chat.ts @@ -8,6 +8,13 @@ export const newChatMessageTable = table("new_chat_messages") threadId: number().from("thread_id"), authorId: string().optional().from("author_id"), createdAt: number().from("created_at"), + // Per-turn correlation id sourced from ``configurable.turn_id`` + // at streaming time. Required by the inline Revert button's + // (chat_turn_id, tool_name, position) fallback in tool-fallback.tsx + // — without it the live-collab Zero sync would clobber the + // metadata we set during streaming and the button would vanish + // the moment Zero re-syncs after the stream finishes. + turnId: string().optional().from("turn_id"), }) .primaryKey("id"); From 1ce122cc99cab31fde5323692b1c835f3a763cd5 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 16:05:58 +0530 Subject: [PATCH 20/68] feat(database): change alembic number and add idempotency --- ...34_add_thread_auto_model_pinning_fields.py | 63 ---------------- ...38_add_thread_auto_model_pinning_fields.py | 72 +++++++++++++++++++ 2 files changed, 72 insertions(+), 63 deletions(-) delete mode 100644 surfsense_backend/alembic/versions/134_add_thread_auto_model_pinning_fields.py create mode 100644 surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py diff --git a/surfsense_backend/alembic/versions/134_add_thread_auto_model_pinning_fields.py b/surfsense_backend/alembic/versions/134_add_thread_auto_model_pinning_fields.py deleted file mode 100644 index ab1643b02..000000000 --- a/surfsense_backend/alembic/versions/134_add_thread_auto_model_pinning_fields.py +++ /dev/null @@ -1,63 +0,0 @@ -"""134_add_thread_auto_model_pinning_fields - -Revision ID: 134 -Revises: 133 -Create Date: 2026-04-29 - -Add thread-level fields to persist Auto (Fastest) model pinning metadata: -- pinned_llm_config_id: concrete resolved config id used for this thread -- pinned_auto_mode: auto policy identifier (currently "auto_fastest") -- pinned_at: timestamp when the pin was created/refreshed -""" - -from __future__ import annotations - -from collections.abc import Sequence - -import sqlalchemy as sa - -from alembic import op - -revision: str = "134" -down_revision: str | None = "133" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -def upgrade() -> None: - op.add_column( - "new_chat_threads", - sa.Column("pinned_llm_config_id", sa.Integer(), nullable=True), - ) - op.add_column( - "new_chat_threads", - sa.Column("pinned_auto_mode", sa.String(length=32), nullable=True), - ) - op.add_column( - "new_chat_threads", - sa.Column("pinned_at", sa.TIMESTAMP(timezone=True), nullable=True), - ) - - op.create_index( - "ix_new_chat_threads_pinned_llm_config_id", - "new_chat_threads", - ["pinned_llm_config_id"], - unique=False, - ) - op.create_index( - "ix_new_chat_threads_pinned_auto_mode", - "new_chat_threads", - ["pinned_auto_mode"], - unique=False, - ) - - -def downgrade() -> None: - op.drop_index("ix_new_chat_threads_pinned_auto_mode", table_name="new_chat_threads") - op.drop_index( - "ix_new_chat_threads_pinned_llm_config_id", table_name="new_chat_threads" - ) - - op.drop_column("new_chat_threads", "pinned_at") - op.drop_column("new_chat_threads", "pinned_auto_mode") - op.drop_column("new_chat_threads", "pinned_llm_config_id") diff --git a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py new file mode 100644 index 000000000..6e4b77cc7 --- /dev/null +++ b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py @@ -0,0 +1,72 @@ +"""138_add_thread_auto_model_pinning_fields + +Revision ID: 138 +Revises: 137 +Create Date: 2026-04-30 + +Add thread-level fields to persist Auto (Fastest) model pinning metadata: +- pinned_llm_config_id: concrete resolved config id used for this thread +- pinned_auto_mode: auto policy identifier (currently "auto_fastest") +- pinned_at: timestamp when the pin was created/refreshed + +Idempotent: this migration was originally numbered 134 on the +``feat/split-auto-free-premium`` branch and was renumbered to 138 during +the merge with ``upstream/dev`` (which claimed 134-137). Some databases +already have these columns/indexes from when the original 134 ran, so we +use ``IF NOT EXISTS`` to make re-application a no-op for those DBs while +still creating the schema on fresh databases. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op + +revision: str = "138" +down_revision: str | None = "137" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.execute( + "ALTER TABLE new_chat_threads " + "ADD COLUMN IF NOT EXISTS pinned_llm_config_id INTEGER" + ) + op.execute( + "ALTER TABLE new_chat_threads " + "ADD COLUMN IF NOT EXISTS pinned_auto_mode VARCHAR(32)" + ) + op.execute( + "ALTER TABLE new_chat_threads " + "ADD COLUMN IF NOT EXISTS pinned_at TIMESTAMP WITH TIME ZONE" + ) + + op.execute( + "CREATE INDEX IF NOT EXISTS ix_new_chat_threads_pinned_llm_config_id " + "ON new_chat_threads (pinned_llm_config_id)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_new_chat_threads_pinned_auto_mode " + "ON new_chat_threads (pinned_auto_mode)" + ) + + +def downgrade() -> None: + op.execute( + "DROP INDEX IF EXISTS ix_new_chat_threads_pinned_auto_mode" + ) + op.execute( + "DROP INDEX IF EXISTS ix_new_chat_threads_pinned_llm_config_id" + ) + + op.execute( + "ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_at" + ) + op.execute( + "ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_auto_mode" + ) + op.execute( + "ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_llm_config_id" + ) From 2a01711bc9f966f25fe652fb2063357cb74ec99b Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 16:20:14 +0530 Subject: [PATCH 21/68] feat(chat): expand error handling for chat operations by introducing a passthrough code set, improving response management and user feedback --- ...138_add_thread_auto_model_pinning_fields.py | 7 ------- .../unit/test_stream_new_chat_contract.py | 16 ++++++++++++---- .../new-chat/[[...chat_id]]/page.tsx | 18 ++++++++++++++---- 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py index 6e4b77cc7..1ea549975 100644 --- a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py +++ b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py @@ -8,13 +8,6 @@ Add thread-level fields to persist Auto (Fastest) model pinning metadata: - pinned_llm_config_id: concrete resolved config id used for this thread - pinned_auto_mode: auto policy identifier (currently "auto_fastest") - pinned_at: timestamp when the pin was created/refreshed - -Idempotent: this migration was originally numbered 134 on the -``feat/split-auto-free-premium`` branch and was renumbered to 138 during -the merge with ``upstream/dev`` (which claimed 134-137). Some databases -already have these columns/indexes from when the original 134 ran, so we -use ``IF NOT EXISTS`` to make re-application a no-op for those DBs while -still creating the schema on fresh databases. """ from __future__ import annotations diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 9f4280063..86ea7edd1 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -231,10 +231,18 @@ def test_network_send_failures_use_unified_retry_toast_message(): assert 'userMessage: "Message not sent. Please retry."' in classifier_source assert 'userMessage: "Connection issue. Please try again."' in classifier_source assert "tagPreAcceptSendFailure(error)" in page_source - assert 'existingCode === "THREAD_BUSY"' in page_source - assert 'existingCode === "AUTH_EXPIRED"' in page_source - assert 'existingCode === "UNAUTHORIZED"' in page_source - assert 'existingCode === "RATE_LIMITED"' in page_source + assert "const passthroughCodes = new Set([" in page_source + assert '"PREMIUM_QUOTA_EXHAUSTED"' in page_source + assert '"THREAD_BUSY"' in page_source + assert '"AUTH_EXPIRED"' in page_source + assert '"UNAUTHORIZED"' in page_source + assert '"RATE_LIMITED"' in page_source + assert '"NETWORK_ERROR"' in page_source + assert '"STREAM_PARSE_ERROR"' in page_source + assert '"TOOL_EXECUTION_ERROR"' in page_source + assert '"PERSIST_MESSAGE_FAILED"' in page_source + assert '"SERVER_ERROR"' in page_source + assert "passthroughCodes.has(existingCode)" in page_source assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in page_source assert 'errorCode: "NETWORK_ERROR"' not in page_source assert "Failed to start chat. Please try again." not in page_source diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index f21a0a30b..239afaf73 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -227,11 +227,21 @@ function tagPreAcceptSendFailure(error: unknown): unknown { if (error instanceof Error) { const withCode = error as Error & { errorCode?: string; code?: string }; const existingCode = withCode.errorCode ?? withCode.code; + const passthroughCodes = new Set([ + "PREMIUM_QUOTA_EXHAUSTED", + "THREAD_BUSY", + "AUTH_EXPIRED", + "UNAUTHORIZED", + "RATE_LIMITED", + "NETWORK_ERROR", + "STREAM_PARSE_ERROR", + "TOOL_EXECUTION_ERROR", + "PERSIST_MESSAGE_FAILED", + "SERVER_ERROR", + ]); if ( - existingCode === "THREAD_BUSY" || - existingCode === "AUTH_EXPIRED" || - existingCode === "UNAUTHORIZED" || - existingCode === "RATE_LIMITED" + existingCode && + passthroughCodes.has(existingCode) ) { return Object.assign(error, { errorCode: existingCode }); } From 1d6d7e3eb10f814aafdff5430af4033af04bb176 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 16:33:13 +0530 Subject: [PATCH 22/68] refactor(chat): remove unused agent action handlers from NewChatPage component to streamline code and improve maintainability --- .../[search_space_id]/new-chat/[[...chat_id]]/page.tsx | 7 ------- 1 file changed, 7 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index df1290971..fe625f169 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -1512,8 +1512,6 @@ export default function NewChatPage() { tokenUsageStore, pendingUserImageUrls, setPendingUserImageUrls, - upsertAgentAction, - updateAgentActionReversible, handleStreamTerminalError, handleChatFailure, persistAssistantTurn, @@ -1894,8 +1892,6 @@ export default function NewChatPage() { messages, searchSpaceId, tokenUsageStore, - upsertAgentAction, - updateAgentActionReversible, handleStreamTerminalError, persistAssistantTurn, ] @@ -2433,9 +2429,6 @@ export default function NewChatPage() { messageDocumentsMap, setMessageDocumentsMap, tokenUsageStore, - upsertAgentAction, - updateAgentActionReversible, - markAgentActionReverted, handleStreamTerminalError, persistAssistantTurn, persistUserTurn, From 6465ea181a25a8c6d003572ea4707aa9e1dcf3cc Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 18:09:18 +0530 Subject: [PATCH 23/68] refactor(chat): streamline NewChatPage component by removing unused functions and integrating new stream handling utilities for improved performance --- .../new-chat/[[...chat_id]]/page.tsx | 625 +++++++----------- 1 file changed, 255 insertions(+), 370 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index fe625f169..d1dd14e06 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -252,6 +252,168 @@ function tagPreAcceptSendFailure(error: unknown): unknown { }); } +type SharedStreamEventContext = { + contentPartsState: ContentPartsState; + toolsWithUI: ToolUIGate; + currentThinkingSteps: Map<string, ThinkingStepData>; + scheduleFlush: () => void; + forceFlush: () => void; + onTokenUsage?: (data: TokenUsageData) => void; + onToolOutputAvailable?: ( + event: Extract<SSEEvent, { type: "tool-output-available" }>, + context: { + contentPartsState: ContentPartsState; + toolCallIndices: Map<string, number>; + } + ) => void; +}; + +function createStreamFlushHelpers(flushMessages: () => void): { + batcher: FrameBatchedUpdater; + scheduleFlush: () => void; + forceFlush: () => void; +} { + const batcher = new FrameBatchedUpdater(); + const scheduleFlush = () => batcher.schedule(flushMessages); + // Force-flush helper: ``batcher.flush()`` is a no-op when + // ``dirty=false`` (e.g. a tool starts before any text streamed). + // ``scheduleFlush(); batcher.flush()`` sets the dirty bit first so + // terminal events render promptly without the throttle delay. + const forceFlush = () => { + scheduleFlush(); + batcher.flush(); + }; + return { batcher, scheduleFlush, forceFlush }; +} + +function hasPersistableContent(contentParts: ContentPartsState["contentParts"], toolsWithUI: ToolUIGate) { + return contentParts.some( + (part) => + (part.type === "text" && part.text.length > 0) || + (part.type === "reasoning" && part.text.length > 0) || + (part.type === "tool-call" && (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) + ); +} + +function processSharedStreamEvent(parsed: SSEEvent, context: SharedStreamEventContext): boolean { + const { contentPartsState, toolsWithUI, currentThinkingSteps, scheduleFlush, forceFlush } = context; + const { contentParts, toolCallIndices } = contentPartsState; + + switch (parsed.type) { + case "text-delta": + appendText(contentPartsState, parsed.delta); + scheduleFlush(); + return true; + + case "reasoning-delta": + appendReasoning(contentPartsState, parsed.delta); + scheduleFlush(); + return true; + + case "reasoning-end": + endReasoning(contentPartsState); + scheduleFlush(); + return true; + + case "start-step": + addStepSeparator(contentPartsState); + scheduleFlush(); + return true; + + case "finish-step": + return true; + + case "tool-input-start": + addToolCall( + contentPartsState, + toolsWithUI, + parsed.toolCallId, + parsed.toolName, + {}, + false, + parsed.langchainToolCallId + ); + forceFlush(); + return true; + + case "tool-input-delta": + // High-frequency event: deltas can fire dozens of times per call, + // so use throttled scheduleFlush (NOT forceFlush) to coalesce. + appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); + scheduleFlush(); + return true; + + case "tool-input-available": { + const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); + if (toolCallIndices.has(parsed.toolCallId)) { + updateToolCall(contentPartsState, parsed.toolCallId, { + args: parsed.input || {}, + argsText: finalArgsText, + langchainToolCallId: parsed.langchainToolCallId, + }); + } else { + addToolCall( + contentPartsState, + toolsWithUI, + parsed.toolCallId, + parsed.toolName, + parsed.input || {}, + false, + parsed.langchainToolCallId + ); + // addToolCall doesn't accept argsText today; backfill via + // updateToolCall so the new card renders pretty-printed JSON. + updateToolCall(contentPartsState, parsed.toolCallId, { + argsText: finalArgsText, + }); + } + forceFlush(); + return true; + } + + case "tool-output-available": + updateToolCall(contentPartsState, parsed.toolCallId, { + result: parsed.output, + langchainToolCallId: parsed.langchainToolCallId, + }); + markInterruptsCompleted(contentParts); + context.onToolOutputAvailable?.(parsed, { contentPartsState, toolCallIndices }); + forceFlush(); + return true; + + case "data-thinking-step": { + const stepData = parsed.data as ThinkingStepData; + if (stepData?.id) { + currentThinkingSteps.set(stepData.id, stepData); + const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); + if (didUpdate) { + scheduleFlush(); + } + } + return true; + } + + case "data-token-usage": + context.onTokenUsage?.(parsed.data as TokenUsageData); + return true; + + case "error": + throw toStreamTerminalError(parsed); + + default: + return false; + } +} + +async function consumeSseEvents( + response: Response, + onEvent: (event: SSEEvent) => void | Promise<void> +): Promise<void> { + for await (const parsed of readSSEStream(response)) { + await onEvent(parsed); + } +} + /** * Zod schema for mentioned document info (for type-safe parsing) */ @@ -456,7 +618,7 @@ export default function NewChatPage() { threadId: number | null; assistantMsgId: string; content: unknown; - tokenUsage?: Record<string, unknown>; + tokenUsage?: TokenUsageData; turnId?: string | null; logContext: string; onRemapped?: (newMsgId: string) => void; @@ -1055,8 +1217,6 @@ export default function NewChatPage() { // Prepare assistant message const assistantMsgId = `msg-assistant-${Date.now()}`; const currentThinkingSteps = new Map<string, ThinkingStepData>(); - const batcher = new FrameBatchedUpdater(); - const contentPartsState: ContentPartsState = { contentParts: [], currentTextPartIndex: -1, @@ -1065,11 +1225,12 @@ export default function NewChatPage() { }; const { contentParts, toolCallIndices } = contentPartsState; let wasInterrupted = false; - let tokenUsageData: Record<string, unknown> | null = null; + let tokenUsageData: TokenUsageData | null = null; let newAccepted = false; let userPersisted = false; // Captured from ``data-turn-info`` at stream start. let streamedChatTurnId: string | null = null; + let streamBatcher: FrameBatchedUpdater | null = null; try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; @@ -1152,123 +1313,37 @@ export default function NewChatPage() { ) ); }; - const scheduleFlush = () => batcher.schedule(flushMessages); - // Force-flush helper: ``batcher.flush()`` is a no-op when - // ``dirty=false`` (e.g. a tool starts before any text - // streamed). ``scheduleFlush(); batcher.flush()`` sets - // the dirty bit FIRST so terminal events render - // promptly without the 50ms throttle delay. - const forceFlush = () => { - scheduleFlush(); - batcher.flush(); - }; + const { batcher, scheduleFlush, forceFlush } = createStreamFlushHelpers(flushMessages); + streamBatcher = batcher; - for await (const parsed of readSSEStream(response)) { - switch (parsed.type) { - case "text-delta": - appendText(contentPartsState, parsed.delta); - scheduleFlush(); - break; - - case "reasoning-delta": - appendReasoning(contentPartsState, parsed.delta); - scheduleFlush(); - break; - - case "reasoning-end": - endReasoning(contentPartsState); - scheduleFlush(); - break; - - case "start-step": - addStepSeparator(contentPartsState); - scheduleFlush(); - break; - - case "finish-step": - break; - - case "tool-input-start": - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - {}, - false, - parsed.langchainToolCallId - ); - forceFlush(); - break; - - case "tool-input-delta": - // High-frequency event: deltas can fire dozens - // of times per call, so use throttled - // scheduleFlush (NOT forceFlush) to coalesce. - appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); - scheduleFlush(); - break; - - case "tool-input-available": { - const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); - if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { - args: parsed.input || {}, - argsText: finalArgsText, - langchainToolCallId: parsed.langchainToolCallId, - }); - } else { - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - parsed.input || {}, - false, - parsed.langchainToolCallId - ); - // addToolCall doesn't accept argsText today; - // backfill via updateToolCall so the new card - // renders pretty-printed JSON. - updateToolCall(contentPartsState, parsed.toolCallId, { - argsText: finalArgsText, - }); - } - forceFlush(); - break; - } - - case "tool-output-available": { - updateToolCall(contentPartsState, parsed.toolCallId, { - result: parsed.output, - langchainToolCallId: parsed.langchainToolCallId, - }); - markInterruptsCompleted(contentParts); - if (parsed.output?.status === "pending" && parsed.output?.podcast_id) { - const idx = toolCallIndices.get(parsed.toolCallId); - if (idx !== undefined) { - const part = contentParts[idx]; - if (part?.type === "tool-call" && part.toolName === "generate_podcast") { - setActivePodcastTaskId(String(parsed.output.podcast_id)); + await consumeSseEvents(response, async (parsed) => { + if ( + processSharedStreamEvent(parsed, { + contentPartsState, + toolsWithUI, + currentThinkingSteps, + scheduleFlush, + forceFlush, + onTokenUsage: (data) => { + tokenUsageData = data; + tokenUsageStore.set(assistantMsgId, data); + }, + onToolOutputAvailable: (event, sharedCtx) => { + if (event.output?.status === "pending" && event.output?.podcast_id) { + const idx = sharedCtx.toolCallIndices.get(event.toolCallId); + if (idx !== undefined) { + const part = sharedCtx.contentPartsState.contentParts[idx]; + if (part?.type === "tool-call" && part.toolName === "generate_podcast") { + setActivePodcastTaskId(String(event.output.podcast_id)); + } } } - } - forceFlush(); - break; - } - - case "data-thinking-step": { - const stepData = parsed.data as ThinkingStepData; - if (stepData?.id) { - currentThinkingSteps.set(stepData.id, stepData); - const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); - if (didUpdate) { - scheduleFlush(); - } - } - break; - } - + }, + }) + ) { + return; + } + switch (parsed.type) { case "data-thread-title-update": { const titleData = parsed.data as { threadId: number; title: string }; if (titleData?.title && titleData?.threadId === currentThreadId) { @@ -1374,16 +1449,8 @@ export default function NewChatPage() { } break; } - - case "data-token-usage": - tokenUsageData = parsed.data; - tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData); - break; - - case "error": - throw toStreamTerminalError(parsed); } - } + }); batcher.flush(); @@ -1425,7 +1492,7 @@ export default function NewChatPage() { trackChatResponseReceived(searchSpaceId, currentThreadId); } } catch (error) { - batcher.dispose(); + streamBatcher?.dispose(); await handleStreamTerminalError({ error, flow: "new", @@ -1448,13 +1515,7 @@ export default function NewChatPage() { } } - const hasContent = contentParts.some( - (part) => - (part.type === "text" && part.text.length > 0) || - (part.type === "reasoning" && part.text.length > 0) || - (part.type === "tool-call" && - (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) - ); + const hasContent = hasPersistableContent(contentParts, toolsWithUI); if (hasContent && currentThreadId) { const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); await persistAssistantTurn({ @@ -1543,7 +1604,6 @@ export default function NewChatPage() { abortControllerRef.current = controller; const currentThinkingSteps = new Map<string, ThinkingStepData>(); - const batcher = new FrameBatchedUpdater(); const contentPartsState: ContentPartsState = { contentParts: [], @@ -1552,10 +1612,11 @@ export default function NewChatPage() { toolCallIndices: new Map(), }; const { contentParts, toolCallIndices } = contentPartsState; - let tokenUsageData: Record<string, unknown> | null = null; + let tokenUsageData: TokenUsageData | null = null; let resumeAccepted = false; // Captured from ``data-turn-info`` at stream start. let streamedChatTurnId: string | null = null; + let streamBatcher: FrameBatchedUpdater | null = null; const existingMsg = messages.find((m) => m.id === assistantMsgId); if (existingMsg && Array.isArray(existingMsg.content)) { @@ -1664,102 +1725,26 @@ export default function NewChatPage() { ) ); }; - const scheduleFlush = () => batcher.schedule(flushMessages); - const forceFlush = () => { - scheduleFlush(); - batcher.flush(); - }; + const { batcher, scheduleFlush, forceFlush } = createStreamFlushHelpers(flushMessages); + streamBatcher = batcher; - for await (const parsed of readSSEStream(response)) { + await consumeSseEvents(response, async (parsed) => { + if ( + processSharedStreamEvent(parsed, { + contentPartsState, + toolsWithUI, + currentThinkingSteps, + scheduleFlush, + forceFlush, + onTokenUsage: (data) => { + tokenUsageData = data; + tokenUsageStore.set(assistantMsgId, data); + }, + }) + ) { + return; + } switch (parsed.type) { - case "text-delta": - appendText(contentPartsState, parsed.delta); - scheduleFlush(); - break; - - case "reasoning-delta": - appendReasoning(contentPartsState, parsed.delta); - scheduleFlush(); - break; - - case "reasoning-end": - endReasoning(contentPartsState); - scheduleFlush(); - break; - - case "start-step": - addStepSeparator(contentPartsState); - scheduleFlush(); - break; - - case "finish-step": - break; - - case "tool-input-start": - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - {}, - false, - parsed.langchainToolCallId - ); - forceFlush(); - break; - - case "tool-input-delta": - appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); - scheduleFlush(); - break; - - case "tool-input-available": { - const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); - if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { - args: parsed.input || {}, - argsText: finalArgsText, - langchainToolCallId: parsed.langchainToolCallId, - }); - } else { - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - parsed.input || {}, - false, - parsed.langchainToolCallId - ); - updateToolCall(contentPartsState, parsed.toolCallId, { - argsText: finalArgsText, - }); - } - forceFlush(); - break; - } - - case "tool-output-available": - updateToolCall(contentPartsState, parsed.toolCallId, { - result: parsed.output, - langchainToolCallId: parsed.langchainToolCallId, - }); - markInterruptsCompleted(contentParts); - forceFlush(); - break; - - case "data-thinking-step": { - const stepData = parsed.data as ThinkingStepData; - if (stepData?.id) { - currentThinkingSteps.set(stepData.id, stepData); - const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); - if (didUpdate) { - scheduleFlush(); - } - } - break; - } - case "data-interrupt-request": { const interruptData = parsed.data as Record<string, unknown>; const actionRequests = (interruptData.action_requests ?? []) as Array<{ @@ -1830,16 +1815,8 @@ export default function NewChatPage() { } break; } - - case "data-token-usage": - tokenUsageData = parsed.data; - tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData); - break; - - case "error": - throw toStreamTerminalError(parsed); } - } + }); batcher.flush(); @@ -1855,7 +1832,7 @@ export default function NewChatPage() { }); } } catch (error) { - batcher.dispose(); + streamBatcher?.dispose(); await handleStreamTerminalError({ error, flow: "resume", @@ -1864,13 +1841,7 @@ export default function NewChatPage() { accepted: resumeAccepted, onAbort: async () => { if (!resumeAccepted) return; - const hasContent = contentParts.some( - (part) => - (part.type === "text" && part.text.length > 0) || - (part.type === "reasoning" && part.text.length > 0) || - (part.type === "tool-call" && - (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) - ); + const hasContent = hasPersistableContent(contentParts, toolsWithUI); if (!hasContent) return; const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); await persistAssistantTurn({ @@ -1891,6 +1862,7 @@ export default function NewChatPage() { pendingInterrupt, messages, searchSpaceId, + queryClient, tokenUsageStore, handleStreamTerminalError, persistAssistantTurn, @@ -2045,15 +2017,15 @@ export default function NewChatPage() { currentReasoningPartIndex: -1, toolCallIndices: new Map(), }; - const { contentParts, toolCallIndices } = contentPartsState; - const batcher = new FrameBatchedUpdater(); - let tokenUsageData: Record<string, unknown> | null = null; + const { contentParts } = contentPartsState; + let tokenUsageData: TokenUsageData | null = null; let regenerateAccepted = false; let userPersisted = false; // Captured from ``data-turn-info`` at stream start; stamped // onto persisted messages so future edits can locate the // right LangGraph checkpoint. let streamedChatTurnId: string | null = null; + let streamBatcher: FrameBatchedUpdater | null = null; // Add placeholder messages to UI // Always add back the user message (with new query for edit, or original content for reload) @@ -2155,111 +2127,37 @@ export default function NewChatPage() { ) ); }; - const scheduleFlush = () => batcher.schedule(flushMessages); - const forceFlush = () => { - scheduleFlush(); - batcher.flush(); - }; + const { batcher, scheduleFlush, forceFlush } = createStreamFlushHelpers(flushMessages); + streamBatcher = batcher; - for await (const parsed of readSSEStream(response)) { - switch (parsed.type) { - case "text-delta": - appendText(contentPartsState, parsed.delta); - scheduleFlush(); - break; - - case "reasoning-delta": - appendReasoning(contentPartsState, parsed.delta); - scheduleFlush(); - break; - - case "reasoning-end": - endReasoning(contentPartsState); - scheduleFlush(); - break; - - case "start-step": - addStepSeparator(contentPartsState); - scheduleFlush(); - break; - - case "finish-step": - break; - - case "tool-input-start": - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - {}, - false, - parsed.langchainToolCallId - ); - forceFlush(); - break; - - case "tool-input-delta": - appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); - scheduleFlush(); - break; - - case "tool-input-available": { - const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); - if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { - args: parsed.input || {}, - argsText: finalArgsText, - langchainToolCallId: parsed.langchainToolCallId, - }); - } else { - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - parsed.input || {}, - false, - parsed.langchainToolCallId - ); - updateToolCall(contentPartsState, parsed.toolCallId, { - argsText: finalArgsText, - }); - } - forceFlush(); - break; - } - - case "tool-output-available": - updateToolCall(contentPartsState, parsed.toolCallId, { - result: parsed.output, - langchainToolCallId: parsed.langchainToolCallId, - }); - markInterruptsCompleted(contentParts); - if (parsed.output?.status === "pending" && parsed.output?.podcast_id) { - const idx = toolCallIndices.get(parsed.toolCallId); - if (idx !== undefined) { - const part = contentParts[idx]; - if (part?.type === "tool-call" && part.toolName === "generate_podcast") { - setActivePodcastTaskId(String(parsed.output.podcast_id)); + await consumeSseEvents(response, async (parsed) => { + if ( + processSharedStreamEvent(parsed, { + contentPartsState, + toolsWithUI, + currentThinkingSteps, + scheduleFlush, + forceFlush, + onTokenUsage: (data) => { + tokenUsageData = data; + tokenUsageStore.set(assistantMsgId, data); + }, + onToolOutputAvailable: (event, sharedCtx) => { + if (event.output?.status === "pending" && event.output?.podcast_id) { + const idx = sharedCtx.toolCallIndices.get(event.toolCallId); + if (idx !== undefined) { + const part = sharedCtx.contentPartsState.contentParts[idx]; + if (part?.type === "tool-call" && part.toolName === "generate_podcast") { + setActivePodcastTaskId(String(event.output.podcast_id)); + } } } - } - forceFlush(); - break; - - case "data-thinking-step": { - const stepData = parsed.data as ThinkingStepData; - if (stepData?.id) { - currentThinkingSteps.set(stepData.id, stepData); - const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); - if (didUpdate) { - scheduleFlush(); - } - } - break; - } - + }, + }) + ) { + return; + } + switch (parsed.type) { case "data-action-log": { if (threadId !== null) { applyActionLogSse(queryClient, threadId, searchSpaceId, parsed.data); @@ -2326,16 +2224,8 @@ export default function NewChatPage() { } break; } - - case "data-token-usage": - tokenUsageData = parsed.data; - tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData); - break; - - case "error": - throw toStreamTerminalError(parsed); } - } + }); batcher.flush(); @@ -2364,7 +2254,7 @@ export default function NewChatPage() { trackChatResponseReceived(searchSpaceId, threadId); } } catch (error) { - batcher.dispose(); + streamBatcher?.dispose(); await handleStreamTerminalError({ error, flow: "regenerate", @@ -2384,13 +2274,7 @@ export default function NewChatPage() { }); userPersisted = Boolean(persistedUserMsgId); } - const hasContent = contentParts.some( - (part) => - (part.type === "text" && part.text.length > 0) || - (part.type === "reasoning" && part.text.length > 0) || - (part.type === "tool-call" && - (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) - ); + const hasContent = hasPersistableContent(contentParts, toolsWithUI); if (!hasContent) return; const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); await persistAssistantTurn({ @@ -2428,6 +2312,7 @@ export default function NewChatPage() { disabledTools, messageDocumentsMap, setMessageDocumentsMap, + queryClient, tokenUsageStore, handleStreamTerminalError, persistAssistantTurn, From 86f6b285ce9cedbf529a7d8325f4457f602f997a Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 18:09:34 +0530 Subject: [PATCH 24/68] refactor(chat): introduce new stream handling utilities and restructure event processing for improved performance and maintainability --- .../new-chat/[[...chat_id]]/page.tsx | 205 +----------------- surfsense_web/lib/chat/stream-flush.ts | 19 ++ surfsense_web/lib/chat/stream-pipeline.ts | 191 ++++++++++++++++ 3 files changed, 217 insertions(+), 198 deletions(-) create mode 100644 surfsense_web/lib/chat/stream-flush.ts create mode 100644 surfsense_web/lib/chat/stream-pipeline.ts diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index d1dd14e06..82a12b6b1 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -71,23 +71,21 @@ import { setActivePodcastTaskId, } from "@/lib/chat/podcast-state"; import { - addStepSeparator, addToolCall, - appendReasoning, - appendText, - appendToolInputDelta, buildContentForPersistence, buildContentForUI, type ContentPartsState, - endReasoning, - FrameBatchedUpdater, - readSSEStream, - type SSEEvent, + type FrameBatchedUpdater, type ThinkingStepData, type ToolUIGate, - updateThinkingSteps, updateToolCall, } from "@/lib/chat/streaming-state"; +import { createStreamFlushHelpers } from "@/lib/chat/stream-flush"; +import { + consumeSseEvents, + hasPersistableContent, + processSharedStreamEvent, +} from "@/lib/chat/stream-pipeline"; import { appendMessage, createThread, @@ -134,33 +132,6 @@ const MobileReportPanel = dynamic( { ssr: false } ); -/** - * After a tool produces output, mark any previously-decided interrupt tool - * calls as completed so the ApprovalCard can transition from shimmer to done. - */ -function markInterruptsCompleted(contentParts: Array<{ type: string; result?: unknown }>): void { - for (const part of contentParts) { - if ( - part.type === "tool-call" && - typeof part.result === "object" && - part.result !== null && - (part.result as Record<string, unknown>).__interrupt__ === true && - (part.result as Record<string, unknown>).__decided__ && - !(part.result as Record<string, unknown>).__completed__ - ) { - part.result = { ...(part.result as Record<string, unknown>), __completed__: true }; - } - } -} - -function toStreamTerminalError( - event: Extract<SSEEvent, { type: "error" }> -): Error & { errorCode?: string } { - return Object.assign(new Error(event.errorText || "Server error"), { - errorCode: event.errorCode, - }); -} - async function toHttpResponseError(response: Response): Promise<Error & { errorCode?: string }> { const statusDefaultCode = response.status === 409 @@ -252,168 +223,6 @@ function tagPreAcceptSendFailure(error: unknown): unknown { }); } -type SharedStreamEventContext = { - contentPartsState: ContentPartsState; - toolsWithUI: ToolUIGate; - currentThinkingSteps: Map<string, ThinkingStepData>; - scheduleFlush: () => void; - forceFlush: () => void; - onTokenUsage?: (data: TokenUsageData) => void; - onToolOutputAvailable?: ( - event: Extract<SSEEvent, { type: "tool-output-available" }>, - context: { - contentPartsState: ContentPartsState; - toolCallIndices: Map<string, number>; - } - ) => void; -}; - -function createStreamFlushHelpers(flushMessages: () => void): { - batcher: FrameBatchedUpdater; - scheduleFlush: () => void; - forceFlush: () => void; -} { - const batcher = new FrameBatchedUpdater(); - const scheduleFlush = () => batcher.schedule(flushMessages); - // Force-flush helper: ``batcher.flush()`` is a no-op when - // ``dirty=false`` (e.g. a tool starts before any text streamed). - // ``scheduleFlush(); batcher.flush()`` sets the dirty bit first so - // terminal events render promptly without the throttle delay. - const forceFlush = () => { - scheduleFlush(); - batcher.flush(); - }; - return { batcher, scheduleFlush, forceFlush }; -} - -function hasPersistableContent(contentParts: ContentPartsState["contentParts"], toolsWithUI: ToolUIGate) { - return contentParts.some( - (part) => - (part.type === "text" && part.text.length > 0) || - (part.type === "reasoning" && part.text.length > 0) || - (part.type === "tool-call" && (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) - ); -} - -function processSharedStreamEvent(parsed: SSEEvent, context: SharedStreamEventContext): boolean { - const { contentPartsState, toolsWithUI, currentThinkingSteps, scheduleFlush, forceFlush } = context; - const { contentParts, toolCallIndices } = contentPartsState; - - switch (parsed.type) { - case "text-delta": - appendText(contentPartsState, parsed.delta); - scheduleFlush(); - return true; - - case "reasoning-delta": - appendReasoning(contentPartsState, parsed.delta); - scheduleFlush(); - return true; - - case "reasoning-end": - endReasoning(contentPartsState); - scheduleFlush(); - return true; - - case "start-step": - addStepSeparator(contentPartsState); - scheduleFlush(); - return true; - - case "finish-step": - return true; - - case "tool-input-start": - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - {}, - false, - parsed.langchainToolCallId - ); - forceFlush(); - return true; - - case "tool-input-delta": - // High-frequency event: deltas can fire dozens of times per call, - // so use throttled scheduleFlush (NOT forceFlush) to coalesce. - appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); - scheduleFlush(); - return true; - - case "tool-input-available": { - const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); - if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { - args: parsed.input || {}, - argsText: finalArgsText, - langchainToolCallId: parsed.langchainToolCallId, - }); - } else { - addToolCall( - contentPartsState, - toolsWithUI, - parsed.toolCallId, - parsed.toolName, - parsed.input || {}, - false, - parsed.langchainToolCallId - ); - // addToolCall doesn't accept argsText today; backfill via - // updateToolCall so the new card renders pretty-printed JSON. - updateToolCall(contentPartsState, parsed.toolCallId, { - argsText: finalArgsText, - }); - } - forceFlush(); - return true; - } - - case "tool-output-available": - updateToolCall(contentPartsState, parsed.toolCallId, { - result: parsed.output, - langchainToolCallId: parsed.langchainToolCallId, - }); - markInterruptsCompleted(contentParts); - context.onToolOutputAvailable?.(parsed, { contentPartsState, toolCallIndices }); - forceFlush(); - return true; - - case "data-thinking-step": { - const stepData = parsed.data as ThinkingStepData; - if (stepData?.id) { - currentThinkingSteps.set(stepData.id, stepData); - const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); - if (didUpdate) { - scheduleFlush(); - } - } - return true; - } - - case "data-token-usage": - context.onTokenUsage?.(parsed.data as TokenUsageData); - return true; - - case "error": - throw toStreamTerminalError(parsed); - - default: - return false; - } -} - -async function consumeSseEvents( - response: Response, - onEvent: (event: SSEEvent) => void | Promise<void> -): Promise<void> { - for await (const parsed of readSSEStream(response)) { - await onEvent(parsed); - } -} - /** * Zod schema for mentioned document info (for type-safe parsing) */ diff --git a/surfsense_web/lib/chat/stream-flush.ts b/surfsense_web/lib/chat/stream-flush.ts new file mode 100644 index 000000000..6d13c9237 --- /dev/null +++ b/surfsense_web/lib/chat/stream-flush.ts @@ -0,0 +1,19 @@ +import { FrameBatchedUpdater } from "@/lib/chat/streaming-state"; + +export function createStreamFlushHelpers(flushMessages: () => void): { + batcher: FrameBatchedUpdater; + scheduleFlush: () => void; + forceFlush: () => void; +} { + const batcher = new FrameBatchedUpdater(); + const scheduleFlush = () => batcher.schedule(flushMessages); + // Force-flush helper: ``batcher.flush()`` is a no-op when + // ``dirty=false`` (e.g. a tool starts before any text streamed). + // ``scheduleFlush(); batcher.flush()`` sets the dirty bit first so + // terminal events render promptly without the throttle delay. + const forceFlush = () => { + scheduleFlush(); + batcher.flush(); + }; + return { batcher, scheduleFlush, forceFlush }; +} diff --git a/surfsense_web/lib/chat/stream-pipeline.ts b/surfsense_web/lib/chat/stream-pipeline.ts new file mode 100644 index 000000000..8957bdea3 --- /dev/null +++ b/surfsense_web/lib/chat/stream-pipeline.ts @@ -0,0 +1,191 @@ +import { + addStepSeparator, + addToolCall, + appendReasoning, + appendText, + appendToolInputDelta, + type ContentPartsState, + endReasoning, + readSSEStream, + type SSEEvent, + type ThinkingStepData, + type ToolUIGate, + updateThinkingSteps, + updateToolCall, +} from "@/lib/chat/streaming-state"; + +export type SharedStreamEventContext = { + contentPartsState: ContentPartsState; + toolsWithUI: ToolUIGate; + currentThinkingSteps: Map<string, ThinkingStepData>; + scheduleFlush: () => void; + forceFlush: () => void; + onTokenUsage?: (data: Extract<SSEEvent, { type: "data-token-usage" }>["data"]) => void; + onToolOutputAvailable?: ( + event: Extract<SSEEvent, { type: "tool-output-available" }>, + context: { + contentPartsState: ContentPartsState; + toolCallIndices: Map<string, number>; + } + ) => void; +}; + +/** + * After a tool produces output, mark any previously-decided interrupt tool + * calls as completed so the ApprovalCard can transition from shimmer to done. + */ +export function markInterruptsCompleted( + contentParts: Array<{ type: string; result?: unknown }> +): void { + for (const part of contentParts) { + if ( + part.type === "tool-call" && + typeof part.result === "object" && + part.result !== null && + (part.result as Record<string, unknown>).__interrupt__ === true && + (part.result as Record<string, unknown>).__decided__ && + !(part.result as Record<string, unknown>).__completed__ + ) { + part.result = { ...(part.result as Record<string, unknown>), __completed__: true }; + } + } +} + +export function hasPersistableContent( + contentParts: ContentPartsState["contentParts"], + toolsWithUI: ToolUIGate +) { + return contentParts.some( + (part) => + (part.type === "text" && part.text.length > 0) || + (part.type === "reasoning" && part.text.length > 0) || + (part.type === "tool-call" && (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) + ); +} + +function toStreamTerminalError( + event: Extract<SSEEvent, { type: "error" }> +): Error & { errorCode?: string } { + return Object.assign(new Error(event.errorText || "Server error"), { + errorCode: event.errorCode, + }); +} + +export function processSharedStreamEvent(parsed: SSEEvent, context: SharedStreamEventContext): boolean { + const { contentPartsState, toolsWithUI, currentThinkingSteps, scheduleFlush, forceFlush } = context; + const { contentParts, toolCallIndices } = contentPartsState; + + switch (parsed.type) { + case "text-delta": + appendText(contentPartsState, parsed.delta); + scheduleFlush(); + return true; + + case "reasoning-delta": + appendReasoning(contentPartsState, parsed.delta); + scheduleFlush(); + return true; + + case "reasoning-end": + endReasoning(contentPartsState); + scheduleFlush(); + return true; + + case "start-step": + addStepSeparator(contentPartsState); + scheduleFlush(); + return true; + + case "finish-step": + return true; + + case "tool-input-start": + addToolCall( + contentPartsState, + toolsWithUI, + parsed.toolCallId, + parsed.toolName, + {}, + false, + parsed.langchainToolCallId + ); + forceFlush(); + return true; + + case "tool-input-delta": + // High-frequency event: deltas can fire dozens of times per call, + // so use throttled scheduleFlush (NOT forceFlush) to coalesce. + appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); + scheduleFlush(); + return true; + + case "tool-input-available": { + const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); + if (toolCallIndices.has(parsed.toolCallId)) { + updateToolCall(contentPartsState, parsed.toolCallId, { + args: parsed.input || {}, + argsText: finalArgsText, + langchainToolCallId: parsed.langchainToolCallId, + }); + } else { + addToolCall( + contentPartsState, + toolsWithUI, + parsed.toolCallId, + parsed.toolName, + parsed.input || {}, + false, + parsed.langchainToolCallId + ); + // addToolCall doesn't accept argsText today; backfill via + // updateToolCall so the new card renders pretty-printed JSON. + updateToolCall(contentPartsState, parsed.toolCallId, { + argsText: finalArgsText, + }); + } + forceFlush(); + return true; + } + + case "tool-output-available": + updateToolCall(contentPartsState, parsed.toolCallId, { + result: parsed.output, + langchainToolCallId: parsed.langchainToolCallId, + }); + markInterruptsCompleted(contentParts); + context.onToolOutputAvailable?.(parsed, { contentPartsState, toolCallIndices }); + forceFlush(); + return true; + + case "data-thinking-step": { + const stepData = parsed.data as ThinkingStepData; + if (stepData?.id) { + currentThinkingSteps.set(stepData.id, stepData); + const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); + if (didUpdate) { + scheduleFlush(); + } + } + return true; + } + + case "data-token-usage": + context.onTokenUsage?.(parsed.data); + return true; + + case "error": + throw toStreamTerminalError(parsed); + + default: + return false; + } +} + +export async function consumeSseEvents( + response: Response, + onEvent: (event: SSEEvent) => void | Promise<void> +): Promise<void> { + for await (const parsed of readSSEStream(response)) { + await onEvent(parsed); + } +} From d65a3fdf76364b0705eaff0953f4d7283ecafde2 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 18:22:34 +0530 Subject: [PATCH 25/68] refactor(chat): implement new error handling utilities and streamline interrupt request processing in NewChatPage for improved performance and maintainability --- .../new-chat/[[...chat_id]]/page.tsx | 238 +++--------------- surfsense_web/lib/chat/chat-request-errors.ts | 89 +++++++ surfsense_web/lib/chat/stream-side-effects.ts | 127 ++++++++++ 3 files changed, 246 insertions(+), 208 deletions(-) create mode 100644 surfsense_web/lib/chat/chat-request-errors.ts create mode 100644 surfsense_web/lib/chat/stream-side-effects.ts diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 82a12b6b1..02c2914be 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -64,6 +64,10 @@ import { classifyChatError, type ChatFlow, } from "@/lib/chat/chat-error-classifier"; +import { + tagPreAcceptSendFailure, + toHttpResponseError, +} from "@/lib/chat/chat-request-errors"; import { convertToThreadMessage } from "@/lib/chat/message-utils"; import { isPodcastGenerating, @@ -71,14 +75,12 @@ import { setActivePodcastTaskId, } from "@/lib/chat/podcast-state"; import { - addToolCall, buildContentForPersistence, buildContentForUI, type ContentPartsState, type FrameBatchedUpdater, type ThinkingStepData, type ToolUIGate, - updateToolCall, } from "@/lib/chat/streaming-state"; import { createStreamFlushHelpers } from "@/lib/chat/stream-flush"; import { @@ -86,6 +88,14 @@ import { hasPersistableContent, processSharedStreamEvent, } from "@/lib/chat/stream-pipeline"; +import { + applyTurnIdToAssistantMessageList, + applyInterruptRequestToContentParts, + mergeChatTurnIdIntoMessage, + mergeEditedInterruptAction, + markInterruptDecisionOnContentParts, + readStreamedChatTurnId, +} from "@/lib/chat/stream-side-effects"; import { appendMessage, createThread, @@ -132,97 +142,6 @@ const MobileReportPanel = dynamic( { ssr: false } ); -async function toHttpResponseError(response: Response): Promise<Error & { errorCode?: string }> { - const statusDefaultCode = - response.status === 409 - ? "THREAD_BUSY" - : response.status === 429 - ? "RATE_LIMITED" - : response.status === 401 || response.status === 403 - ? "AUTH_EXPIRED" - : "SERVER_ERROR"; - - let rawBody = ""; - try { - rawBody = await response.text(); - } catch { - // noop - } - - let parsedBody: Record<string, unknown> | null = null; - if (rawBody) { - try { - const parsed = JSON.parse(rawBody); - if (typeof parsed === "object" && parsed !== null) { - parsedBody = parsed as Record<string, unknown>; - } - } catch { - // noop - } - } - - const detail = parsedBody?.detail; - const detailObject = - typeof detail === "object" && detail !== null ? (detail as Record<string, unknown>) : null; - const detailMessage = typeof detail === "string" ? detail : undefined; - const topLevelMessage = - typeof parsedBody?.message === "string" ? (parsedBody.message as string) : undefined; - const detailNestedMessage = - typeof detailObject?.message === "string" ? (detailObject.message as string) : undefined; - - const topLevelCode = - typeof parsedBody?.errorCode === "string" - ? parsedBody.errorCode - : typeof parsedBody?.error_code === "string" - ? parsedBody.error_code - : undefined; - const detailCode = - typeof detailObject?.errorCode === "string" - ? detailObject.errorCode - : typeof detailObject?.error_code === "string" - ? detailObject.error_code - : undefined; - - const errorCode = detailCode ?? topLevelCode ?? statusDefaultCode; - const message = - detailNestedMessage ?? - detailMessage ?? - topLevelMessage ?? - `Backend error: ${response.status}`; - - return Object.assign(new Error(message), { errorCode }); -} - -function tagPreAcceptSendFailure(error: unknown): unknown { - if (error instanceof Error) { - const withCode = error as Error & { errorCode?: string; code?: string }; - const existingCode = withCode.errorCode ?? withCode.code; - const passthroughCodes = new Set([ - "PREMIUM_QUOTA_EXHAUSTED", - "THREAD_BUSY", - "AUTH_EXPIRED", - "UNAUTHORIZED", - "RATE_LIMITED", - "NETWORK_ERROR", - "STREAM_PARSE_ERROR", - "TOOL_EXECUTION_ERROR", - "PERSIST_MESSAGE_FAILED", - "SERVER_ERROR", - ]); - if ( - existingCode && - passthroughCodes.has(existingCode) - ) { - return Object.assign(error, { errorCode: existingCode }); - } - return Object.assign(error, { errorCode: "SEND_FAILED_PRE_ACCEPT" }); - } - - return Object.assign(new Error("Failed to send message before stream acceptance"), { - errorCode: "SEND_FAILED_PRE_ACCEPT", - }); -} - /** * Zod schema for mentioned document info (for type-safe parsing) */ @@ -264,29 +183,6 @@ function extractMentionedDocuments(content: unknown): MentionedDocumentInfo[] { */ const TOOLS_WITH_UI_ALL: ToolUIGate = "all"; -/** - * When a streamed message is persisted, the backend returns the durable - * ``turn_id`` (``configurable.turn_id`` from the agent run). Merge it - * into the assistant-ui message metadata so the per-turn "Revert turn" - * button can scope to this turn's actions even after a full chat reload. - */ -function mergeChatTurnIdIntoMessage( - msg: ThreadMessageLike, - turnId: string | null | undefined -): ThreadMessageLike { - if (!turnId) return msg; - const existingMeta = (msg.metadata ?? {}) as { custom?: Record<string, unknown> }; - const existingCustom = existingMeta.custom ?? {}; - if ((existingCustom as { chatTurnId?: string }).chatTurnId === turnId) return msg; - return { - ...msg, - metadata: { - ...existingMeta, - custom: { ...existingCustom, chatTurnId: turnId }, - }, - }; -} - export default function NewChatPage() { const params = useParams(); const queryClient = useQueryClient(); @@ -1032,7 +928,7 @@ export default function NewChatPage() { currentReasoningPartIndex: -1, toolCallIndices: new Map(), }; - const { contentParts, toolCallIndices } = contentPartsState; + const { contentParts } = contentPartsState; let wasInterrupted = false; let tokenUsageData: TokenUsageData | null = null; let newAccepted = false; @@ -1194,27 +1090,7 @@ export default function NewChatPage() { case "data-interrupt-request": { wasInterrupted = true; const interruptData = parsed.data as Record<string, unknown>; - const actionRequests = (interruptData.action_requests ?? []) as Array<{ - name: string; - args: Record<string, unknown>; - }>; - for (const action of actionRequests) { - const existingIdx = Array.from(toolCallIndices.entries()).find(([, idx]) => { - const part = contentParts[idx]; - return part?.type === "tool-call" && part.toolName === action.name; - }); - if (existingIdx) { - updateToolCall(contentPartsState, existingIdx[0], { - result: { __interrupt__: true, ...interruptData }, - }); - } else { - const tcId = `interrupt-${action.name}`; - addToolCall(contentPartsState, toolsWithUI, tcId, action.name, action.args, true); - updateToolCall(contentPartsState, tcId, { - result: { __interrupt__: true, ...interruptData }, - }); - } - } + applyInterruptRequestToContentParts(contentPartsState, toolsWithUI, interruptData); setMessages((prev) => prev.map((m) => m.id === assistantMsgId @@ -1248,12 +1124,11 @@ export default function NewChatPage() { } case "data-turn-info": { - streamedChatTurnId = parsed.data.chat_turn_id || null; - if (streamedChatTurnId) { + const turnId = readStreamedChatTurnId(parsed.data); + streamedChatTurnId = turnId; + if (turnId) { setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m - ) + applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId) ); } break; @@ -1469,37 +1344,12 @@ export default function NewChatPage() { } // Merge edited args if present to fix race condition - if (decisions.length > 0 && decisions[0].type === "edit" && decisions[0].edited_action) { - const editedAction = decisions[0].edited_action; - for (const part of contentParts) { - if (part.type === "tool-call" && part.toolName === editedAction.name) { - const mergedArgs = { ...part.args, ...editedAction.args }; - part.args = mergedArgs; - // Sync argsText so the rendered card shows the - // edited inputs — assistant-ui prefers caller- - // supplied argsText over JSON.stringify(args). - part.argsText = JSON.stringify(mergedArgs, null, 2); - break; - } - } + if (decisions.length > 0 && decisions[0].type === "edit") { + mergeEditedInterruptAction(contentParts, decisions[0].edited_action); } const decisionType = decisions[0]?.type as "approve" | "reject" | undefined; - if (decisionType) { - for (const part of contentParts) { - if ( - part.type === "tool-call" && - typeof part.result === "object" && - part.result !== null && - "__interrupt__" in (part.result as Record<string, unknown>) - ) { - part.result = { - ...(part.result as Record<string, unknown>), - __decided__: decisionType, - }; - } - } - } + markInterruptDecisionOnContentParts(contentParts, decisionType); try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; @@ -1556,33 +1406,7 @@ export default function NewChatPage() { switch (parsed.type) { case "data-interrupt-request": { const interruptData = parsed.data as Record<string, unknown>; - const actionRequests = (interruptData.action_requests ?? []) as Array<{ - name: string; - args: Record<string, unknown>; - }>; - for (const action of actionRequests) { - const existingIdx = Array.from(toolCallIndices.entries()).find(([, idx]) => { - const part = contentParts[idx]; - return part?.type === "tool-call" && part.toolName === action.name; - }); - if (existingIdx) { - updateToolCall(contentPartsState, existingIdx[0], { - result: { - __interrupt__: true, - ...interruptData, - }, - }); - } else { - const tcId = `interrupt-${action.name}`; - addToolCall(contentPartsState, toolsWithUI, tcId, action.name, action.args, true); - updateToolCall(contentPartsState, tcId, { - result: { - __interrupt__: true, - ...interruptData, - }, - }); - } - } + applyInterruptRequestToContentParts(contentPartsState, toolsWithUI, interruptData); setMessages((prev) => prev.map((m) => m.id === assistantMsgId @@ -1614,12 +1438,11 @@ export default function NewChatPage() { } case "data-turn-info": { - streamedChatTurnId = parsed.data.chat_turn_id || null; - if (streamedChatTurnId) { + const turnId = readStreamedChatTurnId(parsed.data); + streamedChatTurnId = turnId; + if (turnId) { setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m - ) + applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId) ); } break; @@ -1987,12 +1810,11 @@ export default function NewChatPage() { } case "data-turn-info": { - streamedChatTurnId = parsed.data.chat_turn_id || null; - if (streamedChatTurnId) { + const turnId = readStreamedChatTurnId(parsed.data); + streamedChatTurnId = turnId; + if (turnId) { setMessages((prev) => - prev.map((m) => - m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m - ) + applyTurnIdToAssistantMessageList(prev, assistantMsgId, turnId) ); } break; diff --git a/surfsense_web/lib/chat/chat-request-errors.ts b/surfsense_web/lib/chat/chat-request-errors.ts new file mode 100644 index 000000000..3026e8203 --- /dev/null +++ b/surfsense_web/lib/chat/chat-request-errors.ts @@ -0,0 +1,89 @@ +export async function toHttpResponseError( + response: Response +): Promise<Error & { errorCode?: string }> { + const statusDefaultCode = + response.status === 409 + ? "THREAD_BUSY" + : response.status === 429 + ? "RATE_LIMITED" + : response.status === 401 || response.status === 403 + ? "AUTH_EXPIRED" + : "SERVER_ERROR"; + + let rawBody = ""; + try { + rawBody = await response.text(); + } catch { + // noop + } + + let parsedBody: Record<string, unknown> | null = null; + if (rawBody) { + try { + const parsed = JSON.parse(rawBody); + if (typeof parsed === "object" && parsed !== null) { + parsedBody = parsed as Record<string, unknown>; + } + } catch { + // noop + } + } + + const detail = parsedBody?.detail; + const detailObject = + typeof detail === "object" && detail !== null ? (detail as Record<string, unknown>) : null; + const detailMessage = typeof detail === "string" ? detail : undefined; + const topLevelMessage = + typeof parsedBody?.message === "string" ? (parsedBody.message as string) : undefined; + const detailNestedMessage = + typeof detailObject?.message === "string" ? (detailObject.message as string) : undefined; + + const topLevelCode = + typeof parsedBody?.errorCode === "string" + ? parsedBody.errorCode + : typeof parsedBody?.error_code === "string" + ? parsedBody.error_code + : undefined; + const detailCode = + typeof detailObject?.errorCode === "string" + ? detailObject.errorCode + : typeof detailObject?.error_code === "string" + ? detailObject.error_code + : undefined; + + const errorCode = detailCode ?? topLevelCode ?? statusDefaultCode; + const message = + detailNestedMessage ?? + detailMessage ?? + topLevelMessage ?? + `Backend error: ${response.status}`; + + return Object.assign(new Error(message), { errorCode }); +} + +export function tagPreAcceptSendFailure(error: unknown): unknown { + if (error instanceof Error) { + const withCode = error as Error & { errorCode?: string; code?: string }; + const existingCode = withCode.errorCode ?? withCode.code; + const passthroughCodes = new Set([ + "PREMIUM_QUOTA_EXHAUSTED", + "THREAD_BUSY", + "AUTH_EXPIRED", + "UNAUTHORIZED", + "RATE_LIMITED", + "NETWORK_ERROR", + "STREAM_PARSE_ERROR", + "TOOL_EXECUTION_ERROR", + "PERSIST_MESSAGE_FAILED", + "SERVER_ERROR", + ]); + if (existingCode && passthroughCodes.has(existingCode)) { + return Object.assign(error, { errorCode: existingCode }); + } + return Object.assign(error, { errorCode: "SEND_FAILED_PRE_ACCEPT" }); + } + + return Object.assign(new Error("Failed to send message before stream acceptance"), { + errorCode: "SEND_FAILED_PRE_ACCEPT", + }); +} diff --git a/surfsense_web/lib/chat/stream-side-effects.ts b/surfsense_web/lib/chat/stream-side-effects.ts new file mode 100644 index 000000000..9cb349458 --- /dev/null +++ b/surfsense_web/lib/chat/stream-side-effects.ts @@ -0,0 +1,127 @@ +import type { ThreadMessageLike } from "@assistant-ui/react"; +import { + addToolCall, + type ContentPartsState, + type ToolUIGate, + updateToolCall, +} from "@/lib/chat/streaming-state"; + +type InterruptActionRequest = { + name: string; + args: Record<string, unknown>; +}; + +export type EditedInterruptAction = { + name: string; + args: Record<string, unknown>; +}; + +function readInterruptActions( + interruptData: Record<string, unknown> +): InterruptActionRequest[] { + return (interruptData.action_requests ?? []) as InterruptActionRequest[]; +} + +/** + * Applies an interrupt request payload to tool-call parts. Existing tool cards + * are updated in-place; missing ones are upserted so approval UI always shows. + */ +export function applyInterruptRequestToContentParts( + contentPartsState: ContentPartsState, + toolsWithUI: ToolUIGate, + interruptData: Record<string, unknown> +): void { + const { contentParts, toolCallIndices } = contentPartsState; + const actionRequests = readInterruptActions(interruptData); + for (const action of actionRequests) { + const existingEntry = Array.from(toolCallIndices.entries()).find(([, idx]) => { + const part = contentParts[idx]; + return part?.type === "tool-call" && part.toolName === action.name; + }); + + if (existingEntry) { + updateToolCall(contentPartsState, existingEntry[0], { + result: { __interrupt__: true, ...interruptData }, + }); + } else { + const toolCallId = `interrupt-${action.name}`; + addToolCall(contentPartsState, toolsWithUI, toolCallId, action.name, action.args, true); + updateToolCall(contentPartsState, toolCallId, { + result: { __interrupt__: true, ...interruptData }, + }); + } + } +} + +export function mergeEditedInterruptAction( + contentParts: ContentPartsState["contentParts"], + editedAction: EditedInterruptAction | undefined +): void { + if (!editedAction) return; + for (const part of contentParts) { + if (part.type === "tool-call" && part.toolName === editedAction.name) { + const mergedArgs = { ...part.args, ...editedAction.args }; + part.args = mergedArgs; + // assistant-ui prefers argsText over JSON.stringify(args) + part.argsText = JSON.stringify(mergedArgs, null, 2); + break; + } + } +} + +export function markInterruptDecisionOnContentParts( + contentParts: ContentPartsState["contentParts"], + decisionType: "approve" | "reject" | undefined +): void { + if (!decisionType) return; + for (const part of contentParts) { + if ( + part.type === "tool-call" && + typeof part.result === "object" && + part.result !== null && + "__interrupt__" in (part.result as Record<string, unknown>) + ) { + part.result = { + ...(part.result as Record<string, unknown>), + __decided__: decisionType, + }; + } + } +} + +/** + * When a streamed message is persisted, the backend returns the durable + * turn_id; merge it into assistant-ui metadata for turn-scoped actions. + */ +export function mergeChatTurnIdIntoMessage( + msg: ThreadMessageLike, + turnId: string | null | undefined +): ThreadMessageLike { + if (!turnId) return msg; + const existingMeta = (msg.metadata ?? {}) as { custom?: Record<string, unknown> }; + const existingCustom = existingMeta.custom ?? {}; + if ((existingCustom as { chatTurnId?: string }).chatTurnId === turnId) return msg; + return { + ...msg, + metadata: { + ...existingMeta, + custom: { ...existingCustom, chatTurnId: turnId }, + }, + }; +} + +export function readStreamedChatTurnId(data: unknown): string | null { + if (typeof data !== "object" || data === null) return null; + const value = (data as { chat_turn_id?: unknown }).chat_turn_id; + return typeof value === "string" && value.length > 0 ? value : null; +} + +export function applyTurnIdToAssistantMessageList( + messages: ThreadMessageLike[], + assistantMsgId: string, + turnId: string +): ThreadMessageLike[] { + return messages.map((m) => + m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, turnId) : m + ); +} From 4056bd1d6947703652e612ac425dabc3ec3c67da Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 30 Apr 2026 22:37:11 +0530 Subject: [PATCH 26/68] refactor(chat): update resetCurrentThreadAtom to include shareToken and contentType for enhanced report panel state management --- surfsense_web/atoms/chat/current-thread.atom.ts | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/surfsense_web/atoms/chat/current-thread.atom.ts b/surfsense_web/atoms/chat/current-thread.atom.ts index d781df8d2..131c98309 100644 --- a/surfsense_web/atoms/chat/current-thread.atom.ts +++ b/surfsense_web/atoms/chat/current-thread.atom.ts @@ -26,7 +26,14 @@ export const setThreadVisibilityAtom = atom(null, (get, set, newVisibility: Chat export const resetCurrentThreadAtom = atom(null, (_, set) => { set(currentThreadAtom, initialState); - set(reportPanelAtom, { isOpen: false, reportId: null, title: null, wordCount: null }); + set(reportPanelAtom, { + isOpen: false, + reportId: null, + title: null, + wordCount: null, + shareToken: null, + contentType: "markdown", + }); }); /** Target comment ID to scroll to (from URL navigation or inbox click) */ From af66fbf106921822a895536c358f2b1a9b93b7a8 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 01:47:52 +0530 Subject: [PATCH 27/68] refactor(chat): implement turn cancellation and status management in new chat routes for improved user experience and error handling --- .../agents/new_chat/middleware/busy_mutex.py | 56 ++++- .../app/routes/new_chat_routes.py | 169 ++++++++++++++- surfsense_backend/app/schemas/new_chat.py | 18 ++ .../app/services/new_streaming_service.py | 11 +- .../app/tasks/chat/stream_new_chat.py | 75 ++++++- .../unit/agents/new_chat/test_busy_mutex.py | 30 +++ .../unit/test_stream_new_chat_contract.py | 139 ++++++++++--- .../new-chat/[[...chat_id]]/page.tsx | 194 +++++++++++++----- .../lib/chat/chat-error-classifier.ts | 18 +- surfsense_web/lib/chat/chat-request-errors.ts | 29 ++- surfsense_web/lib/chat/stream-pipeline.ts | 5 + surfsense_web/lib/chat/streaming-state.ts | 8 + 12 files changed, 671 insertions(+), 81 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py index c57d85004..d61a56533 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py +++ b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py @@ -33,6 +33,7 @@ from __future__ import annotations import asyncio import logging +import time import weakref from typing import Any @@ -58,6 +59,8 @@ class _ThreadLockManager: weakref.WeakValueDictionary() ) self._cancel_events: dict[str, asyncio.Event] = {} + self._cancel_requested_at_ms: dict[str, int] = {} + self._cancel_attempt_count: dict[str, int] = {} def lock_for(self, thread_id: str) -> asyncio.Lock: lock = self._locks.get(thread_id) @@ -76,14 +79,45 @@ class _ThreadLockManager: def request_cancel(self, thread_id: str) -> bool: event = self._cancel_events.get(thread_id) if event is None: - return False + event = asyncio.Event() + self._cancel_events[thread_id] = event event.set() + now_ms = int(time.time() * 1000) + self._cancel_requested_at_ms[thread_id] = now_ms + self._cancel_attempt_count[thread_id] = ( + self._cancel_attempt_count.get(thread_id, 0) + 1 + ) return True + def is_cancel_requested(self, thread_id: str) -> bool: + event = self._cancel_events.get(thread_id) + return bool(event and event.is_set()) + + def cancel_state(self, thread_id: str) -> tuple[int, int] | None: + if not self.is_cancel_requested(thread_id): + return None + attempts = self._cancel_attempt_count.get(thread_id, 1) + requested_at_ms = self._cancel_requested_at_ms.get(thread_id, 0) + return attempts, requested_at_ms + def reset(self, thread_id: str) -> None: event = self._cancel_events.get(thread_id) if event is not None: event.clear() + self._cancel_requested_at_ms.pop(thread_id, None) + self._cancel_attempt_count.pop(thread_id, None) + + def end_turn(self, thread_id: str) -> None: + """Best-effort terminal cleanup for a thread turn. + + This is intentionally idempotent and safe to call from outer stream + finally-blocks where middleware teardown might be skipped due to abort + or disconnect edge-cases. + """ + lock = self._locks.get(thread_id) + if lock is not None and lock.locked(): + lock.release() + self.reset(thread_id) # Module-level singleton — process-local but reused across all agent @@ -98,15 +132,30 @@ def get_cancel_event(thread_id: str) -> asyncio.Event: def request_cancel(thread_id: str) -> bool: - """Trip the cancel event for ``thread_id``. Returns True if found.""" + """Trip the cancel event for ``thread_id``. Always returns True.""" return manager.request_cancel(thread_id) +def is_cancel_requested(thread_id: str) -> bool: + """Return whether ``thread_id`` currently has a pending cancel signal.""" + return manager.is_cancel_requested(thread_id) + + +def get_cancel_state(thread_id: str) -> tuple[int, int] | None: + """Return ``(attempt_count, requested_at_ms)`` for pending cancel state.""" + return manager.cancel_state(thread_id) + + def reset_cancel(thread_id: str) -> None: """Reset the cancel event for ``thread_id`` (called between turns).""" manager.reset(thread_id) +def end_turn(thread_id: str) -> None: + """Force end-of-turn cleanup for lock + cancel state.""" + manager.end_turn(thread_id) + + class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): """Block concurrent prompts on the same thread. @@ -229,7 +278,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo __all__ = [ "BusyMutexMiddleware", + "end_turn", "get_cancel_event", + "get_cancel_state", + "is_cancel_requested", "manager", "request_cancel", "reset_cancel", diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index e04cce1b5..28b197ca2 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -15,7 +15,7 @@ import json import logging from datetime import UTC, datetime -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Request, Response from fastapi.responses import StreamingResponse from sqlalchemy import func, or_ from sqlalchemy.exc import IntegrityError, OperationalError @@ -29,6 +29,12 @@ from app.agents.new_chat.filesystem_selection import ( FilesystemSelection, LocalFilesystemMount, ) +from app.agents.new_chat.middleware.busy_mutex import ( + get_cancel_state, + is_cancel_requested, + manager, + request_cancel, +) from app.config import config from app.db import ( ChatComment, @@ -44,6 +50,7 @@ from app.db import ( ) from app.schemas.new_chat import ( AgentToolInfo, + CancelActiveTurnResponse, LocalFilesystemMountPayload, NewChatMessageRead, NewChatRequest, @@ -60,6 +67,7 @@ from app.schemas.new_chat import ( ThreadListItem, ThreadListResponse, TokenUsageSummary, + TurnStatusResponse, ) from app.services.token_tracking_service import record_token_usage from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat @@ -72,6 +80,9 @@ from app.utils.user_message_multimodal import ( _logger = logging.getLogger(__name__) _background_tasks: set[asyncio.Task] = set() +TURN_CANCELLING_INITIAL_DELAY_MS = 200 +TURN_CANCELLING_BACKOFF_FACTOR = 2 +TURN_CANCELLING_MAX_DELAY_MS = 1500 router = APIRouter() @@ -137,6 +148,72 @@ def _resolve_filesystem_selection( ) +def _compute_turn_cancelling_retry_delay(attempt: int) -> int: + """Bounded exponential delay for TURN_CANCELLING retry hints.""" + if attempt < 1: + attempt = 1 + delay = TURN_CANCELLING_INITIAL_DELAY_MS * ( + TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1) + ) + return min(delay, TURN_CANCELLING_MAX_DELAY_MS) + + +def _build_turn_status_payload(thread_id: int) -> dict[str, object]: + lock = manager.lock_for(str(thread_id)) + if not lock.locked(): + return {"status": "idle"} + + if is_cancel_requested(str(thread_id)): + cancel_state = get_cancel_state(str(thread_id)) + attempt = cancel_state[0] if cancel_state else 1 + retry_after_ms = _compute_turn_cancelling_retry_delay(attempt) + retry_after_at = int(datetime.now(UTC).timestamp() * 1000) + retry_after_ms + return { + "status": "cancelling", + "retry_after_ms": retry_after_ms, + "retry_after_at": retry_after_at, + } + + return {"status": "busy"} + + +def _set_retry_after_headers(response: Response, retry_after_ms: int) -> None: + response.headers["retry-after-ms"] = str(retry_after_ms) + response.headers["Retry-After"] = str(max(1, (retry_after_ms + 999) // 1000)) + + +def _raise_if_thread_busy_for_start(thread_id: int) -> None: + status_payload = _build_turn_status_payload(thread_id) + status = status_payload["status"] + if status == "idle": + return + if status == "cancelling": + retry_after_ms = int(status_payload.get("retry_after_ms") or 0) + detail = { + "errorCode": "TURN_CANCELLING", + "message": "A previous response is still stopping. Please try again in a moment.", + "retry_after_ms": retry_after_ms if retry_after_ms > 0 else None, + "retry_after_at": status_payload.get("retry_after_at"), + } + headers = ( + { + "retry-after-ms": str(retry_after_ms), + "Retry-After": str(max(1, (retry_after_ms + 999) // 1000)), + } + if retry_after_ms > 0 + else None + ) + raise HTTPException(status_code=409, detail=detail, headers=headers) + + raise HTTPException( + status_code=409, + detail={ + "errorCode": "THREAD_BUSY", + "message": "Another response is still finishing for this thread. Please try again in a moment.", + }, + ) + + def _find_pre_turn_checkpoint_id( checkpoint_tuples: list, *, @@ -1476,6 +1553,7 @@ async def handle_new_chat( # Check thread-level access based on visibility await check_thread_access(session, thread, user) + _raise_if_thread_busy_for_start(request.chat_id) filesystem_selection = _resolve_filesystem_selection( mode=request.filesystem_mode, client_platform=request.client_platform, @@ -1550,6 +1628,93 @@ async def handle_new_chat( ) from None +@router.post( + "/threads/{thread_id}/cancel-active-turn", + response_model=CancelActiveTurnResponse, +) +async def cancel_active_turn( + thread_id: int, + response: Response, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Signal cancellation for the currently running turn on ``thread_id``.""" + result = await session.execute( + select(NewChatThread).filter(NewChatThread.id == thread_id) + ) + thread = result.scalars().first() + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + + await check_permission( + session, + user, + thread.search_space_id, + Permission.CHATS_UPDATE.value, + "You don't have permission to update chats in this search space", + ) + await check_thread_access(session, thread, user) + + status_payload = _build_turn_status_payload(thread_id) + if status_payload["status"] == "idle": + return CancelActiveTurnResponse( + status="idle", + error_code="NO_ACTIVE_TURN", + ) + + request_cancel(str(thread_id)) + response.status_code = 202 + updated_payload = _build_turn_status_payload(thread_id) + retry_after_ms = int(updated_payload.get("retry_after_ms") or 0) + retry_after_at = ( + int(updated_payload["retry_after_at"]) + if "retry_after_at" in updated_payload + else None + ) + if retry_after_ms > 0: + _set_retry_after_headers(response, retry_after_ms) + return CancelActiveTurnResponse( + status="cancelling", + error_code="TURN_CANCELLING", + retry_after_ms=retry_after_ms if retry_after_ms > 0 else None, + retry_after_at=retry_after_at, + ) + + +@router.get( + "/threads/{thread_id}/turn-status", + response_model=TurnStatusResponse, +) +async def get_turn_status( + thread_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + result = await session.execute( + select(NewChatThread).filter(NewChatThread.id == thread_id) + ) + thread = result.scalars().first() + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + + await check_permission( + session, + user, + thread.search_space_id, + Permission.CHATS_READ.value, + "You don't have permission to view chats in this search space", + ) + await check_thread_access(session, thread, user) + + status_payload = _build_turn_status_payload(thread_id) + return TurnStatusResponse( + status=status_payload["status"], # type: ignore[arg-type] + active_turn_id=None, + retry_after_ms=status_payload.get("retry_after_ms"), # type: ignore[arg-type] + retry_after_at=status_payload.get("retry_after_at"), # type: ignore[arg-type] + ) + + # ============================================================================= # Chat Regeneration Endpoint (Edit/Reload) # ============================================================================= @@ -1605,6 +1770,7 @@ async def regenerate_response( # Check thread-level access based on visibility await check_thread_access(session, thread, user) + _raise_if_thread_busy_for_start(thread_id) filesystem_selection = _resolve_filesystem_selection( mode=request.filesystem_mode, client_platform=request.client_platform, @@ -2012,6 +2178,7 @@ async def resume_chat( ) await check_thread_access(session, thread, user) + _raise_if_thread_busy_for_start(thread_id) filesystem_selection = _resolve_filesystem_selection( mode=request.filesystem_mode, client_platform=request.client_platform, diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index c7284e901..ec5eefc07 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -335,6 +335,24 @@ class ResumeRequest(BaseModel): local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None +class CancelActiveTurnResponse(BaseModel): + """Response for canceling an active turn on a chat thread.""" + + status: Literal["cancelling", "idle"] + error_code: Literal["TURN_CANCELLING", "NO_ACTIVE_TURN"] + retry_after_ms: int | None = None + retry_after_at: int | None = None + + +class TurnStatusResponse(BaseModel): + """Current turn execution status for a thread.""" + + status: Literal["idle", "busy", "cancelling"] + active_turn_id: str | None = None + retry_after_ms: int | None = None + retry_after_at: int | None = None + + # ============================================================================= # Public Chat Snapshot Schemas # ============================================================================= diff --git a/surfsense_backend/app/services/new_streaming_service.py b/surfsense_backend/app/services/new_streaming_service.py index 842481f1c..55129668c 100644 --- a/surfsense_backend/app/services/new_streaming_service.py +++ b/surfsense_backend/app/services/new_streaming_service.py @@ -565,7 +565,12 @@ class VercelStreamingService: # Error Part # ========================================================================= - def format_error(self, error_text: str, error_code: str | None = None) -> str: + def format_error( + self, + error_text: str, + error_code: str | None = None, + extra: dict[str, object] | None = None, + ) -> str: """ Format an error message. @@ -579,9 +584,11 @@ class VercelStreamingService: Example output: data: {"type":"error","errorText":"Something went wrong","errorCode":"SOME_CODE"} """ - payload: dict[str, str] = {"type": "error", "errorText": error_text} + payload: dict[str, object] = {"type": "error", "errorText": error_text} if error_code: payload["errorCode"] = error_code + if extra: + payload.update(extra) return self._format_sse(payload) # ========================================================================= diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 2afa851b5..63c149771 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -45,6 +45,11 @@ from app.agents.new_chat.memory_extraction import ( extract_and_save_memory, extract_and_save_team_memory, ) +from app.agents.new_chat.middleware.busy_mutex import ( + end_turn, + get_cancel_state, + is_cancel_requested, +) from app.agents.new_chat.middleware.kb_persistence import ( commit_staged_filesystem_state, ) @@ -72,6 +77,18 @@ from app.utils.user_message_multimodal import build_human_message_content _background_tasks: set[asyncio.Task] = set() _perf_log = get_perf_logger() +TURN_CANCELLING_INITIAL_DELAY_MS = 200 +TURN_CANCELLING_BACKOFF_FACTOR = 2 +TURN_CANCELLING_MAX_DELAY_MS = 1500 + + +def _compute_turn_cancelling_retry_delay(attempt: int) -> int: + if attempt < 1: + attempt = 1 + delay = TURN_CANCELLING_INITIAL_DELAY_MS * ( + TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1) + ) + return min(delay, TURN_CANCELLING_MAX_DELAY_MS) def _extract_chunk_parts(chunk: Any) -> dict[str, Any]: @@ -401,15 +418,35 @@ def _classify_stream_exception( exc: Exception, *, flow_label: str, -) -> tuple[str, str, Literal["info", "warn", "error"], bool, str]: +) -> tuple[ + str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None +]: raw = str(exc) if isinstance(exc, BusyError) or "Thread is busy with another request" in raw: + busy_thread_id = str(exc.request_id) if isinstance(exc, BusyError) else None + if busy_thread_id and is_cancel_requested(busy_thread_id): + cancel_state = get_cancel_state(busy_thread_id) + attempt = cancel_state[0] if cancel_state else 1 + retry_after_ms = _compute_turn_cancelling_retry_delay(attempt) + retry_after_at = int(time.time() * 1000) + retry_after_ms + return ( + "thread_busy", + "TURN_CANCELLING", + "info", + True, + "A previous response is still stopping. Please try again in a moment.", + { + "retry_after_ms": retry_after_ms, + "retry_after_at": retry_after_at, + }, + ) return ( "thread_busy", "THREAD_BUSY", "warn", True, "Another response is still finishing for this thread. Please try again in a moment.", + None, ) parsed = _parse_error_payload(raw) @@ -431,6 +468,7 @@ def _classify_stream_exception( "warn", True, "This model is temporarily rate-limited. Please try again in a few seconds or switch models.", + None, ) return ( @@ -439,6 +477,7 @@ def _classify_stream_exception( "error", False, f"Error during {flow_label}: {raw}", + None, ) @@ -470,7 +509,7 @@ def _emit_stream_terminal_error( message=message, extra=extra, ) - return streaming_service.format_error(message, error_code=error_code) + return streaming_service.format_error(message, error_code=error_code, extra=extra) def _legacy_match_lc_id( @@ -2497,6 +2536,7 @@ async def stream_new_chat( "turn-info", {"chat_turn_id": stream_result.turn_id}, ) + yield streaming_service.format_data("turn-status", {"status": "busy"}) # Initial thinking step - analyzing the request if mentioned_surfsense_docs: @@ -2805,6 +2845,7 @@ async def stream_new_chat( task.add_done_callback(_background_tasks.discard) # Finish the step and message + yield streaming_service.format_data("turn-status", {"status": "idle"}) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() @@ -2819,11 +2860,19 @@ async def stream_new_chat( severity, is_expected, user_message, + error_extra, ) = _classify_stream_exception(e, flow_label="chat") error_message = f"Error during chat: {e!s}" print(f"[stream_new_chat] {error_message}") print(f"[stream_new_chat] Exception type: {type(e).__name__}") print(f"[stream_new_chat] Traceback:\n{traceback.format_exc()}") + if error_code == "TURN_CANCELLING": + status_payload: dict[str, Any] = {"status": "cancelling"} + if error_extra: + status_payload.update(error_extra) + yield streaming_service.format_data("turn-status", status_payload) + else: + yield streaming_service.format_data("turn-status", {"status": "busy"}) yield _emit_stream_error( message=user_message, @@ -2831,7 +2880,9 @@ async def stream_new_chat( error_code=error_code, severity=severity, is_expected=is_expected, + extra=error_extra, ) + yield streaming_service.format_data("turn-status", {"status": "idle"}) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() @@ -2847,6 +2898,10 @@ async def stream_new_chat( # (CancelledError is a BaseException), and the rest of the # finally block — including session.close() — would never run. with anyio.CancelScope(shield=True): + # Authoritative fallback cleanup for lock/cancel state. Middleware + # teardown can be skipped on some client-abort paths. + end_turn(str(chat_id)) + # Release premium reservation if not finalized if _premium_request_id and _premium_reserved > 0 and user_id: try: @@ -3206,6 +3261,7 @@ async def stream_resume_chat( "turn-info", {"chat_turn_id": stream_result.turn_id}, ) + yield streaming_service.format_data("turn-status", {"status": "busy"}) _t_stream_start = time.perf_counter() _first_event_logged = False @@ -3305,6 +3361,7 @@ async def stream_resume_chat( }, ) + yield streaming_service.format_data("turn-status", {"status": "idle"}) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() @@ -3318,23 +3375,37 @@ async def stream_resume_chat( severity, is_expected, user_message, + error_extra, ) = _classify_stream_exception(e, flow_label="resume") error_message = f"Error during resume: {e!s}" print(f"[stream_resume_chat] {error_message}") print(f"[stream_resume_chat] Traceback:\n{traceback.format_exc()}") + if error_code == "TURN_CANCELLING": + status_payload: dict[str, Any] = {"status": "cancelling"} + if error_extra: + status_payload.update(error_extra) + yield streaming_service.format_data("turn-status", status_payload) + else: + yield streaming_service.format_data("turn-status", {"status": "busy"}) yield _emit_stream_error( message=user_message, error_kind=error_kind, error_code=error_code, severity=severity, is_expected=is_expected, + extra=error_extra, ) + yield streaming_service.format_data("turn-status", {"status": "idle"}) yield streaming_service.format_finish_step() yield streaming_service.format_finish() yield streaming_service.format_done() finally: with anyio.CancelScope(shield=True): + # Authoritative fallback cleanup for lock/cancel state. Middleware + # teardown can be skipped on some client-abort paths. + end_turn(str(chat_id)) + # Release premium reservation if not finalized if _resume_premium_request_id and _resume_premium_reserved > 0 and user_id: try: diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py index 0c7bf17f6..c923dc499 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py @@ -7,7 +7,9 @@ import pytest from app.agents.new_chat.errors import BusyError from app.agents.new_chat.middleware.busy_mutex import ( BusyMutexMiddleware, + end_turn, get_cancel_event, + is_cancel_requested, manager, request_cancel, reset_cancel, @@ -88,3 +90,31 @@ async def test_no_thread_id_skipped_when_not_required() -> None: def test_reset_cancel_idempotent() -> None: # Should not raise even if event was never created reset_cancel("never-seen") + + +def test_request_cancel_creates_event_for_unseen_thread() -> None: + thread_id = "never-seen-cancel" + reset_cancel(thread_id) + + assert request_cancel(thread_id) is True + assert get_cancel_event(thread_id).is_set() + assert is_cancel_requested(thread_id) is True + + +@pytest.mark.asyncio +async def test_end_turn_force_clears_lock_and_cancel_state() -> None: + thread_id = "forced-end-turn" + mw = BusyMutexMiddleware() + runtime = _Runtime(thread_id) + + await mw.abefore_agent({}, runtime) + assert manager.lock_for(thread_id).locked() + + request_cancel(thread_id) + assert is_cancel_requested(thread_id) is True + + end_turn(thread_id) + + assert not manager.lock_for(thread_id).locked() + assert not get_cancel_event(thread_id).is_set() + assert is_cancel_requested(thread_id) is False diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 86ea7edd1..a1345c15c 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -8,6 +8,7 @@ import pytest import app.tasks.chat.stream_new_chat as stream_new_chat_module from app.agents.new_chat.errors import BusyError +from app.agents.new_chat.middleware.busy_mutex import request_cancel, reset_cancel from app.tasks.chat.stream_new_chat import ( StreamResult, _classify_stream_exception, @@ -147,7 +148,7 @@ def test_stream_exception_classifies_rate_limited(): exc = Exception( '{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}' ) - kind, code, severity, is_expected, user_message = _classify_stream_exception( + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( exc, flow_label="chat" ) assert kind == "rate_limited" @@ -155,11 +156,12 @@ def test_stream_exception_classifies_rate_limited(): assert severity == "warn" assert is_expected is True assert "temporarily rate-limited" in user_message + assert extra is None def test_stream_exception_classifies_thread_busy(): exc = BusyError(request_id="thread-123") - kind, code, severity, is_expected, user_message = _classify_stream_exception( + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( exc, flow_label="chat" ) assert kind == "thread_busy" @@ -167,11 +169,12 @@ def test_stream_exception_classifies_thread_busy(): assert severity == "warn" assert is_expected is True assert "still finishing for this thread" in user_message + assert extra is None def test_stream_exception_classifies_thread_busy_from_message(): exc = Exception("Thread is busy with another request") - kind, code, severity, is_expected, user_message = _classify_stream_exception( + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( exc, flow_label="chat" ) assert kind == "thread_busy" @@ -179,6 +182,24 @@ def test_stream_exception_classifies_thread_busy_from_message(): assert severity == "warn" assert is_expected is True assert "still finishing for this thread" in user_message + assert extra is None + + +def test_stream_exception_classifies_turn_cancelling_when_cancel_requested(): + thread_id = "thread-cancelling-1" + reset_cancel(thread_id) + request_cancel(thread_id) + exc = BusyError(request_id=thread_id) + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( + exc, flow_label="chat" + ) + assert kind == "thread_busy" + assert code == "TURN_CANCELLING" + assert severity == "info" + assert is_expected is True + assert "stopping" in user_message + assert isinstance(extra, dict) + assert "retry_after_ms" in extra def test_premium_classification_is_error_code_driven(): @@ -219,33 +240,33 @@ def test_toast_only_pre_accept_policy_has_no_inline_failed_marker(): def test_network_send_failures_use_unified_retry_toast_message(): classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts" classifier_source = classifier_path.read_text(encoding="utf-8") - page_path = ( - Path(__file__).resolve().parents[3] - / "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx" + request_errors_path = ( + Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-request-errors.ts" ) - page_source = page_path.read_text(encoding="utf-8") + request_errors_source = request_errors_path.read_text(encoding="utf-8") assert '"send_failed_pre_accept"' in classifier_source assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source + assert 'errorCode === "TURN_CANCELLING"' in classifier_source assert "if (withCode.code) return withCode.code;" in classifier_source assert 'userMessage: "Message not sent. Please retry."' in classifier_source assert 'userMessage: "Connection issue. Please try again."' in classifier_source - assert "tagPreAcceptSendFailure(error)" in page_source - assert "const passthroughCodes = new Set([" in page_source - assert '"PREMIUM_QUOTA_EXHAUSTED"' in page_source - assert '"THREAD_BUSY"' in page_source - assert '"AUTH_EXPIRED"' in page_source - assert '"UNAUTHORIZED"' in page_source - assert '"RATE_LIMITED"' in page_source - assert '"NETWORK_ERROR"' in page_source - assert '"STREAM_PARSE_ERROR"' in page_source - assert '"TOOL_EXECUTION_ERROR"' in page_source - assert '"PERSIST_MESSAGE_FAILED"' in page_source - assert '"SERVER_ERROR"' in page_source - assert "passthroughCodes.has(existingCode)" in page_source - assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in page_source - assert 'errorCode: "NETWORK_ERROR"' not in page_source - assert "Failed to start chat. Please try again." not in page_source + assert "const passthroughCodes = new Set([" in request_errors_source + assert '"PREMIUM_QUOTA_EXHAUSTED"' in request_errors_source + assert '"THREAD_BUSY"' in request_errors_source + assert '"TURN_CANCELLING"' in request_errors_source + assert '"AUTH_EXPIRED"' in request_errors_source + assert '"UNAUTHORIZED"' in request_errors_source + assert '"RATE_LIMITED"' in request_errors_source + assert '"NETWORK_ERROR"' in request_errors_source + assert '"STREAM_PARSE_ERROR"' in request_errors_source + assert '"TOOL_EXECUTION_ERROR"' in request_errors_source + assert '"PERSIST_MESSAGE_FAILED"' in request_errors_source + assert '"SERVER_ERROR"' in request_errors_source + assert "passthroughCodes.has(existingCode)" in request_errors_source + assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in request_errors_source + assert 'errorCode: "NETWORK_ERROR"' not in request_errors_source + assert "Failed to start chat. Please try again." not in classifier_source def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows(): @@ -269,3 +290,75 @@ def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows() # New flow persists only when accepted and not already persisted. assert "if (newAccepted && !userPersisted) {" in source + assert "const fetchWithTurnCancellingRetry = useCallback(" in source + assert "computeFallbackTurnCancellingRetryDelay" in source + assert 'withMeta.errorCode === "TURN_CANCELLING"' in source + assert 'withMeta.errorCode === "THREAD_BUSY"' in source + assert "await fetchWithTurnCancellingRetry(() =>" in source + + +def test_cancel_active_turn_route_contract_exists(): + routes_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_backend/app/routes/new_chat_routes.py" + ) + source = routes_path.read_text(encoding="utf-8") + + assert '@router.post(\n "/threads/{thread_id}/cancel-active-turn",' in source + assert "response_model=CancelActiveTurnResponse" in source + assert 'status="cancelling",' in source + assert 'error_code="TURN_CANCELLING",' in source + assert "retry_after_ms=retry_after_ms if retry_after_ms > 0 else None," in source + assert "retry_after_at=" in source + assert 'status="idle",' in source + assert 'error_code="NO_ACTIVE_TURN",' in source + + +def test_turn_status_route_contract_exists(): + routes_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_backend/app/routes/new_chat_routes.py" + ) + source = routes_path.read_text(encoding="utf-8") + + assert '@router.get(\n "/threads/{thread_id}/turn-status",' in source + assert "response_model=TurnStatusResponse" in source + assert "_build_turn_status_payload(thread_id)" in source + assert "Permission.CHATS_READ.value" in source + assert "_raise_if_thread_busy_for_start(" in source + + +def test_turn_cancelling_retry_policy_contract_exists(): + routes_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_backend/app/routes/new_chat_routes.py" + ) + source = routes_path.read_text(encoding="utf-8") + + assert "TURN_CANCELLING_INITIAL_DELAY_MS = 200" in source + assert "TURN_CANCELLING_BACKOFF_FACTOR = 2" in source + assert "TURN_CANCELLING_MAX_DELAY_MS = 1500" in source + assert "def _compute_turn_cancelling_retry_delay(" in source + assert "retry-after-ms" in source + assert '"Retry-After"' in source + assert '"errorCode": "TURN_CANCELLING"' in source + + +def test_turn_status_sse_contract_exists(): + stream_source = ( + Path(__file__).resolve().parents[3] + / "surfsense_backend/app/tasks/chat/stream_new_chat.py" + ).read_text(encoding="utf-8") + state_source = ( + Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/streaming-state.ts" + ).read_text(encoding="utf-8") + pipeline_source = ( + Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/stream-pipeline.ts" + ).read_text(encoding="utf-8") + + assert '"turn-status"' in stream_source + assert '"status": "busy"' in stream_source + assert '"status": "idle"' in stream_source + assert "type: \"data-turn-status\"" in state_source + assert "case \"data-turn-status\":" in pipeline_source + assert "end_turn(str(chat_id))" in stream_source diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 02c2914be..1b25ca431 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -182,6 +182,20 @@ function extractMentionedDocuments(content: unknown): MentionedDocumentInfo[] { * ``stream_new_chat.py``) keep the JSON from ballooning. */ const TOOLS_WITH_UI_ALL: ToolUIGate = "all"; +const TURN_CANCELLING_INITIAL_DELAY_MS = 200; +const TURN_CANCELLING_BACKOFF_FACTOR = 2; +const TURN_CANCELLING_MAX_DELAY_MS = 1500; +const RECENT_CANCEL_WINDOW_MS = 5_000; + +function sleep(ms: number): Promise<void> { + return new Promise((resolve) => setTimeout(resolve, ms)); +} + +function computeFallbackTurnCancellingRetryDelay(attempt: number): number { + const safeAttempt = Math.max(1, attempt); + const raw = TURN_CANCELLING_INITIAL_DELAY_MS * TURN_CANCELLING_BACKOFF_FACTOR ** (safeAttempt - 1); + return Math.min(raw, TURN_CANCELLING_MAX_DELAY_MS); +} export default function NewChatPage() { const params = useParams(); @@ -193,6 +207,7 @@ export default function NewChatPage() { const [isRunning, setIsRunning] = useState(false); const [tokenUsageStore] = useState(() => createTokenUsageStore()); const abortControllerRef = useRef<AbortController | null>(null); + const recentCancelRequestedAtRef = useRef(0); const [pendingInterrupt, setPendingInterrupt] = useState<{ threadId: number; assistantMsgId: string; @@ -598,6 +613,36 @@ export default function NewChatPage() { [handleChatFailure] ); + const fetchWithTurnCancellingRetry = useCallback( + async (runFetch: () => Promise<Response>) => { + const maxAttempts = 4; + for (let attempt = 1; attempt <= maxAttempts; attempt += 1) { + const response = await runFetch(); + if (response.ok) { + return response; + } + const error = await toHttpResponseError(response); + const withMeta = error as Error & { errorCode?: string; retryAfterMs?: number }; + const isTurnCancelling = withMeta.errorCode === "TURN_CANCELLING"; + const isRecentThreadBusyAfterCancel = + withMeta.errorCode === "THREAD_BUSY" && + Date.now() - recentCancelRequestedAtRef.current <= RECENT_CANCEL_WINDOW_MS; + if ((isTurnCancelling || isRecentThreadBusyAfterCancel) && attempt < maxAttempts) { + const waitMs = + withMeta.retryAfterMs ?? computeFallbackTurnCancellingRetryDelay(attempt); + await sleep(waitMs); + continue; + } + throw error; + } + + throw Object.assign(new Error("Turn cancellation retry limit exceeded"), { + errorCode: "TURN_CANCELLING", + }); + }, + [] + ); + // Initialize thread and load messages // For new chats (no urlChatId), we use lazy creation - thread is created on first message const initializeThread = useCallback(async () => { @@ -767,12 +812,39 @@ export default function NewChatPage() { // Cancel ongoing request const cancelRun = useCallback(async () => { + if (threadId) { + const token = getBearerToken(); + if (token) { + const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; + try { + const response = await fetch( + `${backendUrl}/api/v1/threads/${threadId}/cancel-active-turn`, + { + method: "POST", + headers: { + Authorization: `Bearer ${token}`, + }, + } + ); + if (response.ok) { + const payload = (await response.json()) as { + error_code?: string; + }; + if (payload.error_code === "TURN_CANCELLING") { + recentCancelRequestedAtRef.current = Date.now(); + } + } + } catch (error) { + console.warn("[NewChatPage] Failed to signal cancel-active-turn:", error); + } + } + } if (abortControllerRef.current) { abortControllerRef.current.abort(); abortControllerRef.current = null; } setIsRunning(false); - }, []); + }, [threadId]); // Handle new message from user const onNew = useCallback( @@ -971,29 +1043,33 @@ export default function NewChatPage() { setMentionedDocuments([]); } - const response = await fetch(`${backendUrl}/api/v1/new_chat`, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${token}`, - }, - body: JSON.stringify({ - chat_id: currentThreadId, - user_query: userQuery.trim(), - search_space_id: searchSpaceId, - filesystem_mode: selection.filesystem_mode, - client_platform: selection.client_platform, - local_filesystem_mounts: selection.local_filesystem_mounts, - messages: messageHistory, - mentioned_document_ids: hasDocumentIds ? mentionedDocumentIds.document_ids : undefined, - mentioned_surfsense_doc_ids: hasSurfsenseDocIds - ? mentionedDocumentIds.surfsense_doc_ids - : undefined, - disabled_tools: disabledTools.length > 0 ? disabledTools : undefined, - ...(userImages.length > 0 ? { user_images: userImages } : {}), - }), - signal: controller.signal, - }); + const response = await fetchWithTurnCancellingRetry(() => + fetch(`${backendUrl}/api/v1/new_chat`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify({ + chat_id: currentThreadId, + user_query: userQuery.trim(), + search_space_id: searchSpaceId, + filesystem_mode: selection.filesystem_mode, + client_platform: selection.client_platform, + local_filesystem_mounts: selection.local_filesystem_mounts, + messages: messageHistory, + mentioned_document_ids: hasDocumentIds + ? mentionedDocumentIds.document_ids + : undefined, + mentioned_surfsense_doc_ids: hasSurfsenseDocIds + ? mentionedDocumentIds.surfsense_doc_ids + : undefined, + disabled_tools: disabledTools.length > 0 ? disabledTools : undefined, + ...(userImages.length > 0 ? { user_images: userImages } : {}), + }), + signal: controller.signal, + }) + ); if (!response.ok) { throw await toHttpResponseError(response); @@ -1033,6 +1109,11 @@ export default function NewChatPage() { tokenUsageData = data; tokenUsageStore.set(assistantMsgId, data); }, + onTurnStatus: (data) => { + if (data.status === "cancelling") { + recentCancelRequestedAtRef.current = Date.now(); + } + }, onToolOutputAvailable: (event, sharedCtx) => { if (event.output?.status === "pending" && event.output?.podcast_id) { const idx = sharedCtx.toolCallIndices.get(event.toolCallId); @@ -1257,6 +1338,7 @@ export default function NewChatPage() { tokenUsageStore, pendingUserImageUrls, setPendingUserImageUrls, + fetchWithTurnCancellingRetry, handleStreamTerminalError, handleChatFailure, persistAssistantTurn, @@ -1354,21 +1436,23 @@ export default function NewChatPage() { try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; const selection = await getAgentFilesystemSelection(searchSpaceId); - const response = await fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${token}`, - }, - body: JSON.stringify({ - search_space_id: searchSpaceId, - decisions, - filesystem_mode: selection.filesystem_mode, - client_platform: selection.client_platform, - local_filesystem_mounts: selection.local_filesystem_mounts, - }), - signal: controller.signal, - }); + const response = await fetchWithTurnCancellingRetry(() => + fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify({ + search_space_id: searchSpaceId, + decisions, + filesystem_mode: selection.filesystem_mode, + client_platform: selection.client_platform, + local_filesystem_mounts: selection.local_filesystem_mounts, + }), + signal: controller.signal, + }) + ); if (!response.ok) { throw await toHttpResponseError(response); @@ -1399,6 +1483,11 @@ export default function NewChatPage() { tokenUsageData = data; tokenUsageStore.set(assistantMsgId, data); }, + onTurnStatus: (data) => { + if (data.status === "cancelling") { + recentCancelRequestedAtRef.current = Date.now(); + } + }, }) ) { return; @@ -1496,6 +1585,7 @@ export default function NewChatPage() { searchSpaceId, queryClient, tokenUsageStore, + fetchWithTurnCancellingRetry, handleStreamTerminalError, persistAssistantTurn, ] @@ -1700,15 +1790,17 @@ export default function NewChatPage() { requestBody.revert_actions = true; } } - const response = await fetch(getRegenerateUrl(threadId), { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${token}`, - }, - body: JSON.stringify(requestBody), - signal: controller.signal, - }); + const response = await fetchWithTurnCancellingRetry(() => + fetch(getRegenerateUrl(threadId), { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify(requestBody), + signal: controller.signal, + }) + ); if (!response.ok) { throw await toHttpResponseError(response); @@ -1774,6 +1866,11 @@ export default function NewChatPage() { tokenUsageData = data; tokenUsageStore.set(assistantMsgId, data); }, + onTurnStatus: (data) => { + if (data.status === "cancelling") { + recentCancelRequestedAtRef.current = Date.now(); + } + }, onToolOutputAvailable: (event, sharedCtx) => { if (event.output?.status === "pending" && event.output?.podcast_id) { const idx = sharedCtx.toolCallIndices.get(event.toolCallId); @@ -1945,6 +2042,7 @@ export default function NewChatPage() { setMessageDocumentsMap, queryClient, tokenUsageStore, + fetchWithTurnCancellingRetry, handleStreamTerminalError, persistAssistantTurn, persistUserTurn, diff --git a/surfsense_web/lib/chat/chat-error-classifier.ts b/surfsense_web/lib/chat/chat-error-classifier.ts index 57341a4c3..7dfbfc1a1 100644 --- a/surfsense_web/lib/chat/chat-error-classifier.ts +++ b/surfsense_web/lib/chat/chat-error-classifier.ts @@ -147,6 +147,22 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } + if ( + errorCode === "TURN_CANCELLING" + ) { + return { + kind: "thread_busy", + channel: "toast", + severity: "info", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: "A previous response is still stopping. Please try again in a moment.", + rawMessage, + errorCode: errorCode ?? "TURN_CANCELLING", + details: { flow: input.flow }, + }; + } + if ( errorCode === "THREAD_BUSY" ) { @@ -156,7 +172,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError severity: "warn", telemetryEvent: "chat_blocked", isExpected: true, - userMessage: "A previous response is still stopping. Please try again in a moment.", + userMessage: "Another response is still finishing for this thread. Please try again in a moment.", rawMessage, errorCode: errorCode ?? "THREAD_BUSY", details: { flow: input.flow }, diff --git a/surfsense_web/lib/chat/chat-request-errors.ts b/surfsense_web/lib/chat/chat-request-errors.ts index 3026e8203..708831354 100644 --- a/surfsense_web/lib/chat/chat-request-errors.ts +++ b/surfsense_web/lib/chat/chat-request-errors.ts @@ -1,6 +1,6 @@ export async function toHttpResponseError( response: Response -): Promise<Error & { errorCode?: string }> { +): Promise<Error & { errorCode?: string; retryAfterMs?: number }> { const statusDefaultCode = response.status === 409 ? "THREAD_BUSY" @@ -52,13 +52,37 @@ export async function toHttpResponseError( : undefined; const errorCode = detailCode ?? topLevelCode ?? statusDefaultCode; + + const detailRetryAfterMs = + typeof detailObject?.retry_after_ms === "number" + ? detailObject.retry_after_ms + : typeof detailObject?.retryAfterMs === "number" + ? detailObject.retryAfterMs + : undefined; + const topRetryAfterMs = + typeof parsedBody?.retry_after_ms === "number" + ? parsedBody.retry_after_ms + : typeof parsedBody?.retryAfterMs === "number" + ? parsedBody.retryAfterMs + : undefined; + const headerRetryAfterMsRaw = response.headers.get("retry-after-ms"); + const headerRetryAfterMs = headerRetryAfterMsRaw ? Number.parseFloat(headerRetryAfterMsRaw) : NaN; + const retryAfterHeader = response.headers.get("retry-after"); + const retryAfterSeconds = retryAfterHeader ? Number.parseFloat(retryAfterHeader) : NaN; + const retryAfterMsFromHeader = Number.isFinite(headerRetryAfterMs) + ? Math.max(0, Math.round(headerRetryAfterMs)) + : Number.isFinite(retryAfterSeconds) + ? Math.max(0, Math.round(retryAfterSeconds * 1000)) + : undefined; + const retryAfterMs = + detailRetryAfterMs ?? topRetryAfterMs ?? retryAfterMsFromHeader ?? undefined; const message = detailNestedMessage ?? detailMessage ?? topLevelMessage ?? `Backend error: ${response.status}`; - return Object.assign(new Error(message), { errorCode }); + return Object.assign(new Error(message), { errorCode, retryAfterMs }); } export function tagPreAcceptSendFailure(error: unknown): unknown { @@ -68,6 +92,7 @@ export function tagPreAcceptSendFailure(error: unknown): unknown { const passthroughCodes = new Set([ "PREMIUM_QUOTA_EXHAUSTED", "THREAD_BUSY", + "TURN_CANCELLING", "AUTH_EXPIRED", "UNAUTHORIZED", "RATE_LIMITED", diff --git a/surfsense_web/lib/chat/stream-pipeline.ts b/surfsense_web/lib/chat/stream-pipeline.ts index 8957bdea3..c9118f949 100644 --- a/surfsense_web/lib/chat/stream-pipeline.ts +++ b/surfsense_web/lib/chat/stream-pipeline.ts @@ -21,6 +21,7 @@ export type SharedStreamEventContext = { scheduleFlush: () => void; forceFlush: () => void; onTokenUsage?: (data: Extract<SSEEvent, { type: "data-token-usage" }>["data"]) => void; + onTurnStatus?: (data: Extract<SSEEvent, { type: "data-turn-status" }>["data"]) => void; onToolOutputAvailable?: ( event: Extract<SSEEvent, { type: "tool-output-available" }>, context: { @@ -173,6 +174,10 @@ export function processSharedStreamEvent(parsed: SSEEvent, context: SharedStream context.onTokenUsage?.(parsed.data); return true; + case "data-turn-status": + context.onTurnStatus?.(parsed.data); + return true; + case "error": throw toStreamTerminalError(parsed); diff --git a/surfsense_web/lib/chat/streaming-state.ts b/surfsense_web/lib/chat/streaming-state.ts index 445bbe83d..80e7bffbe 100644 --- a/surfsense_web/lib/chat/streaming-state.ts +++ b/surfsense_web/lib/chat/streaming-state.ts @@ -528,6 +528,14 @@ export type SSEEvent = }>; }; } + | { + type: "data-turn-status"; + data: { + status: "idle" | "busy" | "cancelling"; + retry_after_ms?: number; + retry_after_at?: number; + }; + } | { type: "data-token-usage"; data: { From a66c1576b965acc50ae89d8a0f71ed3db1b64077 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 03:09:53 +0530 Subject: [PATCH 28/68] refactor(chat): introduce ChatViewport and NestedScroll components for improved chat UI structure and functionality --- .../components/assistant-ui/chat-viewport.tsx | 44 +++++++ .../components/assistant-ui/nested-scroll.tsx | 24 ++++ .../assistant-ui/thread-scroll-to-bottom.tsx | 18 --- .../components/assistant-ui/thread.tsx | 108 +++--------------- .../components/assistant-ui/tool-fallback.tsx | 9 +- .../components/free-chat/free-thread.tsx | 43 ++----- .../components/public-chat/public-thread.tsx | 9 +- 7 files changed, 99 insertions(+), 156 deletions(-) create mode 100644 surfsense_web/components/assistant-ui/chat-viewport.tsx create mode 100644 surfsense_web/components/assistant-ui/nested-scroll.tsx delete mode 100644 surfsense_web/components/assistant-ui/thread-scroll-to-bottom.tsx diff --git a/surfsense_web/components/assistant-ui/chat-viewport.tsx b/surfsense_web/components/assistant-ui/chat-viewport.tsx new file mode 100644 index 000000000..f91a8916a --- /dev/null +++ b/surfsense_web/components/assistant-ui/chat-viewport.tsx @@ -0,0 +1,44 @@ +"use client"; + +import { ThreadPrimitive } from "@assistant-ui/react"; +import { ArrowDownIcon } from "lucide-react"; +import type { FC, ReactNode } from "react"; +import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; + +const ChatScrollToBottom: FC = () => ( + <ThreadPrimitive.ScrollToBottom asChild> + <TooltipIconButton + tooltip="Scroll to bottom" + variant="outline" + className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent" + > + <ArrowDownIcon /> + </TooltipIconButton> + </ThreadPrimitive.ScrollToBottom> +); + +export interface ChatViewportProps { + children: ReactNode; + footer?: ReactNode; +} + +export const ChatViewport: FC<ChatViewportProps> = ({ children, footer }) => ( + <ThreadPrimitive.Viewport + scrollToBottomOnRunStart + scrollToBottomOnInitialize + scrollToBottomOnThreadSwitch + className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4" + style={{ scrollbarGutter: "stable" }} + > + {children} + {footer ? ( + <ThreadPrimitive.ViewportFooter + className="aui-thread-viewport-footer sticky bottom-0 z-10 mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible rounded-t-3xl bg-main-panel pb-4 md:pb-6" + style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }} + > + <ChatScrollToBottom /> + {footer} + </ThreadPrimitive.ViewportFooter> + ) : null} + </ThreadPrimitive.Viewport> +); diff --git a/surfsense_web/components/assistant-ui/nested-scroll.tsx b/surfsense_web/components/assistant-ui/nested-scroll.tsx new file mode 100644 index 000000000..5a4f8d36e --- /dev/null +++ b/surfsense_web/components/assistant-ui/nested-scroll.tsx @@ -0,0 +1,24 @@ +"use client"; + +import { forwardRef, type ComponentPropsWithoutRef, type WheelEvent } from "react"; + +export type NestedScrollProps = ComponentPropsWithoutRef<"div">; + +export const NestedScroll = forwardRef<HTMLDivElement, NestedScrollProps>( + ({ onWheel, ...props }, ref) => { + const handleWheel = (event: WheelEvent<HTMLDivElement>) => { + const el = event.currentTarget; + const canScrollUp = el.scrollTop > 0; + const canScrollDown = el.scrollTop < el.scrollHeight - el.clientHeight - 1; + const goingUp = event.deltaY < 0; + const goingDown = event.deltaY > 0; + if ((goingUp && canScrollUp) || (goingDown && canScrollDown)) { + event.stopPropagation(); + } + onWheel?.(event); + }; + return <div ref={ref} onWheel={handleWheel} {...props} />; + } +); + +NestedScroll.displayName = "NestedScroll"; diff --git a/surfsense_web/components/assistant-ui/thread-scroll-to-bottom.tsx b/surfsense_web/components/assistant-ui/thread-scroll-to-bottom.tsx deleted file mode 100644 index 394ba5d79..000000000 --- a/surfsense_web/components/assistant-ui/thread-scroll-to-bottom.tsx +++ /dev/null @@ -1,18 +0,0 @@ -import { ThreadPrimitive } from "@assistant-ui/react"; -import { ArrowDownIcon } from "lucide-react"; -import type { FC } from "react"; -import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; - -export const ThreadScrollToBottom: FC = () => { - return ( - <ThreadPrimitive.ScrollToBottom asChild> - <TooltipIconButton - tooltip="Scroll to bottom" - variant="outline" - className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent" - > - <ArrowDownIcon /> - </TooltipIconButton> - </ThreadPrimitive.ScrollToBottom> - ); -}; diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index 3e27e7adb..1d24a2a39 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -5,12 +5,10 @@ import { ThreadPrimitive, useAui, useAuiState, - useThreadViewportStore, } from "@assistant-ui/react"; import { useAtom, useAtomValue, useSetAtom } from "jotai"; import { AlertCircle, - ArrowDownIcon, ArrowUpIcon, Camera, ChevronDown, @@ -55,6 +53,7 @@ import { import { currentUserAtom } from "@/atoms/user/user-query.atoms"; import { AssistantMessage } from "@/components/assistant-ui/assistant-message"; import { ChatSessionStatus } from "@/components/assistant-ui/chat-session-status"; +import { ChatViewport } from "@/components/assistant-ui/chat-viewport"; import { ConnectorIndicator } from "@/components/assistant-ui/connector-popup"; import { useDocumentUploadDialog } from "@/components/assistant-ui/document-upload-popup"; import { @@ -112,10 +111,17 @@ const ThreadContent: FC = () => { ["--thread-max-width" as string]: "44rem", }} > - <ThreadPrimitive.Viewport - turnAnchor="top" - className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4" - style={{ scrollbarGutter: "stable" }} + <ChatViewport + footer={ + <> + <AuiIf condition={({ thread }) => !thread.isEmpty}> + <PremiumQuotaPinnedAlert /> + </AuiIf> + <AuiIf condition={({ thread }) => !thread.isEmpty}> + <Composer /> + </AuiIf> + </> + } > <AuiIf condition={({ thread }) => thread.isEmpty}> <ThreadWelcome /> @@ -128,24 +134,7 @@ const ThreadContent: FC = () => { AssistantMessage, }} /> - - <AuiIf condition={({ thread }) => !thread.isEmpty}> - <div className="grow" /> - </AuiIf> - - <ThreadPrimitive.ViewportFooter - className="aui-thread-viewport-footer sticky bottom-0 z-10 mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible rounded-t-3xl bg-main-panel pb-4 md:pb-6" - style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }} - > - <ThreadScrollToBottom /> - <AuiIf condition={({ thread }) => !thread.isEmpty}> - <PremiumQuotaPinnedAlert /> - </AuiIf> - <AuiIf condition={({ thread }) => !thread.isEmpty}> - <Composer /> - </AuiIf> - </ThreadPrimitive.ViewportFooter> - </ThreadPrimitive.Viewport> + </ChatViewport> </ThreadPrimitive.Root> ); }; @@ -181,20 +170,6 @@ const PremiumQuotaPinnedAlert: FC = () => { ); }; -const ThreadScrollToBottom: FC = () => { - return ( - <ThreadPrimitive.ScrollToBottom asChild> - <TooltipIconButton - tooltip="Scroll to bottom" - variant="outline" - className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent" - > - <ArrowDownIcon /> - </TooltipIconButton> - </ThreadPrimitive.ScrollToBottom> - ); -}; - const getTimeBasedGreeting = (user?: { display_name?: string | null; email?: string }): string => { const hour = new Date().getHours(); @@ -411,23 +386,9 @@ const Composer: FC = () => { >(new Map()); const documentPickerRef = useRef<DocumentMentionPickerRef>(null); const promptPickerRef = useRef<PromptPickerRef>(null); - const viewportRef = useRef<Element | null>(null); const { search_space_id, chat_id } = useParams(); const aui = useAui(); - const threadViewportStore = useThreadViewportStore(); const hasAutoFocusedRef = useRef(false); - const submitCleanupRef = useRef<(() => void) | null>(null); - - useEffect(() => { - return () => { - submitCleanupRef.current?.(); - }; - }, []); - - // Store viewport element reference on mount - useEffect(() => { - viewportRef.current = document.querySelector(".aui-thread-viewport"); - }, []); const electronAPI = useElectronAPI(); const [clipboardInitialText, setClipboardInitialText] = useState<string | undefined>(); @@ -626,7 +587,6 @@ const Composer: FC = () => { [showDocumentPopover, showPromptPicker] ); - // Submit message (blocked during streaming, document picker open, or AI responding to another user) const handleSubmit = useCallback(() => { if (isThreadRunning || isBlockedByOtherUser) return; if (showDocumentPopover || showPromptPicker) return; @@ -638,50 +598,9 @@ const Composer: FC = () => { setClipboardInitialText(undefined); } - const viewportEl = viewportRef.current; - const heightBefore = viewportEl?.scrollHeight ?? 0; - aui.composer().send(); editorRef.current?.clear(); setMentionedDocuments([]); - - // With turnAnchor="top", ViewportSlack adds min-height to the last - // assistant message so that scrolling-to-bottom actually positions the - // user message at the TOP of the viewport. That slack height is - // calculated asynchronously (ResizeObserver → style → layout). - // Poll via rAF for ~500ms, re-scrolling whenever scrollHeight changes. - const scrollToBottom = () => - threadViewportStore.getState().scrollToBottom({ behavior: "instant" }); - - let lastHeight = heightBefore; - let frames = 0; - let cancelled = false; - const POLL_FRAMES = 30; - - const pollAndScroll = () => { - if (cancelled) return; - const el = viewportRef.current; - if (el) { - const h = el.scrollHeight; - if (h !== lastHeight) { - lastHeight = h; - scrollToBottom(); - } - } - if (++frames < POLL_FRAMES) { - requestAnimationFrame(pollAndScroll); - } - }; - requestAnimationFrame(pollAndScroll); - - const t1 = setTimeout(scrollToBottom, 100); - const t2 = setTimeout(scrollToBottom, 300); - - submitCleanupRef.current = () => { - cancelled = true; - clearTimeout(t1); - clearTimeout(t2); - }; }, [ showDocumentPopover, showPromptPicker, @@ -690,7 +609,6 @@ const Composer: FC = () => { clipboardInitialText, aui, setMentionedDocuments, - threadViewportStore, ]); const handleDocumentRemove = useCallback( diff --git a/surfsense_web/components/assistant-ui/tool-fallback.tsx b/surfsense_web/components/assistant-ui/tool-fallback.tsx index 66e2ebd4a..cf42cf398 100644 --- a/surfsense_web/components/assistant-ui/tool-fallback.tsx +++ b/surfsense_web/components/assistant-ui/tool-fallback.tsx @@ -13,6 +13,7 @@ import { isDoomLoopInterrupt, } from "@/components/tool-ui/doom-loop-approval"; import { GenericHitlApprovalToolUI } from "@/components/tool-ui/generic-hitl-approval"; +import { NestedScroll } from "@/components/assistant-ui/nested-scroll"; import { AlertDialog, AlertDialogAction, @@ -475,7 +476,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => { {(argsText || isRunning) && ( <div className="flex flex-col gap-1 min-w-0"> <p className="text-xs font-medium text-muted-foreground">Inputs</p> - <div className="max-h-48 overflow-auto rounded-md bg-muted/40"> + <NestedScroll className="max-h-48 overflow-auto rounded-md bg-muted/40"> {argsText ? ( <pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono"> {argsText} @@ -489,7 +490,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => { Waiting for input… </p> )} - </div> + </NestedScroll> </div> )} {!isCancelled && result !== undefined && ( @@ -497,11 +498,11 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => { <Separator /> <div className="flex flex-col gap-1 min-w-0"> <p className="text-xs font-medium text-muted-foreground">Result</p> - <div className="max-h-64 overflow-auto rounded-md bg-muted/40"> + <NestedScroll className="max-h-64 overflow-auto rounded-md bg-muted/40"> <pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono"> {typeof result === "string" ? result : serializedResult} </pre> - </div> + </NestedScroll> </div> </> )} diff --git a/surfsense_web/components/free-chat/free-thread.tsx b/surfsense_web/components/free-chat/free-thread.tsx index bd237004a..933847b2b 100644 --- a/surfsense_web/components/free-chat/free-thread.tsx +++ b/surfsense_web/components/free-chat/free-thread.tsx @@ -1,11 +1,10 @@ "use client"; import { AuiIf, ThreadPrimitive } from "@assistant-ui/react"; -import { ArrowDownIcon } from "lucide-react"; import type { FC } from "react"; import { AssistantMessage } from "@/components/assistant-ui/assistant-message"; +import { ChatViewport } from "@/components/assistant-ui/chat-viewport"; import { EditComposer } from "@/components/assistant-ui/edit-composer"; -import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { UserMessage } from "@/components/assistant-ui/user-message"; import { FreeComposer } from "./free-composer"; @@ -24,20 +23,6 @@ const FreeThreadWelcome: FC = () => { ); }; -const ThreadScrollToBottom: FC = () => { - return ( - <ThreadPrimitive.ScrollToBottom asChild> - <TooltipIconButton - tooltip="Scroll to bottom" - variant="outline" - className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent" - > - <ArrowDownIcon /> - </TooltipIconButton> - </ThreadPrimitive.ScrollToBottom> - ); -}; - export const FreeThread: FC = () => { return ( <ThreadPrimitive.Root @@ -46,10 +31,12 @@ export const FreeThread: FC = () => { ["--thread-max-width" as string]: "44rem", }} > - <ThreadPrimitive.Viewport - turnAnchor="top" - className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4" - style={{ scrollbarGutter: "stable" }} + <ChatViewport + footer={ + <AuiIf condition={({ thread }) => !thread.isEmpty}> + <FreeComposer /> + </AuiIf> + } > <AuiIf condition={({ thread }) => thread.isEmpty}> <FreeThreadWelcome /> @@ -62,21 +49,7 @@ export const FreeThread: FC = () => { AssistantMessage, }} /> - - <AuiIf condition={({ thread }) => !thread.isEmpty}> - <div className="grow" /> - </AuiIf> - - <ThreadPrimitive.ViewportFooter - className="aui-thread-viewport-footer sticky bottom-0 z-10 mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible rounded-t-3xl bg-main-panel pb-4 md:pb-6" - style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }} - > - <ThreadScrollToBottom /> - <AuiIf condition={({ thread }) => !thread.isEmpty}> - <FreeComposer /> - </AuiIf> - </ThreadPrimitive.ViewportFooter> - </ThreadPrimitive.Viewport> + </ChatViewport> </ThreadPrimitive.Root> ); }; diff --git a/surfsense_web/components/public-chat/public-thread.tsx b/surfsense_web/components/public-chat/public-thread.tsx index 22e914988..de91b4451 100644 --- a/surfsense_web/components/public-chat/public-thread.tsx +++ b/surfsense_web/components/public-chat/public-thread.tsx @@ -45,16 +45,17 @@ export const PublicThread: FC<PublicThreadProps> = ({ footer }) => { ["--thread-max-width" as string]: "44rem", }} > - <ThreadPrimitive.Viewport className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4"> + <ThreadPrimitive.Viewport + scrollToBottomOnInitialize + scrollToBottomOnThreadSwitch + className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4 pb-6" + > <ThreadPrimitive.Messages components={{ UserMessage: PublicUserMessage, AssistantMessage: PublicAssistantMessage, }} /> - - {/* Spacer to ensure footer doesn't overlap last message */} - <div className="h-24" /> </ThreadPrimitive.Viewport> {footer && ( From 833b4dd441d0e8053bd2399076fedcf067917617 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 03:10:21 +0530 Subject: [PATCH 29/68] refactor(chat): simplify ChatViewport and footer structure for improved readability and maintainability --- .../components/assistant-ui/chat-viewport.tsx | 26 ++++++++++--------- .../components/assistant-ui/thread.tsx | 12 +++------ .../components/public-chat/public-thread.tsx | 2 +- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/surfsense_web/components/assistant-ui/chat-viewport.tsx b/surfsense_web/components/assistant-ui/chat-viewport.tsx index f91a8916a..a1534df01 100644 --- a/surfsense_web/components/assistant-ui/chat-viewport.tsx +++ b/surfsense_web/components/assistant-ui/chat-viewport.tsx @@ -23,22 +23,24 @@ export interface ChatViewportProps { } export const ChatViewport: FC<ChatViewportProps> = ({ children, footer }) => ( - <ThreadPrimitive.Viewport - scrollToBottomOnRunStart - scrollToBottomOnInitialize - scrollToBottomOnThreadSwitch - className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4" - style={{ scrollbarGutter: "stable" }} - > - {children} + <> + <ThreadPrimitive.Viewport + scrollToBottomOnRunStart + scrollToBottomOnInitialize + scrollToBottomOnThreadSwitch + className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4" + style={{ scrollbarGutter: "stable" }} + > + {children} + </ThreadPrimitive.Viewport> {footer ? ( - <ThreadPrimitive.ViewportFooter - className="aui-thread-viewport-footer sticky bottom-0 z-10 mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible rounded-t-3xl bg-main-panel pb-4 md:pb-6" + <div + className="aui-chat-composer-area relative mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible bg-main-panel px-4 pt-2 pb-4 md:pb-6" style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }} > <ChatScrollToBottom /> {footer} - </ThreadPrimitive.ViewportFooter> + </div> ) : null} - </ThreadPrimitive.Viewport> + </> ); diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index 1d24a2a39..6c02a1efa 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -113,14 +113,10 @@ const ThreadContent: FC = () => { > <ChatViewport footer={ - <> - <AuiIf condition={({ thread }) => !thread.isEmpty}> - <PremiumQuotaPinnedAlert /> - </AuiIf> - <AuiIf condition={({ thread }) => !thread.isEmpty}> - <Composer /> - </AuiIf> - </> + <AuiIf condition={({ thread }) => !thread.isEmpty}> + <PremiumQuotaPinnedAlert /> + <Composer /> + </AuiIf> } > <AuiIf condition={({ thread }) => thread.isEmpty}> diff --git a/surfsense_web/components/public-chat/public-thread.tsx b/surfsense_web/components/public-chat/public-thread.tsx index de91b4451..750b7410e 100644 --- a/surfsense_web/components/public-chat/public-thread.tsx +++ b/surfsense_web/components/public-chat/public-thread.tsx @@ -59,7 +59,7 @@ export const PublicThread: FC<PublicThreadProps> = ({ footer }) => { </ThreadPrimitive.Viewport> {footer && ( - <div className="sticky bottom-0 z-20 border-t bg-main-panel/95 backdrop-blur supports-backdrop-filter:bg-main-panel/60"> + <div className="border-t bg-main-panel/95 backdrop-blur supports-backdrop-filter:bg-main-panel/60"> {footer} </div> )} From b2f487bf36829b3ceb61c654ab6557561a3483ac Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Thu, 30 Apr 2026 15:03:10 -0700 Subject: [PATCH 30/68] feat: added mac signing & notarization for desktop app --- .github/workflows/desktop-release.yml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/desktop-release.yml b/.github/workflows/desktop-release.yml index b955e5014..e356bd3e5 100644 --- a/.github/workflows/desktop-release.yml +++ b/.github/workflows/desktop-release.yml @@ -136,6 +136,14 @@ jobs: AZURE_CODESIGN_ENDPOINT: ${{ vars.AZURE_CODESIGN_ENDPOINT }} AZURE_CODESIGN_ACCOUNT: ${{ vars.AZURE_CODESIGN_ACCOUNT }} AZURE_CODESIGN_PROFILE: ${{ vars.AZURE_CODESIGN_PROFILE }} + # macOS Developer ID signing + notarization. Only the macos-latest runner + # consumes these; Windows/Linux runners ignore them. CSC_LINK accepts either + # a file path or a base64-encoded .p12 blob — electron-builder auto-detects. + CSC_LINK: ${{ secrets.MAC_CERT_P12_BASE64 }} + CSC_KEY_PASSWORD: ${{ secrets.MAC_CERT_PASSWORD }} + APPLE_ID: ${{ secrets.APPLE_ID }} + APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }} + APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }} # Service principal credentials for Azure.Identity EnvironmentCredential used by the # TrustedSigning PowerShell module. Only populated when signing is enabled. # electron-builder 26 does not yet support OIDC federated tokens for Azure signing, From 7b549f84445ef158b97d0270143a88c623d89ab7 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 03:38:21 +0530 Subject: [PATCH 31/68] refactor(chat): enhance ChatViewport with auto-scroll and top fade effect for improved user experience --- .../components/assistant-ui/chat-viewport.tsx | 40 +++++++++++-------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/surfsense_web/components/assistant-ui/chat-viewport.tsx b/surfsense_web/components/assistant-ui/chat-viewport.tsx index a1534df01..d3d664ace 100644 --- a/surfsense_web/components/assistant-ui/chat-viewport.tsx +++ b/surfsense_web/components/assistant-ui/chat-viewport.tsx @@ -23,24 +23,30 @@ export interface ChatViewportProps { } export const ChatViewport: FC<ChatViewportProps> = ({ children, footer }) => ( - <> - <ThreadPrimitive.Viewport - scrollToBottomOnRunStart - scrollToBottomOnInitialize - scrollToBottomOnThreadSwitch - className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4" - style={{ scrollbarGutter: "stable" }} - > - {children} - </ThreadPrimitive.Viewport> + <ThreadPrimitive.Viewport + turnAnchor="top" + autoScroll + scrollToBottomOnRunStart + scrollToBottomOnInitialize + scrollToBottomOnThreadSwitch + className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 [scroll-behavior:smooth]" + style={{ scrollbarGutter: "stable" }} + > + <div + aria-hidden + className="aui-chat-viewport-top-fade pointer-events-none sticky top-0 z-10 -mx-4 h-2 shrink-0 bg-gradient-to-b from-main-panel from-20% to-transparent" + /> + {children} {footer ? ( - <div - className="aui-chat-composer-area relative mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible bg-main-panel px-4 pt-2 pb-4 md:pb-6" - style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }} + <ThreadPrimitive.ViewportFooter + className="aui-chat-composer-footer sticky bottom-0 z-20 -mx-4 flex flex-col items-stretch bg-gradient-to-t from-main-panel from-60% to-transparent px-4 pt-6" + style={{ paddingBottom: "max(0.5rem, env(safe-area-inset-bottom))" }} > - <ChatScrollToBottom /> - {footer} - </div> + <div className="aui-chat-composer-area relative mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-3 overflow-visible"> + <ChatScrollToBottom /> + {footer} + </div> + </ThreadPrimitive.ViewportFooter> ) : null} - </> + </ThreadPrimitive.Viewport> ); From 511f4fde6440378a111fb7bdc3f84cbf4b9c85c1 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 03:40:14 +0530 Subject: [PATCH 32/68] refactor(chat): update ChatViewport className for improved scroll behavior consistency --- surfsense_web/components/assistant-ui/chat-viewport.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/surfsense_web/components/assistant-ui/chat-viewport.tsx b/surfsense_web/components/assistant-ui/chat-viewport.tsx index d3d664ace..f7f1ac188 100644 --- a/surfsense_web/components/assistant-ui/chat-viewport.tsx +++ b/surfsense_web/components/assistant-ui/chat-viewport.tsx @@ -29,7 +29,7 @@ export const ChatViewport: FC<ChatViewportProps> = ({ children, footer }) => ( scrollToBottomOnRunStart scrollToBottomOnInitialize scrollToBottomOnThreadSwitch - className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 [scroll-behavior:smooth]" + className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 scroll-smooth" style={{ scrollbarGutter: "stable" }} > <div From 8b4f1366684e69cfb403ec9942a7ac0d2cc677d9 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 04:02:24 +0530 Subject: [PATCH 33/68] refactor(chat): enhance UserMessage component with mention parsing and segment rendering for improved message display --- .../components/assistant-ui/chat-viewport.tsx | 2 +- .../components/assistant-ui/user-message.tsx | 121 ++++++------------ .../lib/chat/parse-mention-segments.ts | 54 ++++++++ 3 files changed, 97 insertions(+), 80 deletions(-) create mode 100644 surfsense_web/lib/chat/parse-mention-segments.ts diff --git a/surfsense_web/components/assistant-ui/chat-viewport.tsx b/surfsense_web/components/assistant-ui/chat-viewport.tsx index f7f1ac188..c0684407e 100644 --- a/surfsense_web/components/assistant-ui/chat-viewport.tsx +++ b/surfsense_web/components/assistant-ui/chat-viewport.tsx @@ -39,7 +39,7 @@ export const ChatViewport: FC<ChatViewportProps> = ({ children, footer }) => ( {children} {footer ? ( <ThreadPrimitive.ViewportFooter - className="aui-chat-composer-footer sticky bottom-0 z-20 -mx-4 flex flex-col items-stretch bg-gradient-to-t from-main-panel from-60% to-transparent px-4 pt-6" + className="aui-chat-composer-footer sticky bottom-0 z-20 -mx-4 mt-auto flex flex-col items-stretch bg-gradient-to-t from-main-panel from-60% to-transparent px-4 pt-6" style={{ paddingBottom: "max(0.5rem, env(safe-area-inset-bottom))" }} > <div className="aui-chat-composer-area relative mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-3 overflow-visible"> diff --git a/surfsense_web/components/assistant-ui/user-message.tsx b/surfsense_web/components/assistant-ui/user-message.tsx index fb7212119..145ac2d7e 100644 --- a/surfsense_web/components/assistant-ui/user-message.tsx +++ b/surfsense_web/components/assistant-ui/user-message.tsx @@ -1,4 +1,10 @@ -import { ActionBarPrimitive, AuiIf, MessagePrimitive, useAuiState } from "@assistant-ui/react"; +import { + ActionBarPrimitive, + AuiIf, + MessagePrimitive, + useAuiState, + useMessagePartText, +} from "@assistant-ui/react"; import { useAtomValue } from "jotai"; import { CheckIcon, CopyIcon, Pencil } from "lucide-react"; import Image from "next/image"; @@ -7,6 +13,8 @@ import { currentThreadAtom } from "@/atoms/chat/current-thread.atom"; import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; +import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; +import { parseMentionSegments } from "@/lib/chat/parse-mention-segments"; interface AuthorMetadata { displayName: string | null; @@ -47,23 +55,40 @@ const UserAvatar: FC<AuthorMetadata> = ({ displayName, avatarUrl }) => { ); }; -export const UserMessage: FC = () => { +const UserTextPart: FC = () => { const messageId = useAuiState(({ message }) => message?.id); - const messageText = useAuiState(({ message }) => - (message?.content ?? []) - .map((part) => - typeof part === "object" && - part !== null && - "type" in part && - (part as { type?: string }).type === "text" && - "text" in part - ? String((part as { text?: string }).text ?? "") - : "" - ) - .join("") - ); + const part = useMessagePartText(); + const text = (part as { text?: string }).text ?? ""; const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom); - const mentionedDocs = messageId ? messageDocumentsMap[messageId] : undefined; + const mentionedDocs = (messageId ? messageDocumentsMap[messageId] : undefined) ?? []; + + const segments = parseMentionSegments(text, mentionedDocs); + + return ( + <p style={{ whiteSpace: "pre-line" }} className="break-words"> + {segments.map((segment) => + segment.type === "text" ? ( + <span key={`txt-${segment.start}`}>{segment.value}</span> + ) : ( + <span + key={`mention-${getMentionDocKey(segment.doc)}-${segment.start}`} + className="inline-flex items-center gap-1 mx-0.5 px-1 py-0.5 rounded bg-primary/10 text-xs font-bold text-primary/60 select-none align-middle leading-none" + title={segment.doc.title} + > + <span className="flex items-center text-muted-foreground"> + {getConnectorIcon(segment.doc.document_type ?? "UNKNOWN", "h-3 w-3")} + </span> + <span className="max-w-[120px] truncate">{segment.doc.title}</span> + </span> + ) + )} + </p> + ); +}; + +const userMessageParts = { Text: UserTextPart }; + +export const UserMessage: FC = () => { const metadata = useAuiState(({ message }) => message?.metadata); const author = metadata?.custom?.author as AuthorMetadata | undefined; const isSharedChat = useAtomValue(currentThreadAtom).visibility === "SEARCH_SPACE"; @@ -78,11 +103,7 @@ export const UserMessage: FC = () => { <div className="aui-user-message-content-wrapper flex items-end gap-2"> <div className="relative flex-1 min-w-0"> <div className="aui-user-message-content wrap-break-word rounded-2xl bg-muted px-4 py-2.5 text-foreground"> - {mentionedDocs && mentionedDocs.length > 0 ? ( - <UserMessageWithMentionChips text={messageText} mentionedDocs={mentionedDocs} /> - ) : ( - <MessagePrimitive.Parts /> - )} + <MessagePrimitive.Parts components={userMessageParts} /> </div> <div className="absolute right-0 top-full mt-1 z-10 opacity-100 pointer-events-auto md:opacity-0 md:pointer-events-none md:transition-opacity md:duration-200 md:delay-300 md:group-hover/user-msg:opacity-100 md:group-hover/user-msg:delay-0 md:group-hover/user-msg:pointer-events-auto"> <UserActionBar /> @@ -99,64 +120,6 @@ export const UserMessage: FC = () => { ); }; -const UserMessageWithMentionChips: FC<{ - text: string; - mentionedDocs: { id: number; title: string; document_type: string }[]; -}> = ({ text, mentionedDocs }) => { - type Segment = - | { type: "text"; value: string; start: number } - | { type: "mention"; doc: { id: number; title: string; document_type: string }; start: number }; - - const tokens = mentionedDocs - .map((doc) => ({ doc, token: `@${doc.title}` })) - .sort((a, b) => b.token.length - a.token.length); - - const segments: Segment[] = []; - let i = 0; - let buffer = ""; - let bufferStart = 0; - while (i < text.length) { - const tokenMatch = tokens.find(({ token }) => text.startsWith(token, i)); - if (tokenMatch) { - if (buffer) { - segments.push({ type: "text", value: buffer, start: bufferStart }); - buffer = ""; - } - segments.push({ type: "mention", doc: tokenMatch.doc, start: i }); - i += tokenMatch.token.length; - bufferStart = i; - continue; - } - if (!buffer) bufferStart = i; - buffer += text[i]; - i += 1; - } - if (buffer) { - segments.push({ type: "text", value: buffer, start: bufferStart }); - } - - return ( - <span className="whitespace-pre-wrap break-words"> - {segments.map((segment) => - segment.type === "text" ? ( - <span key={`txt-${segment.start}`}>{segment.value}</span> - ) : ( - <span - key={`mention-${segment.doc.document_type}:${segment.doc.id}-${segment.start}`} - className="inline-flex items-center gap-1 mx-0.5 px-1 py-0.5 rounded bg-primary/10 text-xs font-bold text-primary/60 select-none align-baseline" - title={segment.doc.title} - > - <span className="flex items-center text-muted-foreground"> - {getConnectorIcon(segment.doc.document_type ?? "UNKNOWN", "h-3 w-3")} - </span> - <span className="max-w-[120px] truncate">{segment.doc.title}</span> - </span> - ) - )} - </span> - ); -}; - const UserActionBar: FC = () => { const isThreadRunning = useAuiState(({ thread }) => thread.isRunning); diff --git a/surfsense_web/lib/chat/parse-mention-segments.ts b/surfsense_web/lib/chat/parse-mention-segments.ts new file mode 100644 index 000000000..b9cf59792 --- /dev/null +++ b/surfsense_web/lib/chat/parse-mention-segments.ts @@ -0,0 +1,54 @@ +import type { MentionedDocumentInfo } from "@/atoms/chat/mentioned-documents.atom"; + +export type MentionSegment = + | { type: "text"; value: string; start: number } + | { type: "mention"; doc: MentionedDocumentInfo; start: number }; + +/** + * Tokenizes a user message into text and `@mention` segments. + * + * Pure: no React, no DOM, no side effects. Safe to unit-test and reuse. + * + * Mentions are matched greedily by longest title first so that a longer title + * (e.g. `@Project Roadmap`) is never shadowed by a shorter prefix + * (e.g. `@Project`). + */ +export function parseMentionSegments( + text: string, + docs: ReadonlyArray<MentionedDocumentInfo> +): MentionSegment[] { + if (text.length === 0) return []; + if (docs.length === 0) return [{ type: "text", value: text, start: 0 }]; + + const tokens = docs + .map((doc) => ({ doc, token: `@${doc.title}` })) + .sort((a, b) => b.token.length - a.token.length); + + const segments: MentionSegment[] = []; + let i = 0; + let buffer = ""; + let bufferStart = 0; + + while (i < text.length) { + const tokenMatch = tokens.find(({ token }) => text.startsWith(token, i)); + if (tokenMatch) { + if (buffer) { + segments.push({ type: "text", value: buffer, start: bufferStart }); + buffer = ""; + } + segments.push({ type: "mention", doc: tokenMatch.doc, start: i }); + i += tokenMatch.token.length; + bufferStart = i; + continue; + } + if (!buffer) bufferStart = i; + buffer += text[i]; + i += 1; + } + + if (buffer) { + segments.push({ type: "text", value: buffer, start: bufferStart }); + } + + return segments; +} From 3a73912a86f9c7fac85cc1bb1fd228563874d385 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Thu, 30 Apr 2026 15:39:12 -0700 Subject: [PATCH 34/68] feat(desktop): enable hardened runtime and entitlements for mac signing Made-with: Cursor --- .../build/entitlements.mac.plist | 35 +++++++++++++++++++ surfsense_desktop/electron-builder.yml | 5 ++- 2 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 surfsense_desktop/build/entitlements.mac.plist diff --git a/surfsense_desktop/build/entitlements.mac.plist b/surfsense_desktop/build/entitlements.mac.plist new file mode 100644 index 000000000..5647e7759 --- /dev/null +++ b/surfsense_desktop/build/entitlements.mac.plist @@ -0,0 +1,35 @@ +<?xml version="1.0" encoding="UTF-8"?> +<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd"> +<plist version="1.0"> +<dict> + <!-- Required for Electron's V8 JIT under hardened runtime --> + <key>com.apple.security.cs.allow-jit</key> + <true/> + <key>com.apple.security.cs.allow-unsigned-executable-memory</key> + <true/> + + <!-- node-mac-permissions and other native deps load dylibs at runtime --> + <key>com.apple.security.cs.allow-dyld-environment-variables</key> + <true/> + <key>com.apple.security.cs.disable-library-validation</key> + <true/> + + <!-- Networking (OAuth, API calls, auto-updater, deep links) --> + <key>com.apple.security.network.client</key> + <true/> + <key>com.apple.security.network.server</key> + <true/> + + <!-- Screen Capture / Screenshot Assist --> + <key>com.apple.security.device.camera</key> + <true/> + + <!-- Accessibility / Apple Events used by general-assist --> + <key>com.apple.security.automation.apple-events</key> + <true/> + + <!-- File access for folder watcher / agent filesystem features --> + <key>com.apple.security.files.user-selected.read-write</key> + <true/> +</dict> +</plist> diff --git a/surfsense_desktop/electron-builder.yml b/surfsense_desktop/electron-builder.yml index b0014a57b..e4e7670ec 100644 --- a/surfsense_desktop/electron-builder.yml +++ b/surfsense_desktop/electron-builder.yml @@ -46,8 +46,11 @@ mac: icon: assets/icon.icns category: public.app-category.productivity artifactName: "${productName}-${version}-${arch}.${ext}" - hardenedRuntime: false + hardenedRuntime: true gatekeeperAssess: false + entitlements: build/entitlements.mac.plist + entitlementsInherit: build/entitlements.mac.plist + notarize: true extendInfo: NSAccessibilityUsageDescription: "SurfSense uses accessibility features to bring the app to the foreground and interact with the active application when you use desktop assists." NSScreenCaptureUsageDescription: "SurfSense uses screen capture so you can attach a selected region to chat (Screenshot Assist) or capture the full screen from the composer." From 0883ac88fb54653223ff2477724d372531fa1301 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 04:23:59 +0530 Subject: [PATCH 35/68] refactor(chat): enhance InlineMentionEditor with improved mention handling and text processing for better user interaction --- .../assistant-ui/inline-mention-editor.tsx | 1078 ++++++----------- 1 file changed, 391 insertions(+), 687 deletions(-) diff --git a/surfsense_web/components/assistant-ui/inline-mention-editor.tsx b/surfsense_web/components/assistant-ui/inline-mention-editor.tsx index 05277f508..d92348080 100644 --- a/surfsense_web/components/assistant-ui/inline-mention-editor.tsx +++ b/surfsense_web/components/assistant-ui/inline-mention-editor.tsx @@ -1,26 +1,13 @@ "use client"; -import { X } from "lucide-react"; -import type { ReactElement } from "react"; -import { - createElement, - forwardRef, - useCallback, - useEffect, - useImperativeHandle, - useRef, - useState, -} from "react"; -import { renderToStaticMarkup } from "react-dom/server"; +import { type FC, forwardRef, useCallback, useImperativeHandle, useMemo, useRef } from "react"; +import { Plate, PlateContent, ParagraphPlugin, createPlatePlugin, usePlateEditor } from "platejs/react"; +import type { PlateElementProps } from "platejs/react"; import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import type { Document } from "@/contracts/types/document.types"; import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; import { cn } from "@/lib/utils"; -function renderElementToHTML(element: ReactElement): string { - return renderToStaticMarkup(element); -} - export interface MentionedDocument { id: number; title: string; @@ -61,38 +48,174 @@ interface InlineMentionEditorProps { initialText?: string; } -// Unique data attribute to identify chip elements -const CHIP_DATA_ATTR = "data-mention-chip"; -const CHIP_ID_ATTR = "data-mention-id"; -const CHIP_DOCTYPE_ATTR = "data-mention-doctype"; -const CHIP_STATUS_ATTR = "data-mention-status"; +type MentionStatusKind = "pending" | "processing" | "ready" | "failed"; +type ComposerTextNode = { text: string }; +type MentionElementNode = { + type: "mention"; + id: number; + title: string; + document_type?: string; + statusLabel?: string | null; + statusKind?: MentionStatusKind; + children: [{ text: "" }]; +}; +type ComposerNode = ComposerTextNode | MentionElementNode; +type ComposerParagraph = { type: "p"; children: ComposerNode[] }; +type ComposerValue = ComposerParagraph[]; + +const MENTION_TYPE = "mention"; +const MENTION_CHIP_CLASSNAME = + "inline-flex h-5 items-center gap-1 mx-0.5 rounded bg-primary/10 px-1 text-xs font-bold text-primary/60 select-none align-middle leading-none"; +const MENTION_CHIP_ICON_CLASSNAME = "flex items-center text-muted-foreground leading-none"; +const MENTION_CHIP_TITLE_CLASSNAME = "max-w-[120px] truncate leading-none"; +const COMPOSER_TEXT_METRICS_CLASSNAME = "text-sm leading-6"; + +const EMPTY_VALUE: ComposerValue = [{ type: "p", children: [{ text: "" }] }]; + +const MentionElement: FC<PlateElementProps<MentionElementNode>> = ({ attributes, children, element }) => { + const statusClass = + element.statusKind === "failed" + ? "text-destructive" + : element.statusKind === "ready" + ? "text-emerald-700" + : "text-amber-700"; -/** - * Type guard to check if a node is a chip element - */ -function isChipElement(node: Node | null): node is HTMLSpanElement { return ( - node !== null && - node.nodeType === Node.ELEMENT_NODE && - (node as Element).hasAttribute(CHIP_DATA_ATTR) + <span {...attributes} className="inline-flex align-middle"> + <span contentEditable={false} className={`${MENTION_CHIP_CLASSNAME} cursor-default`}> + <span className={MENTION_CHIP_ICON_CLASSNAME}> + {getConnectorIcon(element.document_type ?? "UNKNOWN", "h-3 w-3")} + </span> + <span className={MENTION_CHIP_TITLE_CLASSNAME} title={element.title}> + {element.title} + </span> + {element.statusLabel ? ( + <span className={cn("text-[10px] font-semibold opacity-80", statusClass)}> + {element.statusLabel} + </span> + ) : null} + </span> + {children} + </span> ); +}; + +const MentionPlugin = createPlatePlugin({ + key: MENTION_TYPE, + node: { + isElement: true, + isInline: true, + isVoid: true, + type: MENTION_TYPE, + component: MentionElement, + }, +}); + +function isMentionNode(node: ComposerNode): node is MentionElementNode { + return typeof node === "object" && "type" in node && node.type === MENTION_TYPE; } -/** - * Safely parse chip ID from element attribute - */ -function getChipId(element: Element): number | null { - const idStr = element.getAttribute(CHIP_ID_ATTR); - if (!idStr) return null; - const id = parseInt(idStr, 10); - return Number.isNaN(id) ? null : id; +function getTextNode(node: ComposerNode): ComposerTextNode | null { + if (typeof node === "object" && "text" in node && typeof node.text === "string") return node; + return null; } -/** - * Get chip document type from element attribute - */ -function getChipDocType(element: Element): string { - return element.getAttribute(CHIP_DOCTYPE_ATTR) ?? "UNKNOWN"; +function toValueFromText(text: string): ComposerValue { + const lines = text.split("\n"); + if (lines.length === 0) return EMPTY_VALUE; + return lines.map((line) => ({ type: "p", children: [{ text: line }] })) as ComposerValue; +} + +function getPlainText(value: ComposerValue): string { + const lines = value.map((block) => + block.children + .map((node) => { + if (isMentionNode(node)) return `@${node.title}`; + return getTextNode(node)?.text ?? ""; + }) + .join("") + ); + return lines.join("\n").trim(); +} + +function getMentionedDocuments(value: ComposerValue): MentionedDocument[] { + const map = new Map<string, MentionedDocument>(); + for (const block of value) { + for (const node of block.children) { + if (!isMentionNode(node)) continue; + const doc: MentionedDocument = { + id: node.id, + title: node.title, + document_type: node.document_type, + }; + map.set(getMentionDocKey(doc), doc); + } + } + return Array.from(map.values()); +} + +type EditorSelection = { + anchor: { path: number[]; offset: number }; + focus: { path: number[]; offset: number }; +} | null; + +function getCursorTextContext(value: ComposerValue, selection: EditorSelection) { + if (!selection || !selection.anchor || !selection.focus) return null; + if ( + selection.anchor.path.length < 2 || + selection.focus.path.length < 2 || + selection.anchor.path[0] !== selection.focus.path[0] || + selection.anchor.path[1] !== selection.focus.path[1] + ) { + return null; + } + + const block = value[selection.anchor.path[0]]; + if (!block) return null; + const child = block.children[selection.anchor.path[1]]; + const textNode = getTextNode(child); + if (!textNode) return null; + + return { + blockIndex: selection.anchor.path[0], + childIndex: selection.anchor.path[1], + text: textNode.text, + cursor: selection.anchor.offset, + }; +} + +function scanActiveTrigger(text: string, cursor: number) { + let wordStart = 0; + for (let i = cursor - 1; i >= 0; i--) { + if (text[i] === " " || text[i] === "\n") { + wordStart = i + 1; + break; + } + } + + let triggerChar: "@" | "/" | null = null; + let triggerIndex = -1; + for (let i = wordStart; i < cursor; i++) { + if (text[i] === "@" || text[i] === "/") { + triggerChar = text[i] as "@" | "/"; + triggerIndex = i; + break; + } + } + if (!triggerChar || triggerIndex === -1) return null; + + const query = text.slice(triggerIndex + 1, cursor); + if (query.startsWith(" ")) return null; + if ( + triggerChar === "/" && + triggerIndex > 0 && + text[triggerIndex - 1] !== " " && + text[triggerIndex - 1] !== "\n" + ) { + return null; + } + + return { triggerChar, query }; } export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMentionEditorProps>( @@ -113,393 +236,159 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent }, ref ) => { - const editorRef = useRef<HTMLDivElement>(null); - const [isEmpty, setIsEmpty] = useState(true); - const [mentionedDocs, setMentionedDocs] = useState<Map<string, MentionedDocument>>( - () => new Map() - ); - const isComposingRef = useRef(false); - const lastSelectionRangeRef = useRef<Range | null>(null); - const isRangeInsideEditor = useCallback((range: Range | null): range is Range => { - if (!range || !editorRef.current) return false; - return ( - editorRef.current.contains(range.startContainer) && - editorRef.current.contains(range.endContainer) - ); - }, []); - const isSelectionInsideEditor = useCallback( - (selection: Selection | null): selection is Selection => { - if (!selection || selection.rangeCount === 0 || !editorRef.current) return false; - const range = selection.getRangeAt(0); - return isRangeInsideEditor(range); - }, - [isRangeInsideEditor] - ); + const editableRef = useRef<HTMLDivElement | null>(null); + const editor = usePlateEditor({ + readOnly: disabled, + plugins: [ParagraphPlugin, MentionPlugin], + value: initialText ? toValueFromText(initialText) : EMPTY_VALUE, + }); - const rememberSelection = useCallback(() => { - const selection = window.getSelection(); - if (!isSelectionInsideEditor(selection)) return; - lastSelectionRangeRef.current = selection.getRangeAt(0).cloneRange(); - }, [isSelectionInsideEditor]); - - const restoreRememberedSelection = useCallback((): Selection | null => { - const selection = window.getSelection(); - if (!selection) return null; - if (!isRangeInsideEditor(lastSelectionRangeRef.current)) return null; - selection.removeAllRanges(); - selection.addRange(lastSelectionRangeRef.current.cloneRange()); - return selection; - }, [isRangeInsideEditor]); - - useEffect(() => { - const handleSelectionChange = () => { - if (document.activeElement !== editorRef.current) return; - rememberSelection(); - }; - document.addEventListener("selectionchange", handleSelectionChange); - return () => document.removeEventListener("selectionchange", handleSelectionChange); - }, [rememberSelection]); - - useEffect(() => { - if (!initialText || !editorRef.current) return; - editorRef.current.innerText = initialText; - editorRef.current.appendChild(document.createElement("br")); - editorRef.current.appendChild(document.createElement("br")); - setIsEmpty(false); - onChange?.(initialText, []); - editorRef.current.focus(); - const sel = window.getSelection(); - const range = document.createRange(); - range.selectNodeContents(editorRef.current); - range.collapse(false); - sel?.removeAllRanges(); - sel?.addRange(range); - const anchor = document.createElement("span"); - range.insertNode(anchor); - anchor.scrollIntoView({ block: "end" }); - anchor.remove(); - }, [initialText, onChange]); - - // Focus at the end of the editor const focusAtEnd = useCallback(() => { - if (!editorRef.current) return; - editorRef.current.focus(); + const el = editableRef.current; + if (!el) return; + el.focus(); const selection = window.getSelection(); const range = document.createRange(); - range.selectNodeContents(editorRef.current); + range.selectNodeContents(el); range.collapse(false); selection?.removeAllRanges(); selection?.addRange(range); }, []); - // Get plain text content with inline mention tokens for chips. - // This preserves the original query structure sent to the backend/LLM. - const getText = useCallback((): string => { - if (!editorRef.current) return ""; + const getCurrentValue = useCallback(() => (editor.children as ComposerValue) ?? EMPTY_VALUE, [editor]); - const extractText = (node: Node): string => { - if (node.nodeType === Node.TEXT_NODE) { - return node.textContent ?? ""; - } - - if (node.nodeType === Node.ELEMENT_NODE) { - const element = node as Element; - - // Preserve mention chips as inline @title tokens. - if (element.hasAttribute(CHIP_DATA_ATTR)) { - const title = element.querySelector("[data-mention-title='true']")?.textContent?.trim(); - if (title) { - return `@${title}`; - } - return ""; - } - - let result = ""; - for (const child of Array.from(element.childNodes)) { - result += extractText(child); - } - return result; - } - - return ""; - }; - - return extractText(editorRef.current).trim(); - }, []); - - // Get all mentioned documents - const getMentionedDocuments = useCallback((): MentionedDocument[] => { - return Array.from(mentionedDocs.values()); - }, [mentionedDocs]); - - const syncEditorState = useCallback( - (docsOverride?: Map<string, MentionedDocument>) => { - const docs = docsOverride - ? Array.from(docsOverride.values()) - : Array.from(mentionedDocs.values()); - const text = getText(); - const empty = text.length === 0 && docs.length === 0; - setIsEmpty(empty); + const emitState = useCallback( + (nextValue: ComposerValue) => { + const text = getPlainText(nextValue); + const docs = getMentionedDocuments(nextValue); onChange?.(text, docs); - }, - [getText, mentionedDocs, onChange] - ); - // Create a chip element for a document - const createChipElement = useCallback( - (doc: MentionedDocument): HTMLSpanElement => { - const chip = document.createElement("span"); - chip.setAttribute(CHIP_DATA_ATTR, "true"); - chip.setAttribute(CHIP_ID_ATTR, String(doc.id)); - chip.setAttribute(CHIP_DOCTYPE_ATTR, doc.document_type ?? "UNKNOWN"); - chip.contentEditable = "false"; - chip.className = - "inline-flex items-center gap-1 mx-0.5 px-1 py-0.5 rounded bg-primary/10 text-xs font-bold text-primary/60 select-none cursor-default"; - chip.style.userSelect = "none"; - chip.style.verticalAlign = "baseline"; - - // Container that swaps between icon and remove button on hover - const iconContainer = document.createElement("span"); - iconContainer.className = "shrink-0 flex items-center size-3 relative"; - - const iconSpan = document.createElement("span"); - iconSpan.className = "flex items-center text-muted-foreground"; - iconSpan.innerHTML = renderElementToHTML( - getConnectorIcon(doc.document_type ?? "UNKNOWN", "h-3 w-3") - ); - - const removeBtn = document.createElement("button"); - removeBtn.type = "button"; - removeBtn.className = - "size-3 items-center justify-center rounded-full text-muted-foreground transition-colors"; - removeBtn.style.display = "none"; - removeBtn.innerHTML = renderElementToHTML( - createElement(X, { className: "h-3 w-3", strokeWidth: 2.5 }) - ); - removeBtn.onclick = (e) => { - e.preventDefault(); - e.stopPropagation(); - chip.remove(); - const docKey = getMentionDocKey(doc); - setMentionedDocs((prev) => { - const next = new Map(prev); - next.delete(docKey); - syncEditorState(next); - return next; - }); - onDocumentRemove?.(doc.id, doc.document_type); - focusAtEnd(); - }; - - const titleSpan = document.createElement("span"); - titleSpan.className = "max-w-[120px] truncate"; - titleSpan.textContent = doc.title; - titleSpan.title = doc.title; - titleSpan.setAttribute("data-mention-title", "true"); - - const statusSpan = document.createElement("span"); - statusSpan.setAttribute(CHIP_STATUS_ATTR, "true"); - statusSpan.className = "text-[10px] font-semibold opacity-80 hidden"; - - const isTouchDevice = window.matchMedia("(hover: none)").matches; - if (isTouchDevice) { - // Mobile: icon on left, title, X on right - chip.appendChild(iconSpan); - chip.appendChild(titleSpan); - chip.appendChild(statusSpan); - removeBtn.style.display = "flex"; - removeBtn.className += " ml-0.5"; - chip.appendChild(removeBtn); - } else { - // Desktop: icon/X swap on hover in the same slot - iconContainer.appendChild(iconSpan); - iconContainer.appendChild(removeBtn); - chip.addEventListener("mouseenter", () => { - iconSpan.style.display = "none"; - removeBtn.style.display = "flex"; - }); - chip.addEventListener("mouseleave", () => { - iconSpan.style.display = ""; - removeBtn.style.display = "none"; - }); - chip.appendChild(iconContainer); - chip.appendChild(titleSpan); - chip.appendChild(statusSpan); + const cursorCtx = getCursorTextContext(nextValue, editor.selection); + if (!cursorCtx) { + onMentionClose?.(); + onActionClose?.(); + return; } - return chip; + const trigger = scanActiveTrigger(cursorCtx.text, cursorCtx.cursor); + if (!trigger) { + onMentionClose?.(); + onActionClose?.(); + return; + } + + if (trigger.triggerChar === "@") { + onMentionTrigger?.(trigger.query); + onActionClose?.(); + return; + } + + onActionTrigger?.(trigger.query); + onMentionClose?.(); }, - [focusAtEnd, onDocumentRemove, syncEditorState] + [editor.selection, onActionClose, onActionTrigger, onChange, onMentionClose, onMentionTrigger] + ); + + const setValue = useCallback( + (nextValue: ComposerValue) => { + const tf = editor.tf as { setValue: (value: ComposerValue) => void }; + tf.setValue(nextValue); + emitState(nextValue); + }, + [editor, emitState] ); - // Insert a document chip at the current cursor position const insertDocumentChip = useCallback( ( doc: Pick<Document, "id" | "title" | "document_type">, options?: { removeTriggerText?: boolean } ) => { - if (!editorRef.current) return; + if (typeof doc.id !== "number" || typeof doc.title !== "string") return; + const removeTriggerText = options?.removeTriggerText ?? true; - - // Validate required fields for type safety - if (typeof doc.id !== "number" || typeof doc.title !== "string") { - console.warn("[InlineMentionEditor] Invalid document passed to insertDocumentChip:", doc); - return; - } - - const mentionDoc: MentionedDocument = { + const current = getCurrentValue(); + const selection = editor.selection; + const mentionNode: MentionElementNode = { + type: MENTION_TYPE, id: doc.id, title: doc.title, document_type: doc.document_type, + children: [{ text: "" }], }; - // Add to mentioned docs map using unique key - const docKey = getMentionDocKey(doc); - setMentionedDocs((prev) => new Map(prev).set(docKey, mentionDoc)); - const nextDocs = new Map(mentionedDocs); - nextDocs.set(docKey, mentionDoc); - - // Find and remove the @query text - const selection = window.getSelection(); - const hasActiveSelection = isSelectionInsideEditor(selection); - const resolvedSelection = hasActiveSelection ? selection : restoreRememberedSelection(); - if ( - !resolvedSelection || - resolvedSelection.rangeCount === 0 || - !isSelectionInsideEditor(resolvedSelection) - ) { - // No valid in-editor selection: deterministically insert at end. - editorRef.current.focus(); - const endSelection = window.getSelection(); - if (!endSelection) return; - const endRange = document.createRange(); - endRange.selectNodeContents(editorRef.current); - endRange.collapse(false); - endSelection.removeAllRanges(); - endSelection.addRange(endRange); - - const chip = createChipElement(mentionDoc); - endRange.insertNode(chip); - endRange.setStartAfter(chip); - endRange.collapse(true); - const space = document.createTextNode(" "); - endRange.insertNode(space); - endRange.setStartAfter(space); - endRange.collapse(true); - endSelection.removeAllRanges(); - endSelection.addRange(endRange); - - syncEditorState(nextDocs); - rememberSelection(); + const cursorCtx = getCursorTextContext(current, selection); + if (!cursorCtx) { + const lastBlock = current[current.length - 1] ?? { type: "p", children: [{ text: "" }] }; + const appended: ComposerValue = [ + ...current.slice(0, -1), + { + ...lastBlock, + children: [...lastBlock.children, mentionNode, { text: " " }], + }, + ]; + setValue(appended); + requestAnimationFrame(focusAtEnd); return; } - // Find the @ symbol before the cursor and remove it along with any query text - const range = resolvedSelection.getRangeAt(0); - const textNode = range.startContainer; - - if (textNode.nodeType === Node.TEXT_NODE && removeTriggerText) { - const text = textNode.textContent || ""; - const cursorPos = range.startOffset; - - // Find the @ symbol before cursor - let atIndex = -1; - for (let i = cursorPos - 1; i >= 0; i--) { - if (text[i] === "@") { - atIndex = i; - break; - } - } - - if (atIndex !== -1) { - // Remove @query and insert chip - const beforeAt = text.slice(0, atIndex); - const afterCursor = text.slice(cursorPos); - - // Create chip - const chip = createChipElement(mentionDoc); - - // Replace text node content - const parent = textNode.parentNode; - if (parent) { - const beforeNode = document.createTextNode(beforeAt); - const afterNode = document.createTextNode(` ${afterCursor}`); - - parent.insertBefore(beforeNode, textNode); - parent.insertBefore(chip, textNode); - parent.insertBefore(afterNode, textNode); - parent.removeChild(textNode); - - // Set cursor after the chip - const newRange = document.createRange(); - newRange.setStart(afterNode, 1); - newRange.collapse(true); - resolvedSelection.removeAllRanges(); - resolvedSelection.addRange(newRange); - rememberSelection(); - } - } else { - // No @ found, just insert at cursor - const chip = createChipElement(mentionDoc); - range.insertNode(chip); - range.setStartAfter(chip); - range.collapse(true); - - // Add space after chip - const space = document.createTextNode(" "); - range.insertNode(space); - range.setStartAfter(space); - range.collapse(true); - resolvedSelection.removeAllRanges(); - resolvedSelection.addRange(range); - rememberSelection(); - } - } else { - // Either explicit non-trigger insertion or no @query present. - const chip = createChipElement(mentionDoc); - range.insertNode(chip); - range.setStartAfter(chip); - range.collapse(true); - const space = document.createTextNode(" "); - range.insertNode(space); - range.setStartAfter(space); - range.collapse(true); - resolvedSelection.removeAllRanges(); - resolvedSelection.addRange(range); - rememberSelection(); + const block = current[cursorCtx.blockIndex]; + const currentChild = getTextNode(block.children[cursorCtx.childIndex]); + if (!currentChild) { + const children = [...block.children]; + children.splice(cursorCtx.childIndex + 1, 0, mentionNode, { text: " " }); + const next = [...current]; + next[cursorCtx.blockIndex] = { ...block, children }; + setValue(next as ComposerValue); + requestAnimationFrame(focusAtEnd); + return; } - syncEditorState(nextDocs); + const text = currentChild.text; + let removeStart = cursorCtx.cursor; + if (removeTriggerText) { + for (let i = cursorCtx.cursor - 1; i >= 0; i--) { + if (text[i] === "@") { + removeStart = i; + break; + } + if (text[i] === " " || text[i] === "\n") break; + } + } + + const before = text.slice(0, removeStart); + const after = text.slice(cursorCtx.cursor); + const replacement: ComposerNode[] = []; + if (before.length > 0) replacement.push({ text: before }); + replacement.push(mentionNode); + replacement.push({ text: ` ${after}` }); + + const children = [...block.children]; + children.splice(cursorCtx.childIndex, 1, ...replacement); + const next = [...current]; + next[cursorCtx.blockIndex] = { ...block, children }; + setValue(next as ComposerValue); + requestAnimationFrame(focusAtEnd); }, - [ - createChipElement, - isSelectionInsideEditor, - mentionedDocs, - rememberSelection, - restoreRememberedSelection, - syncEditorState, - ] + [editor.selection, focusAtEnd, getCurrentValue, setValue] ); - // Clear the editor - const clear = useCallback(() => { - if (editorRef.current) { - editorRef.current.innerHTML = ""; - const emptyDocs = new Map<string, MentionedDocument>(); - setMentionedDocs(emptyDocs); - syncEditorState(emptyDocs); - } - }, [syncEditorState]); - - // Replace editor content with plain text and place cursor at end - const setText = useCallback( - (text: string) => { - if (!editorRef.current) return; - editorRef.current.innerText = text; - syncEditorState(); - focusAtEnd(); + const removeDocumentChip = useCallback( + (docId: number, docType?: string) => { + const current = getCurrentValue(); + let changed = false; + const next = current.map((block) => { + const children = block.children.filter((node) => { + if (!isMentionNode(node)) return true; + const match = node.id === docId && (node.document_type ?? "UNKNOWN") === (docType ?? "UNKNOWN"); + if (match) changed = true; + return !match; + }); + return { ...block, children: children.length ? children : [{ text: "" }] }; + }); + if (!changed) return; + setValue(next as ComposerValue); }, - [focusAtEnd, syncEditorState] + [getCurrentValue, setValue] ); const setDocumentChipStatus = useCallback( @@ -507,327 +396,142 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent docId: number, docType: string | undefined, statusLabel: string | null, - statusKind: "pending" | "processing" | "ready" | "failed" = "pending" + statusKind: MentionStatusKind = "pending" ) => { - if (!editorRef.current) return; - - const chips = editorRef.current.querySelectorAll<HTMLSpanElement>( - `span[${CHIP_DATA_ATTR}="true"]` - ); - for (const chip of chips) { - const chipId = getChipId(chip); - const chipType = getChipDocType(chip); - if (chipId !== docId) continue; - if ((docType ?? "UNKNOWN") !== chipType) continue; - - const statusEl = chip.querySelector<HTMLSpanElement>(`span[${CHIP_STATUS_ATTR}="true"]`); - if (!statusEl) continue; - - if (!statusLabel) { - statusEl.textContent = ""; - statusEl.className = "text-[10px] font-semibold opacity-80 hidden"; - continue; - } - - const statusClass = - statusKind === "failed" - ? "text-destructive" - : statusKind === "processing" - ? "text-amber-700" - : statusKind === "ready" - ? "text-emerald-700" - : "text-amber-700"; - statusEl.textContent = statusLabel; - statusEl.className = `text-[10px] font-semibold opacity-80 ${statusClass}`; - } + const current = getCurrentValue(); + let changed = false; + const next = current.map((block) => ({ + ...block, + children: block.children.map((node) => { + if (!isMentionNode(node)) return node; + const sameType = (node.document_type ?? "UNKNOWN") === (docType ?? "UNKNOWN"); + if (node.id !== docId || !sameType) return node; + changed = true; + return { + ...node, + statusLabel, + statusKind: statusLabel ? statusKind : undefined, + }; + }), + })); + if (!changed) return; + setValue(next as ComposerValue); }, - [] + [getCurrentValue, setValue] ); - const removeDocumentChip = useCallback( - (docId: number, docType?: string) => { - if (!editorRef.current) return; - const chipKey = getMentionDocKey({ id: docId, document_type: docType }); - const chips = editorRef.current.querySelectorAll<HTMLSpanElement>( - `span[${CHIP_DATA_ATTR}="true"]` - ); - for (const chip of chips) { - if (getChipId(chip) === docId && getChipDocType(chip) === (docType ?? "UNKNOWN")) { - chip.remove(); - break; - } - } - setMentionedDocs((prev) => { - const next = new Map(prev); - next.delete(chipKey); - syncEditorState(next); - return next; - }); + const clear = useCallback(() => { + setValue(EMPTY_VALUE); + }, [setValue]); + + const setText = useCallback( + (text: string) => { + setValue(toValueFromText(text)); + requestAnimationFrame(focusAtEnd); }, - [syncEditorState] + [focusAtEnd, setValue] ); - // Expose methods via ref - useImperativeHandle(ref, () => ({ - focus: () => editorRef.current?.focus(), - clear, - setText, - getText, - getMentionedDocuments, - insertDocumentChip, - removeDocumentChip, - setDocumentChipStatus, - })); + const getText = useCallback(() => getPlainText(getCurrentValue()), [getCurrentValue]); + const getMentionedDocs = useCallback( + () => getMentionedDocuments(getCurrentValue()), + [getCurrentValue] + ); - // Handle input changes - const handleInput = useCallback(() => { - if (!editorRef.current) return; + useImperativeHandle( + ref, + () => ({ + focus: () => editableRef.current?.focus(), + clear, + setText, + getText, + getMentionedDocuments: getMentionedDocs, + insertDocumentChip, + removeDocumentChip, + setDocumentChipStatus, + }), + [clear, getMentionedDocs, getText, insertDocumentChip, removeDocumentChip, setDocumentChipStatus, setText] + ); - const text = getText(); - const empty = text.length === 0 && mentionedDocs.size === 0; - setIsEmpty(empty); - - // Unified trigger scan: find the leftmost @ or / in the current word. - // Whichever trigger was typed first owns the token — the other character - // is treated as part of the query, not as a separate trigger. - const selection = window.getSelection(); - let shouldTriggerMention = false; - let mentionQuery = ""; - let shouldTriggerAction = false; - let actionQuery = ""; - - if (selection && selection.rangeCount > 0) { - const range = selection.getRangeAt(0); - const textNode = range.startContainer; - - if (textNode.nodeType === Node.TEXT_NODE) { - const textContent = textNode.textContent || ""; - const cursorPos = range.startOffset; - - let wordStart = 0; - for (let i = cursorPos - 1; i >= 0; i--) { - if (textContent[i] === " " || textContent[i] === "\n") { - wordStart = i + 1; - break; - } - } - - let triggerChar: "@" | "/" | null = null; - let triggerIndex = -1; - for (let i = wordStart; i < cursorPos; i++) { - if (textContent[i] === "@" || textContent[i] === "/") { - triggerChar = textContent[i] as "@" | "/"; - triggerIndex = i; - break; - } - } - - if (triggerChar === "@" && triggerIndex !== -1) { - const query = textContent.slice(triggerIndex + 1, cursorPos); - if (!query.startsWith(" ")) { - shouldTriggerMention = true; - mentionQuery = query; - } - } else if (triggerChar === "/" && triggerIndex !== -1) { - if ( - triggerIndex === 0 || - textContent[triggerIndex - 1] === " " || - textContent[triggerIndex - 1] === "\n" - ) { - const query = textContent.slice(triggerIndex + 1, cursorPos); - if (!query.startsWith(" ")) { - shouldTriggerAction = true; - actionQuery = query; - } - } - } - } - } - - // If no @ found before cursor, check if text contains @ at all - // If text is empty or doesn't contain @, close the mention - if (!shouldTriggerMention) { - if (text.length === 0 || !text.includes("@")) { - onMentionClose?.(); - } else { - // Text contains @ but not before cursor, close mention - onMentionClose?.(); - } - } else { - onMentionTrigger?.(mentionQuery); - } - - if (!shouldTriggerAction) { - onActionClose?.(); - } else { - onActionTrigger?.(actionQuery); - } - - // Notify parent of change - onChange?.(text, Array.from(mentionedDocs.values())); - rememberSelection(); - }, [ - getText, - mentionedDocs, - onChange, - onMentionTrigger, - onMentionClose, - onActionTrigger, - onActionClose, - rememberSelection, - ]); - - // Handle keydown const handleKeyDown = useCallback( (e: React.KeyboardEvent<HTMLDivElement>) => { - // Let parent handle navigation keys when mention popover is open - if (onKeyDown) { - onKeyDown(e); - if (e.defaultPrevented) return; - } + onKeyDown?.(e); + if (e.defaultPrevented) return; - // Handle Enter for submit (without shift) if (e.key === "Enter" && !e.shiftKey) { e.preventDefault(); onSubmit?.(); return; } - // Handle backspace on chips - if (e.key === "Backspace") { - const selection = window.getSelection(); - if (selection && selection.rangeCount > 0) { - const range = selection.getRangeAt(0); - if (range.collapsed) { - // Check if cursor is right after a chip - const node = range.startContainer; - const offset = range.startOffset; - - if (node.nodeType === Node.TEXT_NODE && offset === 0) { - // Check previous sibling using type guard - const prevSibling = node.previousSibling; - if (isChipElement(prevSibling)) { - e.preventDefault(); - const chipId = getChipId(prevSibling); - const chipDocType = getChipDocType(prevSibling); - if (chipId !== null) { - prevSibling.remove(); - const chipKey = getMentionDocKey({ - id: chipId, - document_type: chipDocType, - }); - setMentionedDocs((prev) => { - const next = new Map(prev); - next.delete(chipKey); - syncEditorState(next); - return next; - }); - // Notify parent that a document was removed - onDocumentRemove?.(chipId, chipDocType); - } - return; - } - // Check if we're about to delete @ at the start - const textContent = node.textContent || ""; - if (textContent.length > 0 && textContent[0] === "@") { - // Will delete @, close mention popover - setTimeout(() => { - onMentionClose?.(); - }, 0); - } - } else if (node.nodeType === Node.TEXT_NODE && offset > 0) { - // Check if we're about to delete @ - const textContent = node.textContent || ""; - if (textContent[offset - 1] === "@") { - // Will delete @, close mention popover - setTimeout(() => { - onMentionClose?.(); - }, 0); - } - } else if (node.nodeType === Node.ELEMENT_NODE && offset > 0) { - // Check if previous child is a chip using type guard - const prevChild = (node as Element).childNodes[offset - 1]; - if (isChipElement(prevChild)) { - e.preventDefault(); - const chipId = getChipId(prevChild); - const chipDocType = getChipDocType(prevChild); - if (chipId !== null) { - prevChild.remove(); - const chipKey = getMentionDocKey({ - id: chipId, - document_type: chipDocType, - }); - setMentionedDocs((prev) => { - const next = new Map(prev); - next.delete(chipKey); - syncEditorState(next); - return next; - }); - // Notify parent that a document was removed - onDocumentRemove?.(chipId, chipDocType); - } - } - } - } - } + if (e.key !== "Backspace") return; + const selection = editor.selection; + if (!selection || !selection.anchor || !selection.focus) return; + if ( + selection.anchor.path.length < 2 || + selection.focus.path.length < 2 || + selection.anchor.path[0] !== selection.focus.path[0] + ) { + return; } + if (selection.anchor.offset !== 0 || selection.focus.offset !== 0) return; + + const value = getCurrentValue(); + const block = value[selection.anchor.path[0]]; + if (!block) return; + const childIndex = selection.anchor.path[1]; + if (childIndex <= 0) return; + const prev = block.children[childIndex - 1]; + if (!isMentionNode(prev)) return; + + e.preventDefault(); + removeDocumentChip(prev.id, prev.document_type); + onDocumentRemove?.(prev.id, prev.document_type); }, - [onKeyDown, onSubmit, onDocumentRemove, onMentionClose, syncEditorState] + [ + editor.selection, + getCurrentValue, + onDocumentRemove, + onKeyDown, + onSubmit, + removeDocumentChip, + ] ); - // Handle paste - strip formatting - const handlePaste = useCallback((e: React.ClipboardEvent) => { - e.preventDefault(); - const text = e.clipboardData.getData("text/plain"); - document.execCommand("insertText", false, text); - }, []); - - // Handle composition (for IME input) - const handleCompositionStart = useCallback(() => { - isComposingRef.current = true; - }, []); - - const handleCompositionEnd = useCallback(() => { - isComposingRef.current = false; - handleInput(); - }, [handleInput]); + const editableProps = useMemo( + () => ({ + placeholder, + onPaste: (e: React.ClipboardEvent<HTMLDivElement>) => { + e.preventDefault(); + const text = e.clipboardData.getData("text/plain"); + const tf = editor.tf as { insertText: (value: string) => void }; + tf.insertText(text); + }, + onKeyDown: handleKeyDown, + }), + [editor, handleKeyDown, placeholder] + ); return ( <div className="relative w-full"> - {/* biome-ignore lint/a11y/noStaticElementInteractions: contenteditable mention editor requires a div for inline chips */} - <div - ref={editorRef} - contentEditable={!disabled} - suppressContentEditableWarning - tabIndex={disabled ? -1 : 0} - onInput={handleInput} - onKeyDown={handleKeyDown} - onPaste={handlePaste} - onCompositionStart={handleCompositionStart} - onCompositionEnd={handleCompositionEnd} - onKeyUp={rememberSelection} - onMouseUp={rememberSelection} - onBlur={rememberSelection} - className={cn( - "min-h-[24px] max-h-32 overflow-y-auto", - "text-sm outline-none", - "whitespace-pre-wrap wrap-break-word", - disabled && "opacity-50 cursor-not-allowed", - className - )} - style={{ wordBreak: "break-word" }} - data-placeholder={placeholder} - /> - {/* Placeholder with fade animation on change */} - {isEmpty && ( - <div - key={placeholder} - className="absolute top-0 left-0 pointer-events-none text-muted-foreground text-sm animate-in fade-in duration-1000" - aria-hidden="true" - > - {placeholder} - </div> - )} + <Plate + editor={editor} + onChange={({ value }) => { + emitState(value as ComposerValue); + }} + > + <PlateContent + ref={editableRef} + readOnly={disabled} + {...editableProps} + className={cn( + "min-h-[24px] max-h-32 overflow-y-auto outline-none whitespace-pre-wrap wrap-break-word", + COMPOSER_TEXT_METRICS_CLASSNAME, + disabled && "opacity-50 cursor-not-allowed", + className + )} + /> + </Plate> </div> ); } From 04da62a5541d446ccb2111dc4caed69f188806cc Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 04:28:24 +0530 Subject: [PATCH 36/68] refactor(chat): improve AssistantMessage component with fixed comment trigger slot and enhanced visibility handling --- .../assistant-ui/assistant-message.tsx | 70 +++++++++++-------- 1 file changed, 39 insertions(+), 31 deletions(-) diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx index bfe0434b4..711bb2fe2 100644 --- a/surfsense_web/components/assistant-ui/assistant-message.tsx +++ b/surfsense_web/components/assistant-ui/assistant-message.tsx @@ -548,8 +548,10 @@ const AssistantMessageInner: FC = () => { </div> )} - <div className="aui-assistant-message-footer mt-3 mb-5 ml-2 flex items-center gap-2"> - <AssistantActionBar /> + <div className="aui-assistant-message-footer mt-3 mb-5 ml-2 h-6"> + <div className="h-full opacity-100 transition-opacity"> + <AssistantActionBar /> + </div> </div> </CitationMetadataProvider> ); @@ -642,35 +644,41 @@ export const AssistantMessage: FC = () => { className="aui-assistant-message-root group fade-in slide-in-from-bottom-1 relative mx-auto w-full max-w-(--thread-max-width) animate-in py-3 duration-150" data-role="assistant" > - {/* Comment trigger — right-aligned, just below user query on all screen sizes */} - {showCommentTrigger && ( - <div className="mr-2 mb-1 flex justify-end"> - <button - ref={isDesktop ? commentTriggerRef : undefined} - type="button" - onClick={ - isDesktop ? () => setIsInlineOpen((prev) => !prev) : () => setIsSheetOpen(true) - } - className={cn( - "flex items-center gap-1.5 rounded-full px-3 py-1 text-sm transition-colors", - isDesktop && isInlineOpen - ? "bg-primary/10 text-primary" - : hasComments - ? "text-primary hover:bg-primary/10" - : "text-muted-foreground hover:text-foreground hover:bg-muted" - )} - > - <MessageCircleReply className={cn("size-3.5", hasComments && "fill-current")} /> - {hasComments ? ( - <span> - {commentCount} {commentCount === 1 ? "comment" : "comments"} - </span> - ) : ( - <span>Add comment</span> - )} - </button> - </div> - )} + {/* Fixed trigger slot prevents any vertical reflow when visibility changes */} + <div className="mr-2 mb-1 flex h-7 justify-end"> + <button + ref={isDesktop ? commentTriggerRef : undefined} + type="button" + onClick={ + showCommentTrigger + ? isDesktop + ? () => setIsInlineOpen((prev) => !prev) + : () => setIsSheetOpen(true) + : undefined + } + aria-hidden={!showCommentTrigger} + tabIndex={showCommentTrigger ? 0 : -1} + className={cn( + "flex items-center gap-1.5 rounded-full px-3 py-1 text-sm transition-colors", + "opacity-0 pointer-events-none", + showCommentTrigger && "opacity-100 pointer-events-auto", + isDesktop && isInlineOpen + ? "bg-primary/10 text-primary" + : hasComments + ? "text-primary hover:bg-primary/10" + : "text-muted-foreground hover:text-foreground hover:bg-muted" + )} + > + <MessageCircleReply className={cn("size-3.5", hasComments && "fill-current")} /> + {hasComments ? ( + <span> + {commentCount} {commentCount === 1 ? "comment" : "comments"} + </span> + ) : ( + <span>Add comment</span> + )} + </button> + </div> {/* Desktop floating comment panel — overlays on top of chat content */} {showCommentTrigger && isDesktop && isInlineOpen && dbMessageId && ( From 5826e5264d68595fcf7b0e67c03739109ae05e50 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 04:39:33 +0530 Subject: [PATCH 37/68] refactor(chat): add TruncatedNameWithTooltip component in model selector --- .../components/new-chat/model-selector.tsx | 93 ++++++++++++++++++- 1 file changed, 92 insertions(+), 1 deletion(-) diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 9fe9dd8da..1a0f8c5ba 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -236,6 +236,93 @@ interface DisplayItem { isAutoMode: boolean; } +const TruncatedNameWithTooltip: React.FC<{ + text: string; + className?: string; + enableTooltip: boolean; +}> = ({ text, className, enableTooltip }) => { + const textRef = useRef<HTMLSpanElement>(null); + const openTimerRef = useRef<number | undefined>(undefined); + const [isTruncated, setIsTruncated] = useState(false); + const [open, setOpen] = useState(false); + + const recalcTruncation = useCallback(() => { + const el = textRef.current; + if (!el) return; + setIsTruncated(el.scrollWidth > el.clientWidth + 1); + }, []); + + useEffect(() => { + if (!enableTooltip) return; + const el = textRef.current; + if (!el) return; + + const raf = requestAnimationFrame(recalcTruncation); + recalcTruncation(); + + const observer = new ResizeObserver(recalcTruncation); + observer.observe(el); + if (el.parentElement) observer.observe(el.parentElement); + window.addEventListener("resize", recalcTruncation); + + return () => { + cancelAnimationFrame(raf); + observer.disconnect(); + window.removeEventListener("resize", recalcTruncation); + }; + }, [enableTooltip, recalcTruncation]); + + useEffect(() => { + // Recompute when row text changes. + void text; + requestAnimationFrame(recalcTruncation); + }, [text, recalcTruncation]); + + useEffect( + () => () => { + if (openTimerRef.current) window.clearTimeout(openTimerRef.current); + }, + [] + ); + + if (!enableTooltip) { + return ( + <span ref={textRef} className={cn("block max-w-full", className)}> + {text} + </span> + ); + } + + const handleOpenChange = (nextOpen: boolean) => { + if (openTimerRef.current) { + window.clearTimeout(openTimerRef.current); + openTimerRef.current = undefined; + } + if (!nextOpen) { + setOpen(false); + return; + } + if (!isTruncated) return; + openTimerRef.current = window.setTimeout(() => { + setOpen(true); + openTimerRef.current = undefined; + }, 220); + }; + + return ( + <Tooltip open={open} onOpenChange={handleOpenChange}> + <TooltipTrigger asChild> + <span ref={textRef} className={cn("block max-w-full", className)}> + {text} + </span> + </TooltipTrigger> + <TooltipContent side="top" align="start"> + {text} + </TooltipContent> + </Tooltip> + ); +}; + // ─── Component ────────────────────────────────────────────────────── interface ModelSelectorProps { @@ -936,7 +1023,11 @@ export function ModelSelector({ {/* Model info */} <div className="flex-1 min-w-0"> <div className="flex items-center gap-1.5"> - <span className="font-medium text-sm truncate">{config.name}</span> + <TruncatedNameWithTooltip + text={config.name} + enableTooltip={!isMobile} + className="font-medium text-sm truncate" + /> {isAutoMode && ( <Badge variant="secondary" From 7aeb8bb0a88c84afb6f23cc438d75be266701031 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Thu, 30 Apr 2026 18:40:55 -0700 Subject: [PATCH 38/68] feat(markdown): enable citation rendering in MarkdownViewer and related components - Added `enableCitations` prop to `MarkdownViewer` to support interactive citation badges. - Updated instances of `MarkdownViewer` across various components to utilize the new citation feature. - Enhanced citation processing in `PlateEditor` for read-only views, ensuring citations are rendered correctly without affecting markdown serialization. - Refactored citation handling in `InlineCitation` and `MarkdownText` to improve citation context management. --- .../assistant-ui/inline-citation.tsx | 18 +- .../components/assistant-ui/markdown-text.tsx | 428 ++++++++---------- .../citation-panel/citation-panel.tsx | 2 +- .../citations/citation-renderer.tsx | 79 ++++ surfsense_web/components/document-viewer.tsx | 2 +- .../components/editor-panel/editor-panel.tsx | 9 +- .../components/editor/plate-editor.tsx | 56 ++- .../editor/plugins/citation-kit.tsx | 222 +++++++++ .../components/editor/utils/escape-mdx.ts | 2 +- .../layout/ui/tabs/DocumentTabContent.tsx | 4 +- surfsense_web/components/markdown-viewer.tsx | 100 +++- .../components/report-panel/report-panel.tsx | 5 +- .../lib/citations/citation-parser.ts | 134 ++++++ surfsense_web/lib/markdown/code-regions.ts | 8 + 14 files changed, 809 insertions(+), 260 deletions(-) create mode 100644 surfsense_web/components/citations/citation-renderer.tsx create mode 100644 surfsense_web/components/editor/plugins/citation-kit.tsx create mode 100644 surfsense_web/lib/citations/citation-parser.ts create mode 100644 surfsense_web/lib/markdown/code-regions.ts diff --git a/surfsense_web/components/assistant-ui/inline-citation.tsx b/surfsense_web/components/assistant-ui/inline-citation.tsx index 2aeba89ca..e299f2373 100644 --- a/surfsense_web/components/assistant-ui/inline-citation.tsx +++ b/surfsense_web/components/assistant-ui/inline-citation.tsx @@ -3,11 +3,11 @@ import { useQuery } from "@tanstack/react-query"; import { useSetAtom } from "jotai"; import { ExternalLink, FileText } from "lucide-react"; +import dynamic from "next/dynamic"; import type { FC } from "react"; import { useCallback, useEffect, useRef, useState } from "react"; import { openCitationPanelAtom } from "@/atoms/citation/citation-panel.atom"; import { useCitationMetadata } from "@/components/assistant-ui/citation-metadata-context"; -import { MarkdownViewer } from "@/components/markdown-viewer"; import { Citation } from "@/components/tool-ui/citation"; import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { Spinner } from "@/components/ui/spinner"; @@ -15,6 +15,16 @@ import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip import { documentsApiService } from "@/lib/apis/documents-api.service"; import { cacheKeys } from "@/lib/query-client/cache-keys"; +// Lazily load MarkdownViewer here to break the static import cycle: +// `markdown-viewer.tsx` → `citation-renderer.tsx` → `inline-citation.tsx` +// would otherwise pull `markdown-viewer.tsx` back in at module-init time. +// Only `SurfsenseDocCitation` (popover body) ever renders this viewer, so +// the lazy boundary is invisible to most call paths. +const MarkdownViewer = dynamic( + () => import("@/components/markdown-viewer").then((m) => m.MarkdownViewer), + { ssr: false, loading: () => <Spinner size="xs" /> } +); + interface InlineCitationProps { chunkId: number; isDocsChunk?: boolean; @@ -172,7 +182,11 @@ const SurfsenseDocCitation: FC<{ chunkId: number }> = ({ chunkId }) => { </p> )} {!isLoading && !error && citedChunk?.content && ( - <MarkdownViewer content={citedChunk.content} maxLength={1500} /> + <MarkdownViewer + content={citedChunk.content} + maxLength={1500} + enableCitations + /> )} {!isLoading && !error && !citedChunk?.content && ( <p className="py-4 text-xs text-muted-foreground">No content available.</p> diff --git a/surfsense_web/components/assistant-ui/markdown-text.tsx b/surfsense_web/components/assistant-ui/markdown-text.tsx index 7655e10cc..2b788e88b 100644 --- a/surfsense_web/components/assistant-ui/markdown-text.tsx +++ b/surfsense_web/components/assistant-ui/markdown-text.tsx @@ -12,15 +12,26 @@ import { ExternalLinkIcon } from "lucide-react"; import dynamic from "next/dynamic"; import { useParams } from "next/navigation"; import { useTheme } from "next-themes"; -import { memo, type ReactNode } from "react"; +import { + createContext, + memo, + type ReactNode, + useCallback, + useContext, + useRef, +} from "react"; import rehypeKatex from "rehype-katex"; import remarkGfm from "remark-gfm"; import remarkMath from "remark-math"; import { openEditorPanelAtom } from "@/atoms/editor/editor-panel.atom"; import { ImagePreview, ImageRoot, ImageZoom } from "@/components/assistant-ui/image"; import "katex/dist/katex.min.css"; -import { InlineCitation, UrlCitation } from "@/components/assistant-ui/inline-citation"; +import { processChildrenWithCitations } from "@/components/citations/citation-renderer"; import { Skeleton } from "@/components/ui/skeleton"; +import { + type CitationUrlMap, + preprocessCitationMarkdown, +} from "@/lib/citations/citation-parser"; import { Table, TableBody, @@ -59,31 +70,30 @@ const LazyMarkdownCodeBlock = dynamic( } ); -// Storage for URL citations replaced during preprocess to avoid GFM autolink interference. -// Populated in preprocessMarkdown, consumed in parseTextWithCitations. -let _pendingUrlCitations = new Map<string, string>(); -let _urlCiteIdx = 0; +// Per-render URL placeholder map propagated to component overrides via +// React Context. Replaces the previous module-level `_pendingUrlCitations` +// state, which was unsafe under concurrent renders / SSR. +type CitationUrlMapRef = { current: CitationUrlMap }; +const EMPTY_URL_MAP: CitationUrlMap = new Map(); +const CitationUrlMapContext = createContext<CitationUrlMapRef>({ current: EMPTY_URL_MAP }); + +function useCitationUrlMap(): CitationUrlMap { + return useContext(CitationUrlMapContext).current; +} /** * Preprocess raw markdown before it reaches the remark/rehype pipeline. * - Replaces URL-based citations with safe placeholders (prevents GFM autolinks) * - Normalises LaTeX delimiters to dollar-sign syntax for remark-math */ -function preprocessMarkdown(content: string): string { +function preprocessMarkdown(content: string, urlMapRef: CitationUrlMapRef): string { // Replace URL-based citations with safe placeholders BEFORE markdown parsing. // GFM autolinks would otherwise convert the https://... inside [citation:URL] // into an <a> element, splitting the text and preventing our citation regex // from matching the full pattern. - _pendingUrlCitations = new Map(); - _urlCiteIdx = 0; - content = content.replace( - /[[【]\u200B?citation:\s*(https?:\/\/[^\]】\u200B]+)\s*\u200B?[\]】]/g, - (_, url) => { - const key = `urlcite${_urlCiteIdx++}`; - _pendingUrlCitations.set(key, url.trim()); - return `[citation:${key}]`; - } - ); + const { content: rewritten, urlMap } = preprocessCitationMarkdown(content); + urlMapRef.current = urlMap; + content = rewritten; // All math forms are normalised to $$...$$ so we can disable single-dollar // inline math in remark-math (otherwise currency like "$3,120.00 and $0.00" @@ -116,113 +126,28 @@ function preprocessMarkdown(content: string): string { return content; } -// Matches [citation:...] with numeric IDs (incl. negative, doc- prefix, comma-separated), -// URL-based IDs from live web search, or urlciteN placeholders from preprocess. -// Also matches Chinese brackets 【】 and handles zero-width spaces that LLM sometimes inserts. -const CITATION_REGEX = - /[[【]\u200B?citation:\s*(https?:\/\/[^\]】\u200B]+|urlcite\d+|(?:doc-)?-?\d+(?:\s*,\s*(?:doc-)?-?\d+)*)\s*\u200B?[\]】]/g; - -/** - * Parses text and replaces [citation:XXX] patterns with citation components. - * Supports: - * - Numeric chunk IDs: [citation:123] - * - Doc-prefixed IDs: [citation:doc-123] - * - Comma-separated IDs: [citation:4149, 4150, 4151] - * - URL-based citations from live search: [citation:https://example.com/page] - */ -function parseTextWithCitations(text: string): ReactNode[] { - const parts: ReactNode[] = []; - let lastIndex = 0; - let match: RegExpExecArray | null; - let instanceIndex = 0; - - CITATION_REGEX.lastIndex = 0; - - match = CITATION_REGEX.exec(text); - while (match !== null) { - if (match.index > lastIndex) { - parts.push(text.substring(lastIndex, match.index)); - } - - const captured = match[1]; - - if (captured.startsWith("http://") || captured.startsWith("https://")) { - parts.push(<UrlCitation key={`citation-url-${instanceIndex}`} url={captured.trim()} />); - instanceIndex++; - } else if (captured.startsWith("urlcite")) { - const url = _pendingUrlCitations.get(captured); - if (url) { - parts.push(<UrlCitation key={`citation-url-${instanceIndex}`} url={url} />); - } - instanceIndex++; - } else { - const rawIds = captured.split(",").map((s) => s.trim()); - for (const rawId of rawIds) { - const isDocsChunk = rawId.startsWith("doc-"); - const chunkId = Number.parseInt(isDocsChunk ? rawId.slice(4) : rawId, 10); - parts.push( - <InlineCitation - key={`citation-${isDocsChunk ? "doc-" : ""}${chunkId}-${instanceIndex}`} - chunkId={chunkId} - isDocsChunk={isDocsChunk} - /> - ); - instanceIndex++; - } - } - - lastIndex = match.index + match[0].length; - match = CITATION_REGEX.exec(text); - } - - if (lastIndex < text.length) { - parts.push(text.substring(lastIndex)); - } - - return parts.length > 0 ? parts : [text]; -} - const MarkdownTextImpl = () => { + const urlMapRef = useRef<CitationUrlMap>(EMPTY_URL_MAP); + const preprocess = useCallback( + (content: string) => preprocessMarkdown(content, urlMapRef), + [] + ); return ( - <MarkdownTextPrimitive - smooth={false} - remarkPlugins={[remarkGfm, [remarkMath, { singleDollarTextMath: false }]]} - rehypePlugins={[rehypeKatex]} - className="aui-md" - components={defaultComponents} - preprocess={preprocessMarkdown} - /> + <CitationUrlMapContext.Provider value={urlMapRef}> + <MarkdownTextPrimitive + smooth={false} + remarkPlugins={[remarkGfm, [remarkMath, { singleDollarTextMath: false }]]} + rehypePlugins={[rehypeKatex]} + className="aui-md" + components={defaultComponents} + preprocess={preprocess} + /> + </CitationUrlMapContext.Provider> ); }; export const MarkdownText = memo(MarkdownTextImpl); -/** - * Helper to process children and replace citation patterns with components - */ -function processChildrenWithCitations(children: ReactNode): ReactNode { - if (typeof children === "string") { - const parsed = parseTextWithCitations(children); - return parsed.length === 1 && typeof parsed[0] === "string" ? children : parsed; - } - - if (Array.isArray(children)) { - return children.map((child) => { - if (typeof child === "string") { - const parsed = parseTextWithCitations(child); - return parsed.length === 1 && typeof parsed[0] === "string" ? ( - child - ) : ( - <span key={child}>{parsed}</span> - ); - } - return child; - }); - } - - return children; -} - function extractDomain(url: string): string { try { const parsed = new URL(url); @@ -322,92 +247,125 @@ function MarkdownImage({ src, alt }: { src?: string; alt?: string }) { } const defaultComponents = memoizeMarkdownComponents({ - h1: ({ className, children, ...props }) => ( - <h1 - className={cn( - "aui-md-h1 mb-8 scroll-m-20 font-extrabold text-4xl tracking-tight last:mb-0", - className - )} - {...props} - > - {processChildrenWithCitations(children)} - </h1> - ), - h2: ({ className, children, ...props }) => ( - <h2 - className={cn( - "aui-md-h2 mt-8 mb-4 scroll-m-20 font-semibold text-3xl tracking-tight first:mt-0 last:mb-0", - className - )} - {...props} - > - {processChildrenWithCitations(children)} - </h2> - ), - h3: ({ className, children, ...props }) => ( - <h3 - className={cn( - "aui-md-h3 mt-6 mb-4 scroll-m-20 font-semibold text-2xl tracking-tight first:mt-0 last:mb-0", - className - )} - {...props} - > - {processChildrenWithCitations(children)} - </h3> - ), - h4: ({ className, children, ...props }) => ( - <h4 - className={cn( - "aui-md-h4 mt-6 mb-4 scroll-m-20 font-semibold text-xl tracking-tight first:mt-0 last:mb-0", - className - )} - {...props} - > - {processChildrenWithCitations(children)} - </h4> - ), - h5: ({ className, children, ...props }) => ( - <h5 - className={cn("aui-md-h5 my-4 font-semibold text-lg first:mt-0 last:mb-0", className)} - {...props} - > - {processChildrenWithCitations(children)} - </h5> - ), - h6: ({ className, children, ...props }) => ( - <h6 className={cn("aui-md-h6 my-4 font-semibold first:mt-0 last:mb-0", className)} {...props}> - {processChildrenWithCitations(children)} - </h6> - ), - p: ({ className, children, ...props }) => ( - <p className={cn("aui-md-p mt-5 mb-5 leading-7 first:mt-0 last:mb-0", className)} {...props}> - {processChildrenWithCitations(children)} - </p> - ), - a: ({ className, children, ...props }) => ( - <a - className={cn("aui-md-a font-medium text-primary underline underline-offset-4", className)} - {...props} - > - {processChildrenWithCitations(children)} - </a> - ), - blockquote: ({ className, children, ...props }) => ( - <blockquote className={cn("aui-md-blockquote border-l-2 pl-6 italic", className)} {...props}> - {processChildrenWithCitations(children)} - </blockquote> - ), + h1: function H1({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <h1 + className={cn( + "aui-md-h1 mb-8 scroll-m-20 font-extrabold text-4xl tracking-tight last:mb-0", + className + )} + {...props} + > + {processChildrenWithCitations(children, urlMap)} + </h1> + ); + }, + h2: function H2({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <h2 + className={cn( + "aui-md-h2 mt-8 mb-4 scroll-m-20 font-semibold text-3xl tracking-tight first:mt-0 last:mb-0", + className + )} + {...props} + > + {processChildrenWithCitations(children, urlMap)} + </h2> + ); + }, + h3: function H3({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <h3 + className={cn( + "aui-md-h3 mt-6 mb-4 scroll-m-20 font-semibold text-2xl tracking-tight first:mt-0 last:mb-0", + className + )} + {...props} + > + {processChildrenWithCitations(children, urlMap)} + </h3> + ); + }, + h4: function H4({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <h4 + className={cn( + "aui-md-h4 mt-6 mb-4 scroll-m-20 font-semibold text-xl tracking-tight first:mt-0 last:mb-0", + className + )} + {...props} + > + {processChildrenWithCitations(children, urlMap)} + </h4> + ); + }, + h5: function H5({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <h5 + className={cn("aui-md-h5 my-4 font-semibold text-lg first:mt-0 last:mb-0", className)} + {...props} + > + {processChildrenWithCitations(children, urlMap)} + </h5> + ); + }, + h6: function H6({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <h6 className={cn("aui-md-h6 my-4 font-semibold first:mt-0 last:mb-0", className)} {...props}> + {processChildrenWithCitations(children, urlMap)} + </h6> + ); + }, + p: function P({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <p className={cn("aui-md-p mt-5 mb-5 leading-7 first:mt-0 last:mb-0", className)} {...props}> + {processChildrenWithCitations(children, urlMap)} + </p> + ); + }, + a: function A({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <a + className={cn( + "aui-md-a font-medium text-primary underline underline-offset-4", + className + )} + {...props} + > + {processChildrenWithCitations(children, urlMap)} + </a> + ); + }, + blockquote: function Blockquote({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <blockquote className={cn("aui-md-blockquote border-l-2 pl-6 italic", className)} {...props}> + {processChildrenWithCitations(children, urlMap)} + </blockquote> + ); + }, ul: ({ className, ...props }) => ( <ul className={cn("aui-md-ul my-5 ml-6 list-disc [&>li]:mt-2", className)} {...props} /> ), ol: ({ className, ...props }) => ( <ol className={cn("aui-md-ol my-5 ml-6 list-decimal [&>li]:mt-2", className)} {...props} /> ), - li: ({ className, children, ...props }) => ( - <li className={cn("aui-md-li", className)} {...props}> - {processChildrenWithCitations(children)} - </li> - ), + li: function Li({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <li className={cn("aui-md-li", className)} {...props}> + {processChildrenWithCitations(children, urlMap)} + </li> + ); + }, hr: ({ className, ...props }) => ( <hr className={cn("aui-md-hr my-5 border-b", className)} {...props} /> ), @@ -422,28 +380,34 @@ const defaultComponents = memoizeMarkdownComponents({ tbody: ({ className, ...props }) => ( <TableBody className={cn("aui-md-tbody", className)} {...props} /> ), - th: ({ className, children, ...props }) => ( - <TableHead - className={cn( - "aui-md-th bg-muted/50 whitespace-normal [[align=center]]:text-center [[align=right]]:text-right", - className - )} - {...props} - > - {processChildrenWithCitations(children)} - </TableHead> - ), - td: ({ className, children, ...props }) => ( - <TableCell - className={cn( - "aui-md-td whitespace-normal [[align=center]]:text-center [[align=right]]:text-right", - className - )} - {...props} - > - {processChildrenWithCitations(children)} - </TableCell> - ), + th: function Th({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <TableHead + className={cn( + "aui-md-th bg-muted/50 whitespace-normal [[align=center]]:text-center [[align=right]]:text-right", + className + )} + {...props} + > + {processChildrenWithCitations(children, urlMap)} + </TableHead> + ); + }, + td: function Td({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <TableCell + className={cn( + "aui-md-td whitespace-normal [[align=center]]:text-center [[align=right]]:text-right", + className + )} + {...props} + > + {processChildrenWithCitations(children, urlMap)} + </TableCell> + ); + }, tr: ({ className, ...props }) => <TableRow className={cn("aui-md-tr", className)} {...props} />, sup: ({ className, ...props }) => ( <sup className={cn("aui-md-sup [&>a]:text-xs [&>a]:no-underline", className)} {...props} /> @@ -552,16 +516,22 @@ const defaultComponents = memoizeMarkdownComponents({ /> ); }, - strong: ({ className, children, ...props }) => ( - <strong className={cn("aui-md-strong font-semibold", className)} {...props}> - {processChildrenWithCitations(children)} - </strong> - ), - em: ({ className, children, ...props }) => ( - <em className={cn("aui-md-em", className)} {...props}> - {processChildrenWithCitations(children)} - </em> - ), + strong: function Strong({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <strong className={cn("aui-md-strong font-semibold", className)} {...props}> + {processChildrenWithCitations(children, urlMap)} + </strong> + ); + }, + em: function Em({ className, children, ...props }) { + const urlMap = useCitationUrlMap(); + return ( + <em className={cn("aui-md-em", className)} {...props}> + {processChildrenWithCitations(children, urlMap)} + </em> + ); + }, img: ({ src, alt }) => ( <MarkdownImage src={typeof src === "string" ? src : undefined} alt={alt} /> ), diff --git a/surfsense_web/components/citation-panel/citation-panel.tsx b/surfsense_web/components/citation-panel/citation-panel.tsx index cec07b9cf..ed8acd656 100644 --- a/surfsense_web/components/citation-panel/citation-panel.tsx +++ b/surfsense_web/components/citation-panel/citation-panel.tsx @@ -169,7 +169,7 @@ export const CitationPanelContent: FC<CitationPanelContentProps> = ({ chunkId, o )} </div> <div className="text-sm"> - <MarkdownViewer content={chunk.content} /> + <MarkdownViewer content={chunk.content} enableCitations /> </div> </div> ); diff --git a/surfsense_web/components/citations/citation-renderer.tsx b/surfsense_web/components/citations/citation-renderer.tsx new file mode 100644 index 000000000..bf877f03f --- /dev/null +++ b/surfsense_web/components/citations/citation-renderer.tsx @@ -0,0 +1,79 @@ +"use client"; + +import type { ReactNode } from "react"; +import { InlineCitation, UrlCitation } from "@/components/assistant-ui/inline-citation"; +import { + type CitationToken, + type CitationUrlMap, + parseTextWithCitations, +} from "@/lib/citations/citation-parser"; + +/** + * Render a single parsed citation token as JSX. + * + * `ordinalKey` should be a stable per-render counter so duplicate identical + * citations within the same parent don't collide on `key`. The previous + * implementation in `markdown-text.tsx` used the source string itself as + * the key, which produced React warnings when two segments rendered the + * same `[citation:N]` text. + */ +export function renderCitationToken(token: CitationToken, ordinalKey: number): ReactNode { + if (token.kind === "url") { + return <UrlCitation key={`citation-url-${ordinalKey}`} url={token.url} />; + } + return ( + <InlineCitation + key={`citation-${token.isDocsChunk ? "doc-" : ""}${token.chunkId}-${ordinalKey}`} + chunkId={token.chunkId} + isDocsChunk={token.isDocsChunk} + /> + ); +} + +/** + * Walk a `ReactNode` (string, array, or arbitrary node) and replace any + * `[citation:...]` tokens inside string children with citation badges. + * + * Designed for use inside `Streamdown`/`react-markdown` `components` + * overrides where the renderer hands you `children`. Non-string children + * are returned untouched so block/phrasing structure is preserved. + */ +export function processChildrenWithCitations( + children: ReactNode, + urlMap: CitationUrlMap +): ReactNode { + if (typeof children === "string") { + const segments = parseTextWithCitations(children, urlMap); + if (segments.length === 1 && typeof segments[0] === "string") { + return children; + } + let ordinal = 0; + return segments.map((segment) => + typeof segment === "string" ? segment : renderCitationToken(segment, ordinal++) + ); + } + + if (Array.isArray(children)) { + let ordinal = 0; + return children.map((child, childIndex) => { + if (typeof child === "string") { + const segments = parseTextWithCitations(child, urlMap); + if (segments.length === 1 && typeof segments[0] === "string") { + return child; + } + return ( + <span key={`citation-seg-${childIndex}`}> + {segments.map((segment) => + typeof segment === "string" + ? segment + : renderCitationToken(segment, ordinal++) + )} + </span> + ); + } + return child; + }); + } + + return children; +} diff --git a/surfsense_web/components/document-viewer.tsx b/surfsense_web/components/document-viewer.tsx index 0f283e567..710a04ba3 100644 --- a/surfsense_web/components/document-viewer.tsx +++ b/surfsense_web/components/document-viewer.tsx @@ -32,7 +32,7 @@ export function DocumentViewer({ title, content, trigger }: DocumentViewerProps) <DialogTitle>{title}</DialogTitle> </DialogHeader> <div className="mt-4"> - <MarkdownViewer content={content} /> + <MarkdownViewer content={content} enableCitations /> </div> </DialogContent> </Dialog> diff --git a/surfsense_web/components/editor-panel/editor-panel.tsx b/surfsense_web/components/editor-panel/editor-panel.tsx index df138e97e..eab07a91b 100644 --- a/surfsense_web/components/editor-panel/editor-panel.tsx +++ b/surfsense_web/components/editor-panel/editor-panel.tsx @@ -652,7 +652,7 @@ export function EditorPanelContent({ // Plate is heavy on multi-MB docs. <div className="h-full overflow-y-auto px-5 py-4"> {largeDocAlert} - <MarkdownViewer content={editorDoc.source_markdown} /> + <MarkdownViewer content={editorDoc.source_markdown} enableCitations /> </div> ) : renderInPlateEditor ? ( // Editable doc (FILE/NOTE) — Plate editing UX. @@ -670,12 +670,17 @@ export function EditorPanelContent({ reserveToolbarSpace defaultEditing={isEditing} className="**:[[role=toolbar]]:bg-sidebar!" + // Render `[citation:N]` badges in view mode only. + // Edit mode keeps raw text so the user can edit/delete + // tokens directly. `local_file` never reaches this branch + // (handled by the source_code editor above). + enableCitations={!isEditing && !isLocalFileMode} /> </div> </div> ) : ( <div className="h-full overflow-y-auto px-5 py-4"> - <MarkdownViewer content={editorDoc.source_markdown} /> + <MarkdownViewer content={editorDoc.source_markdown} enableCitations /> </div> )} </div> diff --git a/surfsense_web/components/editor/plate-editor.tsx b/surfsense_web/components/editor/plate-editor.tsx index 7f12d3cae..c42cb991e 100644 --- a/surfsense_web/components/editor/plate-editor.tsx +++ b/surfsense_web/components/editor/plate-editor.tsx @@ -8,9 +8,11 @@ import { useEffect, useMemo, useRef } from "react"; import remarkGfm from "remark-gfm"; import remarkMath from "remark-math"; import { EditorSaveContext } from "@/components/editor/editor-save-context"; +import { CitationKit, injectCitationNodes } from "@/components/editor/plugins/citation-kit"; import { type EditorPreset, presetMap } from "@/components/editor/presets"; import { escapeMdxExpressions } from "@/components/editor/utils/escape-mdx"; import { Editor, EditorContainer } from "@/components/ui/editor"; +import { preprocessCitationMarkdown } from "@/lib/citations/citation-parser"; /** Live editor instance returned by `usePlateEditor`. */ export type PlateEditorInstance = ReturnType<typeof usePlateEditor>; @@ -65,6 +67,14 @@ export interface PlateEditorProps { * without modifying the core editor component. */ extraPlugins?: AnyPluginConfig[]; + /** + * Render `[citation:N]` and `[citation:URL]` tokens in the deserialized + * markdown as interactive citation badges/popovers (mirrors chat). Only + * meant for read-only views — when true, `onMarkdownChange` is suppressed + * because the in-memory tree contains custom inline-void elements that + * have no markdown serialize rule. + */ + enableCitations?: boolean; } function PlateEditorContent({ @@ -103,6 +113,7 @@ export function PlateEditor({ defaultEditing = false, preset = "full", extraPlugins = [], + enableCitations = false, }: PlateEditorProps) { const lastMarkdownRef = useRef(markdown); const lastHtmlRef = useRef(html); @@ -145,6 +156,8 @@ export function PlateEditor({ ...(onSave ? [SaveShortcutPlugin] : []), // Consumer-provided extra plugins ...extraPlugins, + // Citation void inline element (read-only document viewer). + ...(enableCitations ? CitationKit : []), MarkdownPlugin.configure({ options: { remarkPlugins: [remarkGfm, remarkMath, remarkMdx], @@ -154,8 +167,18 @@ export function PlateEditor({ value: html ? (editor) => editor.api.html.deserialize({ element: html }) as Value : markdown - ? (editor) => - editor.getApi(MarkdownPlugin).markdown.deserialize(escapeMdxExpressions(markdown)) + ? (editor) => { + if (!enableCitations) { + return editor + .getApi(MarkdownPlugin) + .markdown.deserialize(escapeMdxExpressions(markdown)); + } + const { content: rewritten, urlMap } = preprocessCitationMarkdown(markdown); + const value = editor + .getApi(MarkdownPlugin) + .markdown.deserialize(escapeMdxExpressions(rewritten)); + return injectCitationNodes(value as Descendant[], urlMap) as Value; + } : undefined, }); @@ -174,13 +197,22 @@ export function PlateEditor({ useEffect(() => { if (!html && markdown !== undefined && markdown !== lastMarkdownRef.current) { lastMarkdownRef.current = markdown; - const newValue = editor - .getApi(MarkdownPlugin) - .markdown.deserialize(escapeMdxExpressions(markdown)); + let newValue: Descendant[]; + if (enableCitations) { + const { content: rewritten, urlMap } = preprocessCitationMarkdown(markdown); + const deserialized = editor + .getApi(MarkdownPlugin) + .markdown.deserialize(escapeMdxExpressions(rewritten)) as Descendant[]; + newValue = injectCitationNodes(deserialized, urlMap); + } else { + newValue = editor + .getApi(MarkdownPlugin) + .markdown.deserialize(escapeMdxExpressions(markdown)) as Descendant[]; + } editor.tf.reset(); - editor.tf.setValue(newValue); + editor.tf.setValue(newValue as Value); } - }, [html, markdown, editor]); + }, [html, markdown, editor, enableCitations]); // When not forced read-only, the user can toggle between editing/viewing. const canToggleMode = !readOnly && allowModeToggle; @@ -205,6 +237,16 @@ export function PlateEditor({ // (initialized to true via usePlateEditor, toggled via ModeToolbarButton). {...(readOnly ? { readOnly: true } : {})} onChange={({ value }) => { + // View-only citation mode: skip serialization. The custom + // `citation` inline-void element has no markdown serialize + // rule, so emitting changes here would overwrite + // `lastMarkdownRef.current` (and downstream copy-to-clipboard + // state in EditorPanelContent) with a tree that loses every + // citation token. `enableCitations` is only ever set in + // read-only paths, so user input cannot reach this branch + // in practice — the guard exists for the initial Plate + // normalize emit. + if (enableCitations) return; if (onHtmlChange && html) { const serialized = slateToHtml(value as Descendant[]); onHtmlChange(serialized); diff --git a/surfsense_web/components/editor/plugins/citation-kit.tsx b/surfsense_web/components/editor/plugins/citation-kit.tsx new file mode 100644 index 000000000..c90cb5e28 --- /dev/null +++ b/surfsense_web/components/editor/plugins/citation-kit.tsx @@ -0,0 +1,222 @@ +"use client"; + +import { type FC } from "react"; +import { KEYS, type Descendant } from "platejs"; +import { createPlatePlugin, type PlateElementProps } from "platejs/react"; +import { InlineCitation, UrlCitation } from "@/components/assistant-ui/inline-citation"; +import { + CITATION_REGEX, + type CitationUrlMap, + parseTextWithCitations, +} from "@/lib/citations/citation-parser"; + +/** + * Plate inline-void node modeling a single `[citation:...]` reference. + * + * Modeled after the existing `MentionPlugin` pattern in + * `inline-mention-editor.tsx` — the only confirmed pattern in this repo + * for non-text inline UI. Inline-void elements satisfy Slate's invariant + * that the editor renders both atomic widgets and surrounding text + * cleanly without breaking selection / caret semantics. + */ +export type CitationElementNode = { + type: "citation"; + kind: "chunk" | "doc" | "url"; + chunkId?: number; + url?: string; + /** Original `[citation:...]` substring for traceability/debugging. */ + rawText: string; + children: [{ text: "" }]; +}; + +const CITATION_TYPE = "citation"; + +const CitationElement: FC<PlateElementProps<CitationElementNode>> = ({ + attributes, + children, + element, +}) => { + const isUrl = element.kind === "url"; + return ( + <span {...attributes} className="inline-flex align-baseline"> + <span contentEditable={false}> + {isUrl && element.url ? ( + <UrlCitation url={element.url} /> + ) : element.chunkId !== undefined ? ( + <InlineCitation chunkId={element.chunkId} isDocsChunk={element.kind === "doc"} /> + ) : null} + </span> + {children} + </span> + ); +}; + +const CitationPlugin = createPlatePlugin({ + key: CITATION_TYPE, + node: { + isElement: true, + isInline: true, + isVoid: true, + type: CITATION_TYPE, + component: CitationElement, + }, +}); + +/** Plugin kit shape used elsewhere in the editor. */ +export const CitationKit = [CitationPlugin]; + +// --------------------------------------------------------------------------- +// Slate value transform — runs after MarkdownPlugin.deserialize +// --------------------------------------------------------------------------- + +// Structural shapes used by the value transform. We cannot use Plate's +// generic Element / Text type predicates directly because `Descendant` is a +// constrained union and our predicates would over-narrow. Casting through +// these row types keeps the walker readable without fighting the types. +type SlateText = { text: string } & Record<string, unknown>; +type SlateElement = { type?: string; children: Descendant[] } & Record<string, unknown>; + +function isText(node: Descendant): boolean { + return typeof (node as { text?: unknown }).text === "string"; +} + +function asText(node: Descendant): SlateText { + return node as unknown as SlateText; +} + +function asElement(node: Descendant): SlateElement { + return node as unknown as SlateElement; +} + +/** + * Element types whose subtrees we MUST NOT inject citation void elements + * into. Each rationale documented in the citation plan: + * - `KEYS.codeBlock` / `code_line` — Plate's schema rejects inline elements + * inside code containers; the user expects literal text inside code. + * - `KEYS.link` — `<button>` inside `<a>` is invalid HTML and the link + * swallows the citation click. Mirrors the `<a>` skip in + * `MarkdownViewer`. + */ +const SKIP_SUBTREE_TYPES = new Set<string>([ + KEYS.codeBlock, + "code_line", + KEYS.link, +]); + +/** + * Build the marks portion of a Slate text node so we can preserve formatting + * (bold/italic/etc.) on the surrounding text fragments after we split. + */ +function copyMarks(textNode: SlateText): Record<string, unknown> { + const { text: _text, ...marks } = textNode; + return marks; +} + +function makeCitationElement( + rawText: string, + segment: { kind: "url"; url: string } | { kind: "chunk"; chunkId: number; isDocsChunk: boolean } +): CitationElementNode { + if (segment.kind === "url") { + return { + type: CITATION_TYPE, + kind: "url", + url: segment.url, + rawText, + children: [{ text: "" }], + }; + } + return { + type: CITATION_TYPE, + kind: segment.isDocsChunk ? "doc" : "chunk", + chunkId: segment.chunkId, + rawText, + children: [{ text: "" }], + }; +} + +/** + * Re-extract the raw `[citation:...]` substrings that produced each parsed + * segment, in source order. Lets us preserve the original literal for + * `rawText` on the inline-void element. + */ +function extractRawCitationMatches(text: string): string[] { + const matches: string[] = []; + CITATION_REGEX.lastIndex = 0; + let m: RegExpExecArray | null = CITATION_REGEX.exec(text); + while (m !== null) { + matches.push(m[0]); + m = CITATION_REGEX.exec(text); + } + return matches; +} + +function transformTextNode(node: SlateText, urlMap: CitationUrlMap): Descendant[] { + const segments = parseTextWithCitations(node.text, urlMap); + if (segments.length === 1 && typeof segments[0] === "string") { + return [node as unknown as Descendant]; + } + + const marks = copyMarks(node); + const rawMatches = extractRawCitationMatches(node.text); + const out: Descendant[] = []; + let citationIdx = 0; + let pendingText: string | null = null; + + const flushText = () => { + // Slate inline-void adjacency: emit an empty text node (with copied + // marks) when the citation appears at the very start/end of the text + // node so neighbours of the void always have a text sibling. + out.push({ ...marks, text: pendingText ?? "" } as unknown as Descendant); + pendingText = null; + }; + + for (const segment of segments) { + if (typeof segment === "string") { + pendingText = (pendingText ?? "") + segment; + } else { + flushText(); + const raw = rawMatches[citationIdx] ?? ""; + out.push(makeCitationElement(raw, segment) as unknown as Descendant); + citationIdx += 1; + // Always reset pendingText so the next loop iteration emits a + // trailing empty text node if no further plain text follows. + pendingText = ""; + } + } + flushText(); + + return out; +} + +function transformChildren(children: Descendant[], urlMap: CitationUrlMap): Descendant[] { + const out: Descendant[] = []; + for (const child of children) { + if (isText(child)) { + out.push(...transformTextNode(asText(child), urlMap)); + continue; + } + const elementChild = asElement(child); + const elementType = (elementChild.type ?? "") as string; + if (elementType && SKIP_SUBTREE_TYPES.has(elementType)) { + out.push(child); + continue; + } + out.push({ + ...elementChild, + children: transformChildren(elementChild.children, urlMap), + } as unknown as Descendant); + } + return out; +} + +/** + * Walk a deserialized Slate value and replace every `[citation:...]` + * substring with a `citation` inline-void element. URL placeholders + * created by `preprocessCitationMarkdown` are resolved through `urlMap`. + * + * Subtrees of `code_block`, `code_line`, and `link` are returned as-is — + * see `SKIP_SUBTREE_TYPES` above. + */ +export function injectCitationNodes(value: Descendant[], urlMap: CitationUrlMap): Descendant[] { + return transformChildren(value, urlMap); +} diff --git a/surfsense_web/components/editor/utils/escape-mdx.ts b/surfsense_web/components/editor/utils/escape-mdx.ts index cd5294b11..14839b9fc 100644 --- a/surfsense_web/components/editor/utils/escape-mdx.ts +++ b/surfsense_web/components/editor/utils/escape-mdx.ts @@ -7,7 +7,7 @@ // break the MDX parser. This module sanitises them before deserialization. // --------------------------------------------------------------------------- -const FENCED_OR_INLINE_CODE = /(```[\s\S]*?```|`[^`\n]+`)/g; +import { FENCED_OR_INLINE_CODE } from "@/lib/markdown/code-regions"; // Strip HTML comments that MDX cannot parse. // PDF converters emit <!-- PageHeader="..." -->, <!-- PageBreak -->, etc. diff --git a/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx b/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx index ac5463873..7ad78be41 100644 --- a/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx +++ b/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx @@ -316,10 +316,10 @@ export function DocumentTabContent({ documentId, searchSpaceId, title }: Documen </Button> </AlertDescription> </Alert> - <MarkdownViewer content={doc.source_markdown} /> + <MarkdownViewer content={doc.source_markdown} enableCitations /> </> ) : ( - <MarkdownViewer content={doc.source_markdown} /> + <MarkdownViewer content={doc.source_markdown} enableCitations /> )} </div> </div> diff --git a/surfsense_web/components/markdown-viewer.tsx b/surfsense_web/components/markdown-viewer.tsx index c4d73e30b..b2420711a 100644 --- a/surfsense_web/components/markdown-viewer.tsx +++ b/surfsense_web/components/markdown-viewer.tsx @@ -3,6 +3,12 @@ import { createMathPlugin } from "@streamdown/math"; import { Streamdown, type StreamdownProps } from "streamdown"; import "katex/dist/katex.min.css"; import Image from "next/image"; +import { useMemo } from "react"; +import { processChildrenWithCitations } from "@/components/citations/citation-renderer"; +import { + type CitationUrlMap, + preprocessCitationMarkdown, +} from "@/lib/citations/citation-parser"; import { cn } from "@/lib/utils"; const code = createCodePlugin({ @@ -21,8 +27,21 @@ interface MarkdownViewerProps { content: string; className?: string; maxLength?: number; + /** + * When true, render `[citation:N]` / `[citation:URL]` tokens as the + * interactive citation badges/popovers used in chat. Default `false` + * so callers that don't need citations are unchanged. + * + * Note: we deliberately do NOT override `<a>` to inject citations into + * link text — that would produce `<button>` inside `<a>` (invalid + * HTML). A `[citation:N]` token literally placed inside markdown link + * text stays as raw text. + */ + enableCitations?: boolean; } +const EMPTY_URL_MAP: CitationUrlMap = new Map(); + /** * If the entire content is wrapped in a single ```markdown or ```md * code fence, strip the fence so the inner markdown renders properly. @@ -85,14 +104,45 @@ function convertLatexDelimiters(content: string): string { return content; } -export function MarkdownViewer({ content, className, maxLength }: MarkdownViewerProps) { +export function MarkdownViewer({ + content, + className, + maxLength, + enableCitations = false, +}: MarkdownViewerProps) { const isTruncated = maxLength != null && content.length > maxLength; const displayContent = isTruncated ? content.slice(0, maxLength) : content; - const processedContent = convertLatexDelimiters(stripOuterMarkdownFence(displayContent)); + + // Preprocess for URL placeholders BEFORE LaTeX so GFM autolinks don't + // split `[citation:https://…]` apart. The preprocess is code-fence + // aware so citations inside fenced code stay literal. + const { processedContent, urlMap } = useMemo(() => { + const stripped = stripOuterMarkdownFence(displayContent); + if (!enableCitations) { + return { + processedContent: convertLatexDelimiters(stripped), + urlMap: EMPTY_URL_MAP, + }; + } + const { content: rewritten, urlMap: map } = preprocessCitationMarkdown(stripped); + return { + processedContent: convertLatexDelimiters(rewritten), + urlMap: map, + }; + }, [displayContent, enableCitations]); + + // Phrasing/block renderers wrap their string children through the + // citation renderer when `enableCitations` is on. We deliberately do + // NOT override `<a>` (would produce <button> inside <a>) and we do + // NOT touch the inline/fenced `code` paths (citations stay literal + // inside code, matching markdown-text.tsx behavior). + const wrap = (children: React.ReactNode): React.ReactNode => + enableCitations ? processChildrenWithCitations(children, urlMap) : children; + const components: StreamdownProps["components"] = { p: ({ children, ...props }) => ( <p className="my-2" {...props}> - {children} + {wrap(children)} </p> ), a: ({ children, ...props }) => ( @@ -105,31 +155,49 @@ export function MarkdownViewer({ content, className, maxLength }: MarkdownViewer {children} </a> ), - li: ({ children, ...props }) => <li {...props}>{children}</li>, + li: ({ children, ...props }) => <li {...props}>{wrap(children)}</li>, ul: ({ ...props }) => <ul className="list-disc pl-5 my-2" {...props} />, ol: ({ ...props }) => <ol className="list-decimal pl-5 my-2" {...props} />, h1: ({ children, ...props }) => ( <h1 className="text-2xl font-bold mt-6 mb-2" {...props}> - {children} + {wrap(children)} </h1> ), h2: ({ children, ...props }) => ( <h2 className="text-xl font-bold mt-5 mb-2" {...props}> - {children} + {wrap(children)} </h2> ), h3: ({ children, ...props }) => ( <h3 className="text-lg font-bold mt-4 mb-2" {...props}> - {children} + {wrap(children)} </h3> ), h4: ({ children, ...props }) => ( <h4 className="text-base font-bold mt-3 mb-1" {...props}> - {children} + {wrap(children)} </h4> ), - blockquote: ({ ...props }) => ( - <blockquote className="border-l-4 border-muted pl-4 italic my-2" {...props} /> + h5: ({ children, ...props }) => ( + <h5 className="text-sm font-bold mt-3 mb-1" {...props}> + {wrap(children)} + </h5> + ), + h6: ({ children, ...props }) => ( + <h6 className="text-xs font-bold mt-3 mb-1" {...props}> + {wrap(children)} + </h6> + ), + strong: ({ children, ...props }) => ( + <strong className="font-semibold" {...props}> + {wrap(children)} + </strong> + ), + em: ({ children, ...props }) => <em {...props}>{wrap(children)}</em>, + blockquote: ({ children, ...props }) => ( + <blockquote className="border-l-4 border-muted pl-4 italic my-2" {...props}> + {wrap(children)} + </blockquote> ), hr: ({ ...props }) => <hr className="my-4 border-muted" {...props} />, img: ({ src, alt, width: _w, height: _h, ...props }) => { @@ -163,17 +231,21 @@ export function MarkdownViewer({ content, className, maxLength }: MarkdownViewer <table className="w-full divide-y divide-border" {...props} /> </div> ), - th: ({ ...props }) => ( + th: ({ children, ...props }) => ( <th className="px-4 py-2.5 text-left text-sm font-semibold text-muted-foreground/80 bg-muted/30 border-r border-border/40 last:border-r-0" {...props} - /> + > + {wrap(children)} + </th> ), - td: ({ ...props }) => ( + td: ({ children, ...props }) => ( <td className="px-4 py-2.5 text-sm border-t border-r border-border/40 last:border-r-0" {...props} - /> + > + {wrap(children)} + </td> ), }; diff --git a/surfsense_web/components/report-panel/report-panel.tsx b/surfsense_web/components/report-panel/report-panel.tsx index 621cf13ce..7fafc9c3b 100644 --- a/surfsense_web/components/report-panel/report-panel.tsx +++ b/surfsense_web/components/report-panel/report-panel.tsx @@ -516,7 +516,7 @@ export function ReportPanelContent({ ) : reportContent.content ? ( isReadOnly ? ( <div className="h-full overflow-y-auto px-5 py-4"> - <MarkdownViewer content={reportContent.content} /> + <MarkdownViewer content={reportContent.content} enableCitations /> </div> ) : ( <PlateEditor @@ -531,6 +531,9 @@ export function ReportPanelContent({ reserveToolbarSpace defaultEditing={isEditing} className="[&_[role=toolbar]]:!bg-sidebar" + // Show citation badges in view mode; raw `[citation:N]` + // text in edit mode so users can edit/delete tokens. + enableCitations={!isEditing} /> ) ) : ( diff --git a/surfsense_web/lib/citations/citation-parser.ts b/surfsense_web/lib/citations/citation-parser.ts new file mode 100644 index 000000000..6333b0f97 --- /dev/null +++ b/surfsense_web/lib/citations/citation-parser.ts @@ -0,0 +1,134 @@ +// Pure citation parsing for `[citation:...]` tokens emitted by SurfSense +// agents. No React imports — consumed by both the React renderer +// (markdown surfaces) and the Plate value transform (document viewer). +// +// The same logic previously lived inline in +// `components/assistant-ui/markdown-text.tsx` with module-level mutable +// state. This module exposes a per-call URL map so multiple concurrent +// renderers / SSR contexts can't race each other. + +import { FENCED_OR_INLINE_CODE } from "@/lib/markdown/code-regions"; + +/** + * Matches `[citation:...]` with numeric IDs (incl. negative, doc- prefix, + * comma-separated), URL-based IDs from live web search, or `urlciteN` + * placeholders produced by `preprocessCitationMarkdown`. + * + * Also matches Chinese brackets 【】 and zero-width spaces that LLMs + * sometimes emit. + */ +export const CITATION_REGEX = + /[[【]\u200B?citation:\s*(https?:\/\/[^\]】\u200B]+|urlcite\d+|(?:doc-)?-?\d+(?:\s*,\s*(?:doc-)?-?\d+)*)\s*\u200B?[\]】]/g; + +/** A single parsed citation reference. */ +export type CitationToken = + | { kind: "url"; url: string } + | { kind: "chunk"; chunkId: number; isDocsChunk: boolean }; + +/** Output of `parseTextWithCitations` — interleaved text + citation tokens. */ +export type ParsedSegment = string | CitationToken; + +/** Per-call URL placeholder map; key is `urlciteN`, value is the original URL. */ +export type CitationUrlMap = Map<string, string>; + +/** Result of preprocessing raw markdown for downstream parsing. */ +export interface PreprocessedCitations { + /** Markdown with `[citation:URL]` tokens rewritten to `[citation:urlciteN]`. */ + content: string; + /** Lookup table to recover the original URL from each placeholder. */ + urlMap: CitationUrlMap; +} + +/** Pattern matching only URL-form citations (used during preprocessing). */ +const URL_CITATION_REGEX = + /[[【]\u200B?citation:\s*(https?:\/\/[^\]】\u200B]+)\s*\u200B?[\]】]/g; + +/** + * Replace `[citation:URL]` tokens with `[citation:urlciteN]` placeholders so + * GFM autolinks don't split the URL out of the brackets during markdown + * parsing. Returns both the rewritten content and a map for later lookup. + * + * Code-fence aware: skips fenced (``` ``` ```) and inline (`` ` ``) code + * regions so citation-shaped strings inside example code remain literal. + * + * Known limitations: `~~~` fences, 4-space indented code, and LaTeX math + * blocks are not skipped. Citation tokens inside those regions are rare in + * practice; documented in the plan. + */ +export function preprocessCitationMarkdown(content: string): PreprocessedCitations { + const urlMap: CitationUrlMap = new Map(); + let counter = 0; + + // Splitting on a regex with one capture group puts code regions at odd + // indexes (matched delimiters) and the surrounding text at even indexes. + // Only transform the even-indexed parts. + const parts = content.split(FENCED_OR_INLINE_CODE); + const transformed = parts.map((part, index) => { + if (index % 2 === 1) return part; + return part.replace(URL_CITATION_REGEX, (_match, url: string) => { + const key = `urlcite${counter++}`; + urlMap.set(key, url.trim()); + return `[citation:${key}]`; + }); + }); + + return { content: transformed.join(""), urlMap }; +} + +/** + * Parse a string into an array of plain text segments and citation tokens. + * + * Pure data — no React. The renderer module is responsible for mapping + * tokens to JSX. Negative chunk IDs are forwarded as-is so the consumer + * can decide how to render anonymous documents. + */ +export function parseTextWithCitations( + text: string, + urlMap: CitationUrlMap +): ParsedSegment[] { + const segments: ParsedSegment[] = []; + let lastIndex = 0; + let match: RegExpExecArray | null; + + CITATION_REGEX.lastIndex = 0; + match = CITATION_REGEX.exec(text); + while (match !== null) { + if (match.index > lastIndex) { + segments.push(text.substring(lastIndex, match.index)); + } + + const captured = match[1]; + + if (captured.startsWith("http://") || captured.startsWith("https://")) { + segments.push({ kind: "url", url: captured.trim() }); + } else if (captured.startsWith("urlcite")) { + const url = urlMap.get(captured); + if (url) { + segments.push({ kind: "url", url }); + } + } else { + const rawIds = captured.split(",").map((s) => s.trim()); + for (const rawId of rawIds) { + const isDocsChunk = rawId.startsWith("doc-"); + const chunkId = Number.parseInt(isDocsChunk ? rawId.slice(4) : rawId, 10); + if (!Number.isNaN(chunkId)) { + segments.push({ kind: "chunk", chunkId, isDocsChunk }); + } + } + } + + lastIndex = match.index + match[0].length; + match = CITATION_REGEX.exec(text); + } + + if (lastIndex < text.length) { + segments.push(text.substring(lastIndex)); + } + + return segments.length > 0 ? segments : [text]; +} + +/** Type guard for the citation branch of `ParsedSegment`. */ +export function isCitationToken(segment: ParsedSegment): segment is CitationToken { + return typeof segment !== "string"; +} diff --git a/surfsense_web/lib/markdown/code-regions.ts b/surfsense_web/lib/markdown/code-regions.ts new file mode 100644 index 000000000..336a87acb --- /dev/null +++ b/surfsense_web/lib/markdown/code-regions.ts @@ -0,0 +1,8 @@ +// Matches fenced (```...```) and inline (`...`) code regions. Used by MDX +// escaping and citation preprocessing — single source of truth so future +// edits stay in sync. +// +// String.split() with this capturing pattern places non-code parts at even +// indexes and matched code regions at odd indexes — preserve odd-indexed +// segments verbatim when transforming markdown. +export const FENCED_OR_INLINE_CODE = /(```[\s\S]*?```|`[^`\n]+`)/g; From c644f02d0575473118076aa55b762aaa638a56b3 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Thu, 30 Apr 2026 18:42:38 -0700 Subject: [PATCH 39/68] chore: linting --- ...38_add_thread_auto_model_pinning_fields.py | 16 +- .../app/services/auto_model_pin_service.py | 15 +- .../app/tasks/chat/stream_new_chat.py | 16 +- .../services/test_auto_model_pin_service.py | 76 +++++- .../unit/test_stream_new_chat_contract.py | 26 +- .../new-chat/[[...chat_id]]/page.tsx | 120 ++++----- .../app/desktop/permissions/page.tsx | 4 +- .../agent-action-log/action-log-sheet.tsx | 5 +- .../assistant-ui/inline-citation.tsx | 6 +- .../assistant-ui/inline-mention-editor.tsx | 43 +++- .../components/assistant-ui/markdown-text.tsx | 24 +- .../components/assistant-ui/nested-scroll.tsx | 2 +- .../components/assistant-ui/thread.tsx | 2 +- .../components/assistant-ui/tool-fallback.tsx | 22 +- .../citations/citation-renderer.tsx | 4 +- .../editor/plugins/citation-kit.tsx | 10 +- .../layout/providers/LayoutDataProvider.tsx | 2 +- .../layout/ui/sidebar/DocumentsSidebar.tsx | 6 +- surfsense_web/components/markdown-viewer.tsx | 5 +- .../hooks/use-agent-actions-query.ts | 243 ++++++++---------- .../lib/chat/chat-error-classifier.ts | 45 ++-- surfsense_web/lib/chat/chat-request-errors.ts | 8 +- surfsense_web/lib/chat/stream-pipeline.ts | 8 +- surfsense_web/lib/chat/stream-side-effects.ts | 8 +- .../lib/citations/citation-parser.ts | 8 +- surfsense_web/lib/posthog/events.ts | 2 +- 26 files changed, 346 insertions(+), 380 deletions(-) diff --git a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py index 1ea549975..3972b84b9 100644 --- a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py +++ b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py @@ -47,19 +47,11 @@ def upgrade() -> None: def downgrade() -> None: - op.execute( - "DROP INDEX IF EXISTS ix_new_chat_threads_pinned_auto_mode" - ) - op.execute( - "DROP INDEX IF EXISTS ix_new_chat_threads_pinned_llm_config_id" - ) + op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_auto_mode") + op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_llm_config_id") - op.execute( - "ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_at" - ) - op.execute( - "ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_auto_mode" - ) + op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_at") + op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_auto_mode") op.execute( "ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_llm_config_id" ) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index 6bdb60f57..6b69c91ea 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -44,7 +44,9 @@ def _is_usable_global_config(cfg: dict) -> bool: def _global_candidates() -> list[dict]: - candidates = [cfg for cfg in config.GLOBAL_LLM_CONFIGS if _is_usable_global_config(cfg)] + candidates = [ + cfg for cfg in config.GLOBAL_LLM_CONFIGS if _is_usable_global_config(cfg) + ] return sorted(candidates, key=lambda c: int(c.get("id", 0))) @@ -69,7 +71,9 @@ def _to_uuid(user_id: str | UUID | None) -> UUID | None: return None -async def _is_premium_eligible(session: AsyncSession, user_id: str | UUID | None) -> bool: +async def _is_premium_eligible( + session: AsyncSession, user_id: str | UUID | None +) -> bool: parsed = _to_uuid(user_id) if parsed is None: return False @@ -136,8 +140,7 @@ async def resolve_or_get_pinned_llm_config_id( pinned_id = thread.pinned_llm_config_id if ( not force_repin_free - and - thread.pinned_auto_mode == AUTO_FASTEST_MODE + and thread.pinned_auto_mode == AUTO_FASTEST_MODE and pinned_id is not None and int(pinned_id) in candidate_by_id ): @@ -163,7 +166,9 @@ async def resolve_or_get_pinned_llm_config_id( thread.pinned_auto_mode, ) - premium_eligible = False if force_repin_free else await _is_premium_eligible(session, user_id) + premium_eligible = ( + False if force_repin_free else await _is_premium_eligible(session, user_id) + ) if premium_eligible: eligible = candidates else: diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 63c149771..5abcb63eb 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -2225,9 +2225,7 @@ async def stream_new_chat( # Premium quota reservation for pinned premium model only. _needs_premium_quota = ( - agent_config is not None - and user_id - and agent_config.is_premium + agent_config is not None and user_id and agent_config.is_premium ) if _needs_premium_quota: import uuid as _uuid @@ -2271,7 +2269,9 @@ async def stream_new_chat( yield streaming_service.format_done() return - llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + llm, agent_config, llm_load_error = await _load_llm_bundle( + llm_config_id + ) if llm_load_error: yield _emit_stream_error( message=llm_load_error, @@ -3086,9 +3086,7 @@ async def stream_resume_chat( _resume_premium_reserved = 0 _resume_premium_request_id: str | None = None _resume_needs_premium = ( - agent_config is not None - and user_id - and agent_config.is_premium + agent_config is not None and user_id and agent_config.is_premium ) if _resume_needs_premium: import uuid as _uuid @@ -3132,7 +3130,9 @@ async def stream_resume_chat( yield streaming_service.format_done() return - llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + llm, agent_config, llm_load_error = await _load_llm_bundle( + llm_config_id + ) if llm_load_error: yield _emit_stream_error( message=llm_load_error, diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index f08e50ba2..0a2342e05 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -66,7 +66,13 @@ async def test_auto_first_turn_pins_one_model(monkeypatch): "GLOBAL_LLM_CONFIGS", [ {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, - {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-prem", + "api_key": "k2", + "billing_tier": "premium", + }, ], ) @@ -103,12 +109,20 @@ async def test_next_turn_reuses_existing_pin(monkeypatch): config, "GLOBAL_LLM_CONFIGS", [ - {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-prem", + "api_key": "k2", + "billing_tier": "premium", + }, ], ) async def _must_not_call(*_args, **_kwargs): - raise AssertionError("premium_get_usage should not be called for valid pin reuse") + raise AssertionError( + "premium_get_usage should not be called for valid pin reuse" + ) monkeypatch.setattr( "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", @@ -136,7 +150,13 @@ async def test_premium_eligible_auto_can_pin_premium(monkeypatch): config, "GLOBAL_LLM_CONFIGS", [ - {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-prem", + "api_key": "k2", + "billing_tier": "premium", + }, ], ) @@ -168,8 +188,20 @@ async def test_premium_ineligible_auto_pins_free_only(monkeypatch): config, "GLOBAL_LLM_CONFIGS", [ - {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"}, - {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + { + "id": -2, + "provider": "OPENAI", + "model_name": "gpt-free", + "api_key": "k1", + "billing_tier": "free", + }, + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-prem", + "api_key": "k2", + "billing_tier": "premium", + }, ], ) @@ -203,8 +235,20 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): config, "GLOBAL_LLM_CONFIGS", [ - {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"}, - {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + { + "id": -2, + "provider": "OPENAI", + "model_name": "gpt-free", + "api_key": "k1", + "billing_tier": "free", + }, + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-prem", + "api_key": "k2", + "billing_tier": "premium", + }, ], ) @@ -238,8 +282,20 @@ async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch): config, "GLOBAL_LLM_CONFIGS", [ - {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free"}, - {"id": -1, "provider": "OPENAI", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium"}, + { + "id": -2, + "provider": "OPENAI", + "model_name": "gpt-free", + "api_key": "k1", + "billing_tier": "free", + }, + { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-prem", + "api_key": "k2", + "billing_tier": "premium", + }, ], ) diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index a1345c15c..5e6ad6abd 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -203,7 +203,10 @@ def test_stream_exception_classifies_turn_cancelling_when_cancel_requested(): def test_premium_classification_is_error_code_driven(): - classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts" + classifier_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_web/lib/chat/chat-error-classifier.ts" + ) source = classifier_path.read_text(encoding="utf-8") assert "PREMIUM_KEYWORDS" not in source @@ -229,7 +232,8 @@ def test_stream_terminal_error_handler_has_pre_accept_soft_rollback_hook(): def test_toast_only_pre_accept_policy_has_no_inline_failed_marker(): user_message_path = ( - Path(__file__).resolve().parents[3] / "surfsense_web/components/assistant-ui/user-message.tsx" + Path(__file__).resolve().parents[3] + / "surfsense_web/components/assistant-ui/user-message.tsx" ) source = user_message_path.read_text(encoding="utf-8") @@ -238,10 +242,14 @@ def test_toast_only_pre_accept_policy_has_no_inline_failed_marker(): def test_network_send_failures_use_unified_retry_toast_message(): - classifier_path = Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-error-classifier.ts" + classifier_path = ( + Path(__file__).resolve().parents[3] + / "surfsense_web/lib/chat/chat-error-classifier.ts" + ) classifier_source = classifier_path.read_text(encoding="utf-8") request_errors_path = ( - Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/chat-request-errors.ts" + Path(__file__).resolve().parents[3] + / "surfsense_web/lib/chat/chat-request-errors.ts" ) request_errors_source = request_errors_path.read_text(encoding="utf-8") @@ -350,15 +358,17 @@ def test_turn_status_sse_contract_exists(): / "surfsense_backend/app/tasks/chat/stream_new_chat.py" ).read_text(encoding="utf-8") state_source = ( - Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/streaming-state.ts" + Path(__file__).resolve().parents[3] + / "surfsense_web/lib/chat/streaming-state.ts" ).read_text(encoding="utf-8") pipeline_source = ( - Path(__file__).resolve().parents[3] / "surfsense_web/lib/chat/stream-pipeline.ts" + Path(__file__).resolve().parents[3] + / "surfsense_web/lib/chat/stream-pipeline.ts" ).read_text(encoding="utf-8") assert '"turn-status"' in stream_source assert '"status": "busy"' in stream_source assert '"status": "idle"' in stream_source - assert "type: \"data-turn-status\"" in state_source - assert "case \"data-turn-status\":" in pipeline_source + assert 'type: "data-turn-status"' in state_source + assert 'case "data-turn-status":' in pipeline_source assert "end_turn(str(chat_id))" in stream_source diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 1b25ca431..39201e5cc 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -19,7 +19,6 @@ import { currentThreadAtom, setTargetCommentIdAtom, } from "@/atoms/chat/current-thread.atom"; -import { setPremiumAlertForThreadAtom } from "@/atoms/chat/premium-alert.atom"; import { type MentionedDocumentInfo, mentionedDocumentIdsAtom, @@ -31,6 +30,7 @@ import { clearPlanOwnerRegistry, // extractWriteTodosFromContent, } from "@/atoms/chat/plan-state.atom"; +import { setPremiumAlertForThreadAtom } from "@/atoms/chat/premium-alert.atom"; import { closeReportPanelAtom } from "@/atoms/chat/report-panel.atom"; import { type AgentCreatedDocument, agentCreatedDocumentsAtom } from "@/atoms/documents/ui.atoms"; import { closeEditorPanelAtom } from "@/atoms/editor/editor-panel.atom"; @@ -60,20 +60,28 @@ import { useMessagesSync } from "@/hooks/use-messages-sync"; import { getAgentFilesystemSelection } from "@/lib/agent-filesystem"; import { documentsApiService } from "@/lib/apis/documents-api.service"; import { getBearerToken } from "@/lib/auth-utils"; -import { - classifyChatError, - type ChatFlow, -} from "@/lib/chat/chat-error-classifier"; -import { - tagPreAcceptSendFailure, - toHttpResponseError, -} from "@/lib/chat/chat-request-errors"; +import { type ChatFlow, classifyChatError } from "@/lib/chat/chat-error-classifier"; +import { tagPreAcceptSendFailure, toHttpResponseError } from "@/lib/chat/chat-request-errors"; import { convertToThreadMessage } from "@/lib/chat/message-utils"; import { isPodcastGenerating, looksLikePodcastRequest, setActivePodcastTaskId, } from "@/lib/chat/podcast-state"; +import { createStreamFlushHelpers } from "@/lib/chat/stream-flush"; +import { + consumeSseEvents, + hasPersistableContent, + processSharedStreamEvent, +} from "@/lib/chat/stream-pipeline"; +import { + applyInterruptRequestToContentParts, + applyTurnIdToAssistantMessageList, + markInterruptDecisionOnContentParts, + mergeChatTurnIdIntoMessage, + mergeEditedInterruptAction, + readStreamedChatTurnId, +} from "@/lib/chat/stream-side-effects"; import { buildContentForPersistence, buildContentForUI, @@ -82,20 +90,6 @@ import { type ThinkingStepData, type ToolUIGate, } from "@/lib/chat/streaming-state"; -import { createStreamFlushHelpers } from "@/lib/chat/stream-flush"; -import { - consumeSseEvents, - hasPersistableContent, - processSharedStreamEvent, -} from "@/lib/chat/stream-pipeline"; -import { - applyTurnIdToAssistantMessageList, - applyInterruptRequestToContentParts, - mergeChatTurnIdIntoMessage, - mergeEditedInterruptAction, - markInterruptDecisionOnContentParts, - readStreamedChatTurnId, -} from "@/lib/chat/stream-side-effects"; import { appendMessage, createThread, @@ -112,8 +106,8 @@ import { } from "@/lib/chat/user-turn-api-parts"; import { NotFoundError } from "@/lib/error"; import { - trackChatCreated, trackChatBlocked, + trackChatCreated, trackChatErrorDetailed, trackChatMessageSent, trackChatResponseReceived, @@ -193,7 +187,8 @@ function sleep(ms: number): Promise<void> { function computeFallbackTurnCancellingRetryDelay(attempt: number): number { const safeAttempt = Math.max(1, attempt); - const raw = TURN_CANCELLING_INITIAL_DELAY_MS * TURN_CANCELLING_BACKOFF_FACTOR ** (safeAttempt - 1); + const raw = + TURN_CANCELLING_INITIAL_DELAY_MS * TURN_CANCELLING_BACKOFF_FACTOR ** (safeAttempt - 1); return Math.min(raw, TURN_CANCELLING_MAX_DELAY_MS); } @@ -278,11 +273,9 @@ export default function NewChatPage() { }) => { if (!threadId) return null; try { - const normalizedContent = Array.isArray(content) - ? ([...content] as unknown[]) - : [content]; - const hasMentionedDocumentsPart = normalizedContent.some((part) => - MentionedDocumentsPartSchema.safeParse(part).success + const normalizedContent = Array.isArray(content) ? ([...content] as unknown[]) : [content]; + const hasMentionedDocumentsPart = normalizedContent.some( + (part) => MentionedDocumentsPartSchema.safeParse(part).success ); if (mentionedDocs && mentionedDocs.length > 0 && !hasMentionedDocumentsPart) { normalizedContent.push({ @@ -300,10 +293,7 @@ export default function NewChatPage() { setMessages((prev) => prev.map((m) => m.id === userMsgId - ? mergeChatTurnIdIntoMessage( - { ...m, id: newUserMsgId }, - savedUserMessage.turn_id - ) + ? mergeChatTurnIdIntoMessage({ ...m, id: newUserMsgId }, savedUserMessage.turn_id) : m ) ); @@ -356,10 +346,7 @@ export default function NewChatPage() { setMessages((prev) => prev.map((m) => m.id === assistantMsgId - ? mergeChatTurnIdIntoMessage( - { ...m, id: newMsgId }, - savedMessage.turn_id - ) + ? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id) : m ) ); @@ -564,12 +551,7 @@ export default function NewChatPage() { toast.error(normalized.userMessage); }, - [ - currentUser?.id, - persistAssistantErrorMessage, - searchSpaceId, - setPremiumAlertForThread, - ] + [currentUser?.id, persistAssistantErrorMessage, searchSpaceId, setPremiumAlertForThread] ); const handleStreamTerminalError = useCallback( @@ -613,35 +595,31 @@ export default function NewChatPage() { [handleChatFailure] ); - const fetchWithTurnCancellingRetry = useCallback( - async (runFetch: () => Promise<Response>) => { - const maxAttempts = 4; - for (let attempt = 1; attempt <= maxAttempts; attempt += 1) { - const response = await runFetch(); - if (response.ok) { - return response; - } - const error = await toHttpResponseError(response); - const withMeta = error as Error & { errorCode?: string; retryAfterMs?: number }; - const isTurnCancelling = withMeta.errorCode === "TURN_CANCELLING"; - const isRecentThreadBusyAfterCancel = - withMeta.errorCode === "THREAD_BUSY" && - Date.now() - recentCancelRequestedAtRef.current <= RECENT_CANCEL_WINDOW_MS; - if ((isTurnCancelling || isRecentThreadBusyAfterCancel) && attempt < maxAttempts) { - const waitMs = - withMeta.retryAfterMs ?? computeFallbackTurnCancellingRetryDelay(attempt); - await sleep(waitMs); - continue; - } - throw error; + const fetchWithTurnCancellingRetry = useCallback(async (runFetch: () => Promise<Response>) => { + const maxAttempts = 4; + for (let attempt = 1; attempt <= maxAttempts; attempt += 1) { + const response = await runFetch(); + if (response.ok) { + return response; } + const error = await toHttpResponseError(response); + const withMeta = error as Error & { errorCode?: string; retryAfterMs?: number }; + const isTurnCancelling = withMeta.errorCode === "TURN_CANCELLING"; + const isRecentThreadBusyAfterCancel = + withMeta.errorCode === "THREAD_BUSY" && + Date.now() - recentCancelRequestedAtRef.current <= RECENT_CANCEL_WINDOW_MS; + if ((isTurnCancelling || isRecentThreadBusyAfterCancel) && attempt < maxAttempts) { + const waitMs = withMeta.retryAfterMs ?? computeFallbackTurnCancellingRetryDelay(attempt); + await sleep(waitMs); + continue; + } + throw error; + } - throw Object.assign(new Error("Turn cancellation retry limit exceeded"), { - errorCode: "TURN_CANCELLING", - }); - }, - [] - ); + throw Object.assign(new Error("Turn cancellation retry limit exceeded"), { + errorCode: "TURN_CANCELLING", + }); + }, []); // Initialize thread and load messages // For new chats (no urlChatId), we use lazy creation - thread is created on first message diff --git a/surfsense_web/app/desktop/permissions/page.tsx b/surfsense_web/app/desktop/permissions/page.tsx index e30a76f83..ca9228272 100644 --- a/surfsense_web/app/desktop/permissions/page.tsx +++ b/surfsense_web/app/desktop/permissions/page.tsx @@ -132,8 +132,8 @@ export default function DesktopPermissionsPage() { <div className="space-y-1"> <h1 className="text-2xl font-semibold tracking-tight">System Permissions</h1> <p className="text-sm text-muted-foreground"> - SurfSense needs two macOS permissions for Screenshot Assist and for desktop features that - require focusing the app or the active application. + SurfSense needs two macOS permissions for Screenshot Assist and for desktop features + that require focusing the app or the active application. </p> </div> </div> diff --git a/surfsense_web/components/agent-action-log/action-log-sheet.tsx b/surfsense_web/components/agent-action-log/action-log-sheet.tsx index 32c25771a..7d27b4019 100644 --- a/surfsense_web/components/agent-action-log/action-log-sheet.tsx +++ b/surfsense_web/components/agent-action-log/action-log-sheet.tsx @@ -17,10 +17,7 @@ import { SheetTitle, } from "@/components/ui/sheet"; import { Skeleton } from "@/components/ui/skeleton"; -import { - agentActionsQueryKey, - useAgentActionsQuery, -} from "@/hooks/use-agent-actions-query"; +import { agentActionsQueryKey, useAgentActionsQuery } from "@/hooks/use-agent-actions-query"; import { ActionLogItem } from "./action-log-item"; function EmptyState() { diff --git a/surfsense_web/components/assistant-ui/inline-citation.tsx b/surfsense_web/components/assistant-ui/inline-citation.tsx index e299f2373..32a29cfc9 100644 --- a/surfsense_web/components/assistant-ui/inline-citation.tsx +++ b/surfsense_web/components/assistant-ui/inline-citation.tsx @@ -182,11 +182,7 @@ const SurfsenseDocCitation: FC<{ chunkId: number }> = ({ chunkId }) => { </p> )} {!isLoading && !error && citedChunk?.content && ( - <MarkdownViewer - content={citedChunk.content} - maxLength={1500} - enableCitations - /> + <MarkdownViewer content={citedChunk.content} maxLength={1500} enableCitations /> )} {!isLoading && !error && !citedChunk?.content && ( <p className="py-4 text-xs text-muted-foreground">No content available.</p> diff --git a/surfsense_web/components/assistant-ui/inline-mention-editor.tsx b/surfsense_web/components/assistant-ui/inline-mention-editor.tsx index d92348080..c585dc80f 100644 --- a/surfsense_web/components/assistant-ui/inline-mention-editor.tsx +++ b/surfsense_web/components/assistant-ui/inline-mention-editor.tsx @@ -1,8 +1,14 @@ "use client"; -import { type FC, forwardRef, useCallback, useImperativeHandle, useMemo, useRef } from "react"; -import { Plate, PlateContent, ParagraphPlugin, createPlatePlugin, usePlateEditor } from "platejs/react"; import type { PlateElementProps } from "platejs/react"; +import { + createPlatePlugin, + ParagraphPlugin, + Plate, + PlateContent, + usePlateEditor, +} from "platejs/react"; +import { type FC, forwardRef, useCallback, useImperativeHandle, useMemo, useRef } from "react"; import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import type { Document } from "@/contracts/types/document.types"; import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; @@ -72,7 +78,11 @@ const COMPOSER_TEXT_METRICS_CLASSNAME = "text-sm leading-6"; const EMPTY_VALUE: ComposerValue = [{ type: "p", children: [{ text: "" }] }]; -const MentionElement: FC<PlateElementProps<MentionElementNode>> = ({ attributes, children, element }) => { +const MentionElement: FC<PlateElementProps<MentionElementNode>> = ({ + attributes, + children, + element, +}) => { const statusClass = element.statusKind === "failed" ? "text-destructive" @@ -255,7 +265,10 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent selection?.addRange(range); }, []); - const getCurrentValue = useCallback(() => (editor.children as ComposerValue) ?? EMPTY_VALUE, [editor]); + const getCurrentValue = useCallback( + () => (editor.children as ComposerValue) ?? EMPTY_VALUE, + [editor] + ); const emitState = useCallback( (nextValue: ComposerValue) => { @@ -379,7 +392,8 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent const next = current.map((block) => { const children = block.children.filter((node) => { if (!isMentionNode(node)) return true; - const match = node.id === docId && (node.document_type ?? "UNKNOWN") === (docType ?? "UNKNOWN"); + const match = + node.id === docId && (node.document_type ?? "UNKNOWN") === (docType ?? "UNKNOWN"); if (match) changed = true; return !match; }); @@ -450,7 +464,15 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent removeDocumentChip, setDocumentChipStatus, }), - [clear, getMentionedDocs, getText, insertDocumentChip, removeDocumentChip, setDocumentChipStatus, setText] + [ + clear, + getMentionedDocs, + getText, + insertDocumentChip, + removeDocumentChip, + setDocumentChipStatus, + setText, + ] ); const handleKeyDown = useCallback( @@ -488,14 +510,7 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent removeDocumentChip(prev.id, prev.document_type); onDocumentRemove?.(prev.id, prev.document_type); }, - [ - editor.selection, - getCurrentValue, - onDocumentRemove, - onKeyDown, - onSubmit, - removeDocumentChip, - ] + [editor.selection, getCurrentValue, onDocumentRemove, onKeyDown, onSubmit, removeDocumentChip] ); const editableProps = useMemo( diff --git a/surfsense_web/components/assistant-ui/markdown-text.tsx b/surfsense_web/components/assistant-ui/markdown-text.tsx index 2b788e88b..4842e5979 100644 --- a/surfsense_web/components/assistant-ui/markdown-text.tsx +++ b/surfsense_web/components/assistant-ui/markdown-text.tsx @@ -12,14 +12,7 @@ import { ExternalLinkIcon } from "lucide-react"; import dynamic from "next/dynamic"; import { useParams } from "next/navigation"; import { useTheme } from "next-themes"; -import { - createContext, - memo, - type ReactNode, - useCallback, - useContext, - useRef, -} from "react"; +import { createContext, memo, type ReactNode, useCallback, useContext, useRef } from "react"; import rehypeKatex from "rehype-katex"; import remarkGfm from "remark-gfm"; import remarkMath from "remark-math"; @@ -28,10 +21,6 @@ import { ImagePreview, ImageRoot, ImageZoom } from "@/components/assistant-ui/im import "katex/dist/katex.min.css"; import { processChildrenWithCitations } from "@/components/citations/citation-renderer"; import { Skeleton } from "@/components/ui/skeleton"; -import { - type CitationUrlMap, - preprocessCitationMarkdown, -} from "@/lib/citations/citation-parser"; import { Table, TableBody, @@ -41,6 +30,7 @@ import { TableRow, } from "@/components/ui/table"; import { useElectronAPI } from "@/hooks/use-platform"; +import { type CitationUrlMap, preprocessCitationMarkdown } from "@/lib/citations/citation-parser"; import { cn } from "@/lib/utils"; function MarkdownCodeBlockSkeleton() { @@ -128,10 +118,7 @@ function preprocessMarkdown(content: string, urlMapRef: CitationUrlMapRef): stri const MarkdownTextImpl = () => { const urlMapRef = useRef<CitationUrlMap>(EMPTY_URL_MAP); - const preprocess = useCallback( - (content: string) => preprocessMarkdown(content, urlMapRef), - [] - ); + const preprocess = useCallback((content: string) => preprocessMarkdown(content, urlMapRef), []); return ( <CitationUrlMapContext.Provider value={urlMapRef}> <MarkdownTextPrimitive @@ -334,10 +321,7 @@ const defaultComponents = memoizeMarkdownComponents({ const urlMap = useCitationUrlMap(); return ( <a - className={cn( - "aui-md-a font-medium text-primary underline underline-offset-4", - className - )} + className={cn("aui-md-a font-medium text-primary underline underline-offset-4", className)} {...props} > {processChildrenWithCitations(children, urlMap)} diff --git a/surfsense_web/components/assistant-ui/nested-scroll.tsx b/surfsense_web/components/assistant-ui/nested-scroll.tsx index 5a4f8d36e..37c4790df 100644 --- a/surfsense_web/components/assistant-ui/nested-scroll.tsx +++ b/surfsense_web/components/assistant-ui/nested-scroll.tsx @@ -1,6 +1,6 @@ "use client"; -import { forwardRef, type ComponentPropsWithoutRef, type WheelEvent } from "react"; +import { type ComponentPropsWithoutRef, forwardRef, type WheelEvent } from "react"; export type NestedScrollProps = ComponentPropsWithoutRef<"div">; diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index 6c02a1efa..b4a3b58c6 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -92,8 +92,8 @@ import { useBatchCommentsPreload } from "@/hooks/use-comments"; import { useCommentsSync } from "@/hooks/use-comments-sync"; import { useMediaQuery } from "@/hooks/use-media-query"; import { useElectronAPI } from "@/hooks/use-platform"; -import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; import { captureDisplayToPngDataUrl } from "@/lib/chat/display-media-capture"; +import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; import { SLIDEOUT_PANEL_OPENED_EVENT } from "@/lib/layout-events"; import { cn } from "@/lib/utils"; diff --git a/surfsense_web/components/assistant-ui/tool-fallback.tsx b/surfsense_web/components/assistant-ui/tool-fallback.tsx index cf42cf398..06082c9c7 100644 --- a/surfsense_web/components/assistant-ui/tool-fallback.tsx +++ b/surfsense_web/components/assistant-ui/tool-fallback.tsx @@ -1,19 +1,16 @@ -import { - type ToolCallMessagePartComponent, - useAuiState, -} from "@assistant-ui/react"; +import { type ToolCallMessagePartComponent, useAuiState } from "@assistant-ui/react"; import { useQueryClient } from "@tanstack/react-query"; import { useAtomValue } from "jotai"; import { CheckIcon, ChevronDownIcon, RotateCcw, XCircleIcon } from "lucide-react"; import { useEffect, useMemo, useState } from "react"; import { toast } from "sonner"; import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom"; +import { NestedScroll } from "@/components/assistant-ui/nested-scroll"; import { DoomLoopApprovalToolUI, isDoomLoopInterrupt, } from "@/components/tool-ui/doom-loop-approval"; import { GenericHitlApprovalToolUI } from "@/components/tool-ui/generic-hitl-approval"; -import { NestedScroll } from "@/components/assistant-ui/nested-scroll"; import { AlertDialog, AlertDialogAction, @@ -32,10 +29,7 @@ import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/component import { Separator } from "@/components/ui/separator"; import { Spinner } from "@/components/ui/spinner"; import { getToolDisplayName } from "@/contracts/enums/toolIcons"; -import { - markActionRevertedInCache, - useAgentActionsQuery, -} from "@/hooks/use-agent-actions-query"; +import { markActionRevertedInCache, useAgentActionsQuery } from "@/hooks/use-agent-actions-query"; import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; import { AppError } from "@/lib/error"; import { isInterruptResult } from "@/lib/hitl"; @@ -124,8 +118,7 @@ function ToolCardRevertButton({ // Tier 1 + 2: O(1) Map-backed direct id match. Covers // ~all parity_v2 streams and any legacy stream that backfilled // ``langchainToolCallId`` via ``tool-output-available``. - const direct = - findByToolCallId(toolCallId) ?? findByToolCallId(langchainToolCallId); + const direct = findByToolCallId(toolCallId) ?? findByToolCallId(langchainToolCallId); if (direct) return direct; // Tier 3: position-within-turn fallback. Only kicks in when the // card has a synthetic ``call_<run_id>`` id AND no @@ -160,12 +153,7 @@ function ToolCardRevertButton({ setIsReverting(true); try { const response = await agentActionsApiService.revert(threadId, action.id); - markActionRevertedInCache( - queryClient, - threadId, - action.id, - response.new_action_id ?? null - ); + markActionRevertedInCache(queryClient, threadId, action.id, response.new_action_id ?? null); toast.success(response.message || "Action reverted."); } catch (err) { // 503 means revert is gated off on this deployment — hide the diff --git a/surfsense_web/components/citations/citation-renderer.tsx b/surfsense_web/components/citations/citation-renderer.tsx index bf877f03f..f2de4b27d 100644 --- a/surfsense_web/components/citations/citation-renderer.tsx +++ b/surfsense_web/components/citations/citation-renderer.tsx @@ -64,9 +64,7 @@ export function processChildrenWithCitations( return ( <span key={`citation-seg-${childIndex}`}> {segments.map((segment) => - typeof segment === "string" - ? segment - : renderCitationToken(segment, ordinal++) + typeof segment === "string" ? segment : renderCitationToken(segment, ordinal++) )} </span> ); diff --git a/surfsense_web/components/editor/plugins/citation-kit.tsx b/surfsense_web/components/editor/plugins/citation-kit.tsx index c90cb5e28..1908de209 100644 --- a/surfsense_web/components/editor/plugins/citation-kit.tsx +++ b/surfsense_web/components/editor/plugins/citation-kit.tsx @@ -1,8 +1,8 @@ "use client"; -import { type FC } from "react"; -import { KEYS, type Descendant } from "platejs"; +import { type Descendant, KEYS } from "platejs"; import { createPlatePlugin, type PlateElementProps } from "platejs/react"; +import type { FC } from "react"; import { InlineCitation, UrlCitation } from "@/components/assistant-ui/inline-citation"; import { CITATION_REGEX, @@ -97,11 +97,7 @@ function asElement(node: Descendant): SlateElement { * swallows the citation click. Mirrors the `<a>` skip in * `MarkdownViewer`. */ -const SKIP_SUBTREE_TYPES = new Set<string>([ - KEYS.codeBlock, - "code_line", - KEYS.link, -]); +const SKIP_SUBTREE_TYPES = new Set<string>([KEYS.codeBlock, "code_line", KEYS.link]); /** * Build the marks portion of a Slate text node so we can preserve formatting diff --git a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx index 3efdab03b..afd888f48 100644 --- a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx +++ b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx @@ -26,9 +26,9 @@ import { type Tab, } from "@/atoms/tabs/tabs.atom"; import { currentUserAtom } from "@/atoms/user/user-query.atoms"; +import { ActionLogSheet } from "@/components/agent-action-log/action-log-sheet"; import { SearchSpaceSettingsDialog } from "@/components/settings/search-space-settings-dialog"; import { TeamDialog } from "@/components/settings/team-dialog"; -import { ActionLogSheet } from "@/components/agent-action-log/action-log-sheet"; import { UserSettingsDialog } from "@/components/settings/user-settings-dialog"; import { AlertDialog, diff --git a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx index d20aea2cd..bf4de6454 100644 --- a/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DocumentsSidebar.tsx @@ -23,9 +23,7 @@ import { useTranslations } from "next-intl"; import type React from "react"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; -import { - mentionedDocumentsAtom, -} from "@/atoms/chat/mentioned-documents.atom"; +import { mentionedDocumentsAtom } from "@/atoms/chat/mentioned-documents.atom"; import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms"; import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms"; import { deleteDocumentMutationAtom } from "@/atoms/documents/document-mutation.atoms"; @@ -74,12 +72,12 @@ import type { DocumentTypeEnum } from "@/contracts/types/document.types"; import { useDebouncedValue } from "@/hooks/use-debounced-value"; import { useMediaQuery } from "@/hooks/use-media-query"; import { useElectronAPI, usePlatform } from "@/hooks/use-platform"; -import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; import { anonymousChatApiService } from "@/lib/apis/anonymous-chat-api.service"; import { documentsApiService } from "@/lib/apis/documents-api.service"; import { foldersApiService } from "@/lib/apis/folders-api.service"; import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service"; import { authenticatedFetch } from "@/lib/auth-utils"; +import { getMentionDocKey } from "@/lib/chat/mention-doc-key"; import { uploadFolderScan } from "@/lib/folder-sync-upload"; import { getSupportedExtensionsSet } from "@/lib/supported-extensions"; import { queries } from "@/zero/queries/index"; diff --git a/surfsense_web/components/markdown-viewer.tsx b/surfsense_web/components/markdown-viewer.tsx index b2420711a..6caf01917 100644 --- a/surfsense_web/components/markdown-viewer.tsx +++ b/surfsense_web/components/markdown-viewer.tsx @@ -5,10 +5,7 @@ import "katex/dist/katex.min.css"; import Image from "next/image"; import { useMemo } from "react"; import { processChildrenWithCitations } from "@/components/citations/citation-renderer"; -import { - type CitationUrlMap, - preprocessCitationMarkdown, -} from "@/lib/citations/citation-parser"; +import { type CitationUrlMap, preprocessCitationMarkdown } from "@/lib/citations/citation-parser"; import { cn } from "@/lib/utils"; const code = createCodePlugin({ diff --git a/surfsense_web/hooks/use-agent-actions-query.ts b/surfsense_web/hooks/use-agent-actions-query.ts index 9a722fb2e..114c79567 100644 --- a/surfsense_web/hooks/use-agent-actions-query.ts +++ b/surfsense_web/hooks/use-agent-actions-query.ts @@ -88,71 +88,68 @@ export function applyActionLogSse( searchSpaceId, event, }); - queryClient.setQueryData<AgentActionListResponse>( - agentActionsQueryKey(threadId), - (prev) => { - const placeholder: AgentAction = { - id: event.id, - thread_id: threadId, - user_id: null, - search_space_id: searchSpaceId, - tool_name: event.tool_name, - args: null, - result_id: null, - reversible: event.reversible, - reverse_descriptor: event.reverse_descriptor_present ? {} : null, - error: event.error ? {} : null, - reverse_of: null, - reverted_by_action_id: null, - is_revert_action: false, - tool_call_id: event.lc_tool_call_id, - chat_turn_id: event.chat_turn_id, - created_at: event.created_at ?? new Date().toISOString(), - }; - if (!prev) { - return { - items: [placeholder], - total: 1, - page: 0, - page_size: ACTION_LOG_PAGE_SIZE, - has_more: false, - }; - } - const existingIdx = prev.items.findIndex((a) => a.id === event.id); - if (existingIdx >= 0) { - const merged = [...prev.items]; - const existing = merged[existingIdx]; - if (existing) { - merged[existingIdx] = { - ...existing, - reversible: event.reversible, - tool_call_id: event.lc_tool_call_id ?? existing.tool_call_id, - chat_turn_id: event.chat_turn_id ?? existing.chat_turn_id, - }; - } - dbg("applyActionLogSse: merged into existing entry", { - id: event.id, - tool_call_id: merged[existingIdx]?.tool_call_id, - reversible: merged[existingIdx]?.reversible, - }); - return { ...prev, items: merged }; - } - dbg("applyActionLogSse: appended new placeholder", { - id: event.id, - tool_call_id: placeholder.tool_call_id, - tool_name: placeholder.tool_name, - reversible: placeholder.reversible, - cacheSizeAfter: prev.items.length + 1, - }); - // REST returns newest-first — keep that ordering when - // the server eventually refetches by prepending. + queryClient.setQueryData<AgentActionListResponse>(agentActionsQueryKey(threadId), (prev) => { + const placeholder: AgentAction = { + id: event.id, + thread_id: threadId, + user_id: null, + search_space_id: searchSpaceId, + tool_name: event.tool_name, + args: null, + result_id: null, + reversible: event.reversible, + reverse_descriptor: event.reverse_descriptor_present ? {} : null, + error: event.error ? {} : null, + reverse_of: null, + reverted_by_action_id: null, + is_revert_action: false, + tool_call_id: event.lc_tool_call_id, + chat_turn_id: event.chat_turn_id, + created_at: event.created_at ?? new Date().toISOString(), + }; + if (!prev) { return { - ...prev, - items: [placeholder, ...prev.items], - total: prev.total + 1, + items: [placeholder], + total: 1, + page: 0, + page_size: ACTION_LOG_PAGE_SIZE, + has_more: false, }; } - ); + const existingIdx = prev.items.findIndex((a) => a.id === event.id); + if (existingIdx >= 0) { + const merged = [...prev.items]; + const existing = merged[existingIdx]; + if (existing) { + merged[existingIdx] = { + ...existing, + reversible: event.reversible, + tool_call_id: event.lc_tool_call_id ?? existing.tool_call_id, + chat_turn_id: event.chat_turn_id ?? existing.chat_turn_id, + }; + } + dbg("applyActionLogSse: merged into existing entry", { + id: event.id, + tool_call_id: merged[existingIdx]?.tool_call_id, + reversible: merged[existingIdx]?.reversible, + }); + return { ...prev, items: merged }; + } + dbg("applyActionLogSse: appended new placeholder", { + id: event.id, + tool_call_id: placeholder.tool_call_id, + tool_name: placeholder.tool_name, + reversible: placeholder.reversible, + cacheSizeAfter: prev.items.length + 1, + }); + // REST returns newest-first — keep that ordering when + // the server eventually refetches by prepending. + return { + ...prev, + items: [placeholder, ...prev.items], + total: prev.total + 1, + }; + }); } /** @@ -170,33 +167,30 @@ export function applyActionLogUpdatedSse( id, reversible, }); - queryClient.setQueryData<AgentActionListResponse>( - agentActionsQueryKey(threadId), - (prev) => { - if (!prev) { - dbg("applyActionLogUpdatedSse: NO prev cache for thread; flip dropped", { - threadId, - id, - }); - return prev; - } - let mutated = false; - const items = prev.items.map((a) => { - if (a.id !== id) return a; - mutated = true; - return { ...a, reversible }; + queryClient.setQueryData<AgentActionListResponse>(agentActionsQueryKey(threadId), (prev) => { + if (!prev) { + dbg("applyActionLogUpdatedSse: NO prev cache for thread; flip dropped", { + threadId, + id, }); - if (!mutated) { - dbg("applyActionLogUpdatedSse: id not in cache; flip dropped", { - threadId, - id, - cacheSize: prev.items.length, - cacheIds: prev.items.map((a) => a.id), - }); - } - return mutated ? { ...prev, items } : prev; + return prev; } - ); + let mutated = false; + const items = prev.items.map((a) => { + if (a.id !== id) return a; + mutated = true; + return { ...a, reversible }; + }); + if (!mutated) { + dbg("applyActionLogUpdatedSse: id not in cache; flip dropped", { + threadId, + id, + cacheSize: prev.items.length, + cacheIds: prev.items.map((a) => a.id), + }); + } + return mutated ? { ...prev, items } : prev; + }); } /** @@ -214,24 +208,21 @@ export function markActionRevertedInCache( id: number, newActionId: number | null ): void { - queryClient.setQueryData<AgentActionListResponse>( - agentActionsQueryKey(threadId), - (prev) => { - if (!prev) return prev; - let mutated = false; - const items = prev.items.map((a) => { - if (a.id !== id) return a; - mutated = true; - // ``-1`` is a sentinel meaning "we know it was reverted - // but the server didn't tell us the new row's id". - return { - ...a, - reverted_by_action_id: newActionId ?? -1, - }; - }); - return mutated ? { ...prev, items } : prev; - } - ); + queryClient.setQueryData<AgentActionListResponse>(agentActionsQueryKey(threadId), (prev) => { + if (!prev) return prev; + let mutated = false; + const items = prev.items.map((a) => { + if (a.id !== id) return a; + mutated = true; + // ``-1`` is a sentinel meaning "we know it was reverted + // but the server didn't tell us the new row's id". + return { + ...a, + reverted_by_action_id: newActionId ?? -1, + }; + }); + return mutated ? { ...prev, items } : prev; + }); } /** @@ -245,21 +236,18 @@ export function applyRevertTurnResultsToCache( entries: Array<{ id: number; newActionId: number | null }> ): void { if (entries.length === 0) return; - queryClient.setQueryData<AgentActionListResponse>( - agentActionsQueryKey(threadId), - (prev) => { - if (!prev) return prev; - const lookup = new Map(entries.map((e) => [e.id, e.newActionId])); - let mutated = false; - const items = prev.items.map((a) => { - if (!lookup.has(a.id)) return a; - mutated = true; - const newActionId = lookup.get(a.id) ?? null; - return { ...a, reverted_by_action_id: newActionId ?? -1 }; - }); - return mutated ? { ...prev, items } : prev; - } - ); + queryClient.setQueryData<AgentActionListResponse>(agentActionsQueryKey(threadId), (prev) => { + if (!prev) return prev; + const lookup = new Map(entries.map((e) => [e.id, e.newActionId])); + let mutated = false; + const items = prev.items.map((a) => { + if (!lookup.has(a.id)) return a; + mutated = true; + const newActionId = lookup.get(a.id) ?? null; + return { ...a, reverted_by_action_id: newActionId ?? -1 }; + }); + return mutated ? { ...prev, items } : prev; + }); } /** @@ -271,10 +259,7 @@ export function applyRevertTurnResultsToCache( * knob — pass ``false`` to keep the query dormant when the consumer * doesn't yet have a thread id. */ -export function useAgentActionsQuery( - threadId: number | null, - options: { enabled?: boolean } = {} -) { +export function useAgentActionsQuery(threadId: number | null, options: { enabled?: boolean } = {}) { const enabled = (options.enabled ?? true) && threadId !== null; const query = useQuery({ queryKey: agentActionsQueryKey(threadId), @@ -336,10 +321,7 @@ export function useAgentActionsQuery( else m.set(key, [a]); } for (const bucket of m.values()) { - bucket.sort( - (a, b) => - new Date(a.created_at).getTime() - new Date(b.created_at).getTime() - ); + bucket.sort((a, b) => new Date(a.created_at).getTime() - new Date(b.created_at).getTime()); } return m; }, [items]); @@ -396,10 +378,7 @@ export function useAgentActionsQuery( ); const findByChatTurnAndTool = useCallback( - ( - chatTurnId: string | null | undefined, - toolName: string | null | undefined - ): AgentAction[] => { + (chatTurnId: string | null | undefined, toolName: string | null | undefined): AgentAction[] => { if (!chatTurnId || !toolName) return []; return byTurnAndTool.get(`${chatTurnId}::${toolName}`) ?? []; }, diff --git a/surfsense_web/lib/chat/chat-error-classifier.ts b/surfsense_web/lib/chat/chat-error-classifier.ts index 7dfbfc1a1..95d9848f2 100644 --- a/surfsense_web/lib/chat/chat-error-classifier.ts +++ b/surfsense_web/lib/chat/chat-error-classifier.ts @@ -53,7 +53,10 @@ function getErrorMessage(error: unknown): string { } } -function getErrorCode(error: unknown, parsedJson: Record<string, unknown> | null): string | undefined { +function getErrorCode( + error: unknown, + parsedJson: Record<string, unknown> | null +): string | undefined { if (error instanceof Error) { const withCode = error as Error & { errorCode?: string; code?: string }; if (withCode.errorCode) return withCode.errorCode; @@ -138,8 +141,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError severity: "info", telemetryEvent: "chat_blocked", isExpected: true, - userMessage: - "Buy more tokens to continue with this model, or switch to a free model.", + userMessage: "Buy more tokens to continue with this model, or switch to a free model.", assistantMessage: PREMIUM_QUOTA_ASSISTANT_MESSAGE, rawMessage, errorCode: errorCode ?? "PREMIUM_QUOTA_EXHAUSTED", @@ -147,9 +149,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } - if ( - errorCode === "TURN_CANCELLING" - ) { + if (errorCode === "TURN_CANCELLING") { return { kind: "thread_busy", channel: "toast", @@ -163,16 +163,15 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } - if ( - errorCode === "THREAD_BUSY" - ) { + if (errorCode === "THREAD_BUSY") { return { kind: "thread_busy", channel: "toast", severity: "warn", telemetryEvent: "chat_blocked", isExpected: true, - userMessage: "Another response is still finishing for this thread. Please try again in a moment.", + userMessage: + "Another response is still finishing for this thread. Please try again in a moment.", rawMessage, errorCode: errorCode ?? "THREAD_BUSY", details: { flow: input.flow }, @@ -193,10 +192,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } - if ( - errorCode === "AUTH_EXPIRED" || - errorCode === "UNAUTHORIZED" - ) { + if (errorCode === "AUTH_EXPIRED" || errorCode === "UNAUTHORIZED") { return { kind: "auth_expired", channel: "toast", @@ -210,10 +206,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } - if ( - errorCode === "RATE_LIMITED" || - providerTypeNormalized === "rate_limit_error" - ) { + if (errorCode === "RATE_LIMITED" || providerTypeNormalized === "rate_limit_error") { return { kind: "rate_limited", channel: "toast", @@ -242,9 +235,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } - if ( - errorCode === "STREAM_PARSE_ERROR" - ) { + if (errorCode === "STREAM_PARSE_ERROR") { return { kind: "stream_parse_error", channel: "toast", @@ -258,9 +249,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } - if ( - errorCode === "TOOL_EXECUTION_ERROR" - ) { + if (errorCode === "TOOL_EXECUTION_ERROR") { return { kind: "tool_execution_error", channel: "toast", @@ -274,9 +263,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } - if ( - errorCode === "PERSIST_MESSAGE_FAILED" - ) { + if (errorCode === "PERSIST_MESSAGE_FAILED") { return { kind: "persist_message_failed", channel: "toast", @@ -290,9 +277,7 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } - if ( - errorCode === "SERVER_ERROR" - ) { + if (errorCode === "SERVER_ERROR") { return { kind: "server_error", channel: "toast", diff --git a/surfsense_web/lib/chat/chat-request-errors.ts b/surfsense_web/lib/chat/chat-request-errors.ts index 708831354..e0dfb3cc4 100644 --- a/surfsense_web/lib/chat/chat-request-errors.ts +++ b/surfsense_web/lib/chat/chat-request-errors.ts @@ -74,13 +74,9 @@ export async function toHttpResponseError( : Number.isFinite(retryAfterSeconds) ? Math.max(0, Math.round(retryAfterSeconds * 1000)) : undefined; - const retryAfterMs = - detailRetryAfterMs ?? topRetryAfterMs ?? retryAfterMsFromHeader ?? undefined; + const retryAfterMs = detailRetryAfterMs ?? topRetryAfterMs ?? retryAfterMsFromHeader ?? undefined; const message = - detailNestedMessage ?? - detailMessage ?? - topLevelMessage ?? - `Backend error: ${response.status}`; + detailNestedMessage ?? detailMessage ?? topLevelMessage ?? `Backend error: ${response.status}`; return Object.assign(new Error(message), { errorCode, retryAfterMs }); } diff --git a/surfsense_web/lib/chat/stream-pipeline.ts b/surfsense_web/lib/chat/stream-pipeline.ts index c9118f949..c76781083 100644 --- a/surfsense_web/lib/chat/stream-pipeline.ts +++ b/surfsense_web/lib/chat/stream-pipeline.ts @@ -72,8 +72,12 @@ function toStreamTerminalError( }); } -export function processSharedStreamEvent(parsed: SSEEvent, context: SharedStreamEventContext): boolean { - const { contentPartsState, toolsWithUI, currentThinkingSteps, scheduleFlush, forceFlush } = context; +export function processSharedStreamEvent( + parsed: SSEEvent, + context: SharedStreamEventContext +): boolean { + const { contentPartsState, toolsWithUI, currentThinkingSteps, scheduleFlush, forceFlush } = + context; const { contentParts, toolCallIndices } = contentPartsState; switch (parsed.type) { diff --git a/surfsense_web/lib/chat/stream-side-effects.ts b/surfsense_web/lib/chat/stream-side-effects.ts index 9cb349458..5483ff14b 100644 --- a/surfsense_web/lib/chat/stream-side-effects.ts +++ b/surfsense_web/lib/chat/stream-side-effects.ts @@ -16,9 +16,7 @@ export type EditedInterruptAction = { args: Record<string, unknown>; }; -function readInterruptActions( - interruptData: Record<string, unknown> -): InterruptActionRequest[] { +function readInterruptActions(interruptData: Record<string, unknown>): InterruptActionRequest[] { return (interruptData.action_requests ?? []) as InterruptActionRequest[]; } @@ -121,7 +119,5 @@ export function applyTurnIdToAssistantMessageList( assistantMsgId: string, turnId: string ): ThreadMessageLike[] { - return messages.map((m) => - m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, turnId) : m - ); + return messages.map((m) => (m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, turnId) : m)); } diff --git a/surfsense_web/lib/citations/citation-parser.ts b/surfsense_web/lib/citations/citation-parser.ts index 6333b0f97..533c644c2 100644 --- a/surfsense_web/lib/citations/citation-parser.ts +++ b/surfsense_web/lib/citations/citation-parser.ts @@ -40,8 +40,7 @@ export interface PreprocessedCitations { } /** Pattern matching only URL-form citations (used during preprocessing). */ -const URL_CITATION_REGEX = - /[[【]\u200B?citation:\s*(https?:\/\/[^\]】\u200B]+)\s*\u200B?[\]】]/g; +const URL_CITATION_REGEX = /[[【]\u200B?citation:\s*(https?:\/\/[^\]】\u200B]+)\s*\u200B?[\]】]/g; /** * Replace `[citation:URL]` tokens with `[citation:urlciteN]` placeholders so @@ -82,10 +81,7 @@ export function preprocessCitationMarkdown(content: string): PreprocessedCitatio * tokens to JSX. Negative chunk IDs are forwarded as-is so the consumer * can decide how to render anonymous documents. */ -export function parseTextWithCitations( - text: string, - urlMap: CitationUrlMap -): ParsedSegment[] { +export function parseTextWithCitations(text: string, urlMap: CitationUrlMap): ParsedSegment[] { const segments: ParsedSegment[] = []; let lastIndex = 0; let match: RegExpExecArray | null; diff --git a/surfsense_web/lib/posthog/events.ts b/surfsense_web/lib/posthog/events.ts index 30e58215a..f9eb6b312 100644 --- a/surfsense_web/lib/posthog/events.ts +++ b/surfsense_web/lib/posthog/events.ts @@ -1,6 +1,6 @@ import posthog from "posthog-js"; import { getConnectorTelemetryMeta } from "@/components/assistant-ui/connector-popup/constants/connector-constants"; -import type { ChatErrorKind, ChatFlow, ChatErrorSeverity } from "@/lib/chat/chat-error-classifier"; +import type { ChatErrorKind, ChatErrorSeverity, ChatFlow } from "@/lib/chat/chat-error-classifier"; /** * PostHog Analytics Event Definitions From 1efed5e489763a655eba5fa3ed86c2d0dd4fa800 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Thu, 30 Apr 2026 20:28:41 -0700 Subject: [PATCH 40/68] chore: add debug environment variables for macOS codesigning troubleshooting --- .github/workflows/desktop-release.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/desktop-release.yml b/.github/workflows/desktop-release.yml index e356bd3e5..ad1c128bc 100644 --- a/.github/workflows/desktop-release.yml +++ b/.github/workflows/desktop-release.yml @@ -144,6 +144,11 @@ jobs: APPLE_ID: ${{ secrets.APPLE_ID }} APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }} APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }} + # TEMP DEBUG — remove once the codesign hang on macos-latest is diagnosed. + # Surfaces the exact codesign / notarize commands electron-builder spawns, + # so we can see which subprocess hangs. + DEBUG: electron-builder,electron-osx-sign*,@electron/notarize* + ELECTRON_BUILDER_ALLOW_UNRESOLVED_DEPENDENCIES: "true" # Service principal credentials for Azure.Identity EnvironmentCredential used by the # TrustedSigning PowerShell module. Only populated when signing is enabled. # electron-builder 26 does not yet support OIDC federated tokens for Azure signing, From 360b5f8e3ad8056e6db171a6cb34fe5a7899dee4 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Thu, 30 Apr 2026 20:47:30 -0700 Subject: [PATCH 41/68] chore: update environment variables for improved macOS codesigning debugging --- .github/workflows/notary-status.yml | 60 +++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 .github/workflows/notary-status.yml diff --git a/.github/workflows/notary-status.yml b/.github/workflows/notary-status.yml new file mode 100644 index 000000000..5c7c42038 --- /dev/null +++ b/.github/workflows/notary-status.yml @@ -0,0 +1,60 @@ +name: Notary status check + +# One-off diagnostic workflow. Queries Apple's notary service to see if your +# submissions are queued, in progress, accepted, or rejected. Useful when a +# notarization seems "hung" — most often the queue itself, especially on a +# brand-new Apple Developer account. +# +# Run via: Actions tab -> "Notary status check" -> Run workflow. +# Inputs are optional; if you provide a submission ID, it also fetches that +# submission's full Apple log. +# +# Safe to delete after diagnosis. + +on: + workflow_dispatch: + inputs: + submission_id: + description: 'Optional: submission UUID to fetch full Apple log for' + required: false + default: '' + +jobs: + status: + runs-on: macos-latest + steps: + - name: List recent notarization submissions + env: + APPLE_ID: ${{ secrets.APPLE_ID }} + APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }} + APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }} + run: | + set -euo pipefail + echo "::group::Submission history (most recent first)" + xcrun notarytool history \ + --apple-id "$APPLE_ID" \ + --password "$APPLE_APP_SPECIFIC_PASSWORD" \ + --team-id "$APPLE_TEAM_ID" + echo "::endgroup::" + + - name: Inspect specific submission (if id provided) + if: ${{ inputs.submission_id != '' }} + env: + APPLE_ID: ${{ secrets.APPLE_ID }} + APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }} + APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }} + SUBMISSION_ID: ${{ inputs.submission_id }} + run: | + set -euo pipefail + echo "::group::Submission info" + xcrun notarytool info "$SUBMISSION_ID" \ + --apple-id "$APPLE_ID" \ + --password "$APPLE_APP_SPECIFIC_PASSWORD" \ + --team-id "$APPLE_TEAM_ID" + echo "::endgroup::" + echo "::group::Apple's processing log for this submission" + xcrun notarytool log "$SUBMISSION_ID" \ + --apple-id "$APPLE_ID" \ + --password "$APPLE_APP_SPECIFIC_PASSWORD" \ + --team-id "$APPLE_TEAM_ID" || true + echo "::endgroup::" From e57c3a7d0c0f4f1fbe29382a97635fb01e5db44a Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" <vermarohanfinal@gmail.com> Date: Fri, 1 May 2026 05:10:53 -0700 Subject: [PATCH 42/68] feat: prompt caching - Updated `litellm` dependency version from `1.83.4` to `1.83.7`. - Adjusted `aiohttp` version from `3.13.5` to `3.13.4` in the lock file. - Implemented `apply_litellm_prompt_caching` in `chat_deepagent.py` to improve prompt caching. - Added model name resolution logic in `chat_deepagent.py` to ensure correct provider-variant dispatch. - Enhanced `llm_config.py` to configure prompt caching for various LLM providers. - Updated tests to verify correct model name forwarding and prompt caching behavior. --- .../app/agents/new_chat/chat_deepagent.py | 60 ++- .../app/agents/new_chat/llm_config.py | 22 +- .../app/agents/new_chat/prompt_caching.py | 166 +++++++++ .../app/services/llm_router_service.py | 35 +- surfsense_backend/pyproject.toml | 2 +- .../agents/new_chat/prompts/test_composer.py | 25 ++ .../agents/new_chat/test_prompt_caching.py | 350 ++++++++++++++++++ .../test_resolve_prompt_model_name.py | 117 ++++++ .../unit/test_stream_new_chat_contract.py | 36 ++ surfsense_backend/uv.lock | 160 ++++---- .../components/pricing/pricing-section.tsx | 1 - .../settings/more-pages-content.tsx | 59 +-- 12 files changed, 877 insertions(+), 156 deletions(-) create mode 100644 surfsense_backend/app/agents/new_chat/prompt_caching.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py create mode 100644 surfsense_backend/tests/unit/agents/new_chat/test_resolve_prompt_model_name.py diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index fdd72ea92..c0e9a3b96 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -10,7 +10,9 @@ We use ``create_agent`` (from langchain) rather than ``create_deep_agent`` This lets us swap in ``SurfSenseFilesystemMiddleware`` — a customisable subclass of the default ``FilesystemMiddleware`` — while preserving every other behaviour that ``create_deep_agent`` provides (todo-list, subagents, -summarisation, prompt-caching, etc.). +summarisation, etc.). Prompt caching is configured at LLM-build time via +``apply_litellm_prompt_caching`` (LiteLLM-native, multi-provider) rather +than as a middleware. """ import asyncio @@ -33,7 +35,6 @@ from langchain.agents.middleware import ( TodoListMiddleware, ToolCallLimitMiddleware, ) -from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool from langgraph.types import Checkpointer @@ -74,6 +75,7 @@ from app.agents.new_chat.plugin_loader import ( load_allowed_plugin_names_from_env, load_plugin_middlewares, ) +from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching from app.agents.new_chat.subagents import build_specialized_subagents from app.agents.new_chat.system_prompt import ( build_configurable_system_prompt, @@ -94,6 +96,39 @@ from app.utils.perf import get_perf_logger _perf_log = get_perf_logger() + +def _resolve_prompt_model_name( + agent_config: AgentConfig | None, + llm: BaseChatModel, +) -> str | None: + """Resolve the model id to feed to provider-variant detection. + + Preference order (matches the established idiom in + ``llm_router_service.py`` — see ``params.get("base_model") or + params.get("model", "")`` usages there): + + 1. ``agent_config.litellm_params["base_model"]`` — required for Azure + deployments where ``model_name`` is the deployment slug, not the + underlying family. Without this, a deployment named e.g. + ``"prod-chat-001"`` would silently miss every provider regex. + 2. ``agent_config.model_name`` — the user's configured model id. + 3. ``getattr(llm, "model", None)`` — fallback for direct callers that + don't supply an ``AgentConfig`` (currently a defensive path; all + production callers pass ``agent_config``). + + Returns ``None`` when nothing is available; ``compose_system_prompt`` + treats that as the ``"default"`` variant (no provider block emitted). + """ + if agent_config is not None: + params = agent_config.litellm_params or {} + base_model = params.get("base_model") + if isinstance(base_model, str) and base_model.strip(): + return base_model + if agent_config.model_name: + return agent_config.model_name + return getattr(llm, "model", None) + + # ============================================================================= # Connector Type Mapping # ============================================================================= @@ -279,6 +314,14 @@ async def create_surfsense_deep_agent( ) """ _t_agent_total = time.perf_counter() + + # Layer thread-aware prompt caching onto the LLM. Idempotent with the + # build-time call in ``llm_config.py``; this run merely adds + # ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` for OpenAI-family + # configs now that ``thread_id`` is known. No-op when ``thread_id`` is + # None or the provider is non-OpenAI-family. + apply_litellm_prompt_caching(llm, agent_config=agent_config, thread_id=thread_id) + filesystem_selection = filesystem_selection or FilesystemSelection() backend_resolver = build_backend_resolver( filesystem_selection, @@ -398,6 +441,7 @@ async def create_surfsense_deep_agent( enabled_tool_names=_enabled_tool_names, disabled_tool_names=_user_disabled_tool_names, mcp_connector_tools=_mcp_connector_tools, + model_name=_resolve_prompt_model_name(agent_config, llm), ) else: system_prompt = build_surfsense_system_prompt( @@ -405,6 +449,7 @@ async def create_surfsense_deep_agent( enabled_tool_names=_enabled_tool_names, disabled_tool_names=_user_disabled_tool_names, mcp_connector_tools=_mcp_connector_tools, + model_name=_resolve_prompt_model_name(agent_config, llm), ) _perf_log.info( "[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0 @@ -568,7 +613,6 @@ def _build_compiled_agent_blocking( ), create_surfsense_compaction_middleware(llm, StateBackend), PatchToolCallsMiddleware(), - AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"), ] general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key] @@ -1006,12 +1050,12 @@ def _build_compiled_agent_blocking( action_log_mw, PatchToolCallsMiddleware(), DedupHITLToolCallsMiddleware(agent_tools=list(tools)), - # Plugin slot — sits just before AnthropicCache so plugin-side - # transforms see the final tool result and run before any - # caching heuristics. Multiple plugins in declared order; loader - # filtered by the admin allowlist already. + # Plugin slot — sits at the tail so plugin-side transforms see the + # final tool result. Prompt caching is now applied at LLM build time + # via ``apply_litellm_prompt_caching`` (see prompt_caching.py), so no + # caching middleware is needed here. Multiple plugins run in declared + # order; loader filtered by the admin allowlist already. *plugin_middlewares, - AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"), ] deepagent_middleware = [m for m in deepagent_middleware if m is not None] diff --git a/surfsense_backend/app/agents/new_chat/llm_config.py b/surfsense_backend/app/agents/new_chat/llm_config.py index 58d8f84d0..99bb719f6 100644 --- a/surfsense_backend/app/agents/new_chat/llm_config.py +++ b/surfsense_backend/app/agents/new_chat/llm_config.py @@ -27,6 +27,7 @@ from litellm import get_model_info from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching from app.services.llm_router_service import ( AUTO_MODE_ID, ChatLiteLLMRouter, @@ -494,6 +495,11 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: llm = SanitizedChatLiteLLM(**litellm_kwargs) _attach_model_profile(llm, model_string) + # Configure LiteLLM-native prompt caching (cache_control_injection_points + # for Anthropic/Bedrock/Vertex/Gemini/Azure-AI/OpenRouter/Databricks/etc.). + # ``agent_config=None`` here — the YAML path doesn't have provider intent + # in a structured form, so we set only the universal injection points. + apply_litellm_prompt_caching(llm) return llm @@ -518,7 +524,16 @@ def create_chat_litellm_from_agent_config( print("Error: Auto mode requested but LLM Router not initialized") return None try: - return get_auto_mode_llm() + router_llm = get_auto_mode_llm() + if router_llm is not None: + # Universal cache_control_injection_points only — auto-mode + # fans out across providers, so OpenAI-only kwargs (e.g. + # ``prompt_cache_key``) are left off here. ``drop_params`` + # would strip them at the provider boundary anyway, but + # there's no point setting them when we don't know the + # destination. + apply_litellm_prompt_caching(router_llm, agent_config=agent_config) + return router_llm except Exception as e: print(f"Error creating ChatLiteLLMRouter: {e}") return None @@ -549,4 +564,9 @@ def create_chat_litellm_from_agent_config( llm = SanitizedChatLiteLLM(**litellm_kwargs) _attach_model_profile(llm, model_string) + # Build-time prompt caching: sets ``cache_control_injection_points`` for + # all providers and (for OpenAI/DeepSeek/xAI) ``prompt_cache_retention``. + # Per-thread ``prompt_cache_key`` is layered on later in + # ``create_surfsense_deep_agent`` once ``thread_id`` is known. + apply_litellm_prompt_caching(llm, agent_config=agent_config) return llm diff --git a/surfsense_backend/app/agents/new_chat/prompt_caching.py b/surfsense_backend/app/agents/new_chat/prompt_caching.py new file mode 100644 index 000000000..86bc57725 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompt_caching.py @@ -0,0 +1,166 @@ +"""LiteLLM-native prompt caching configuration for SurfSense agents. + +Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never +activated for our LiteLLM-based stack — its ``isinstance(model, ChatAnthropic)`` +gate always failed) with LiteLLM's universal caching mechanism. + +Coverage: + +- Marker-based providers (need ``cache_control`` injection, which LiteLLM + performs automatically when ``cache_control_injection_points`` is set): + ``anthropic/``, ``bedrock/``, ``vertex_ai/``, ``gemini/``, ``azure_ai/``, + ``openrouter/`` (Claude/Gemini/MiniMax/GLM/z-ai routes), ``databricks/`` + (Claude), ``dashscope/`` (Qwen), ``minimax/``, ``zai/`` (GLM). +- Auto-cached (LiteLLM strips the marker silently): ``openai/``, + ``deepseek/``, ``xai/`` — these caches automatically for prompts ≥1024 + tokens and surface ``prompt_cache_key`` / ``prompt_cache_retention``. + +We inject **two** breakpoints per request: + +- ``role: system`` — pins the SurfSense system prompt (provider variant, + citation rules, tool catalog, KB tree, skills metadata) into the cache. +- ``index: -1`` — pins the latest message so multi-turn savings compound: + Anthropic-family providers use longest-matching-prefix lookup, so turn + N+1 still reads turn N's cache up to the shared prefix. + +For OpenAI-family configs we additionally pass: + +- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` — routing hint that + raises hit rate by sending requests with a shared prefix to the same + backend. +- ``prompt_cache_retention="24h"`` — extends cache TTL beyond the default + 5-10 min in-memory cache. + +Safety net: ``litellm.drop_params=True`` is set globally in +``app.services.llm_service`` at module-load time. Any kwarg the destination +provider doesn't recognise is auto-stripped at the provider transformer +layer, so an OpenAI→Bedrock auto-mode fallback can't 400 on +``prompt_cache_key`` etc. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from langchain_core.language_models import BaseChatModel + +if TYPE_CHECKING: + from app.agents.new_chat.llm_config import AgentConfig + +logger = logging.getLogger(__name__) + + +# Two-breakpoint policy: system + latest message. See module docstring for +# rationale. Anthropic limits requests to 4 ``cache_control`` blocks; we +# use 2 here, leaving headroom for Phase-2 tool caching. +_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = ( + {"location": "message", "role": "system"}, + {"location": "message", "index": -1}, +) + +# Providers (uppercase ``AgentConfig.provider`` values) that natively expose +# OpenAI-style automatic prompt caching with ``prompt_cache_key`` and +# ``prompt_cache_retention`` kwargs. Strict whitelist — many other providers +# in ``PROVIDER_MAP`` route through litellm's ``openai`` prefix without +# implementing the OpenAI prompt-cache surface (e.g. MOONSHOT, ZHIPU, +# MINIMAX), so we can't infer family from the litellm prefix alone. +_OPENAI_FAMILY_PROVIDERS: frozenset[str] = frozenset({"OPENAI", "DEEPSEEK", "XAI"}) + + +def _is_router_llm(llm: BaseChatModel) -> bool: + """Detect ``ChatLiteLLMRouter`` (auto-mode) without an eager import. + + Importing ``app.services.llm_router_service`` at module-load time would + create a cycle via ``llm_config -> prompt_caching -> llm_router_service``. + Class-name comparison is sufficient since the class is defined in a + single place. + """ + return type(llm).__name__ == "ChatLiteLLMRouter" + + +def _is_openai_family_config(agent_config: AgentConfig | None) -> bool: + """Whether the config targets an OpenAI-style prompt-cache surface. + + Strict — only returns True when the user explicitly chose OPENAI, + DEEPSEEK, or XAI as the provider in their ``NewLLMConfig`` / + ``YAMLConfig``. Auto-mode and custom providers return False because + we can't statically know the destination. + """ + if agent_config is None or not agent_config.provider: + return False + if agent_config.is_auto_mode: + return False + if agent_config.custom_provider: + return False + return agent_config.provider.upper() in _OPENAI_FAMILY_PROVIDERS + + +def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None: + """Return ``llm.model_kwargs`` as a writable dict, or ``None`` to bail. + + Initialises the field to ``{}`` when present-but-None on a Pydantic v2 + model. Returns ``None`` if the LLM type doesn't expose a writable + ``model_kwargs`` attribute (caller should treat as no-op). + """ + model_kwargs = getattr(llm, "model_kwargs", None) + if isinstance(model_kwargs, dict): + return model_kwargs + try: + llm.model_kwargs = {} # type: ignore[attr-defined] + except Exception: + return None + refreshed = getattr(llm, "model_kwargs", None) + return refreshed if isinstance(refreshed, dict) else None + + +def apply_litellm_prompt_caching( + llm: BaseChatModel, + *, + agent_config: AgentConfig | None = None, + thread_id: int | None = None, +) -> None: + """Configure LiteLLM prompt caching on a ChatLiteLLM/ChatLiteLLMRouter. + + Idempotent — values already present in ``llm.model_kwargs`` (e.g. from + ``agent_config.litellm_params`` overrides) are preserved. Mutates + ``llm.model_kwargs`` in place; the kwargs flow to ``litellm.completion`` + via ``ChatLiteLLM._default_params`` and via ``self.model_kwargs`` merge + in our custom ``ChatLiteLLMRouter``. + + Args: + llm: ChatLiteLLM, SanitizedChatLiteLLM, or ChatLiteLLMRouter instance. + agent_config: Optional ``AgentConfig`` driving provider-specific + behaviour. When omitted (or auto-mode), only the universal + ``cache_control_injection_points`` are set. + thread_id: Optional thread id used to construct a per-thread + ``prompt_cache_key`` for OpenAI-family providers. Caching still + works without it (server-side automatic), but the key improves + backend routing affinity and therefore hit rate. + """ + model_kwargs = _get_or_init_model_kwargs(llm) + if model_kwargs is None: + logger.debug( + "apply_litellm_prompt_caching: %s exposes no writable model_kwargs; skipping", + type(llm).__name__, + ) + return + + if "cache_control_injection_points" not in model_kwargs: + model_kwargs["cache_control_injection_points"] = [ + dict(point) for point in _DEFAULT_INJECTION_POINTS + ] + + # OpenAI-family extras only when we statically know the destination is + # OpenAI / DeepSeek / xAI. Auto-mode router fans out across providers + # so we can't safely set OpenAI-only kwargs there (drop_params would + # strip them but it's wasteful to set them in the first place). + if _is_router_llm(llm): + return + if not _is_openai_family_config(agent_config): + return + + if thread_id is not None and "prompt_cache_key" not in model_kwargs: + model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}" + if "prompt_cache_retention" not in model_kwargs: + model_kwargs["prompt_cache_retention"] = "24h" diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index 4bce79a43..fbd42b458 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -28,6 +28,7 @@ from litellm.exceptions import ( BadRequestError as LiteLLMBadRequestError, ContextWindowExceededError, ) +from pydantic import Field from app.utils.perf import get_perf_logger @@ -573,6 +574,11 @@ class ChatLiteLLMRouter(BaseChatModel): # Public attributes that Pydantic will manage model: str = "auto" streaming: bool = True + # Static kwargs that flow through to ``litellm.completion(...)`` on every + # invocation (e.g. ``cache_control_injection_points`` set by + # ``apply_litellm_prompt_caching``). Per-call ``**kwargs`` from + # ``invoke()`` still take precedence — see ``_generate``/``_astream``. + model_kwargs: dict[str, Any] = Field(default_factory=dict) # Bound tools and tool choice for tool calling _bound_tools: list[dict] | None = None @@ -898,13 +904,16 @@ class ChatLiteLLMRouter(BaseChatModel): logger.warning(f"Failed to convert tool {tool}: {e}") continue - # Create a new instance with tools bound + # Create a new instance with tools bound. Carry through ``model_kwargs`` + # so static settings (e.g. cache_control_injection_points) survive the + # bind_tools rebuild. return ChatLiteLLMRouter( router=self._router, bound_tools=formatted_tools if formatted_tools else None, tool_choice=tool_choice, model=self.model, streaming=self.streaming, + model_kwargs=dict(self.model_kwargs), **kwargs, ) @@ -929,8 +938,10 @@ class ChatLiteLLMRouter(BaseChatModel): formatted_messages = self._convert_messages(messages) formatted_messages = self._trim_messages_to_fit_context(formatted_messages) - # Add tools if bound - call_kwargs = {**kwargs} + # Merge static model_kwargs (e.g. cache_control_injection_points) under + # per-call kwargs so callers can still override per invocation. Then add + # bound tools. + call_kwargs = {**self.model_kwargs, **kwargs} if self._bound_tools: call_kwargs["tools"] = self._bound_tools if self._tool_choice is not None: @@ -997,8 +1008,10 @@ class ChatLiteLLMRouter(BaseChatModel): formatted_messages = self._convert_messages(messages) formatted_messages = self._trim_messages_to_fit_context(formatted_messages) - # Add tools if bound - call_kwargs = {**kwargs} + # Merge static model_kwargs (e.g. cache_control_injection_points) under + # per-call kwargs so callers can still override per invocation. Then add + # bound tools. + call_kwargs = {**self.model_kwargs, **kwargs} if self._bound_tools: call_kwargs["tools"] = self._bound_tools if self._tool_choice is not None: @@ -1060,8 +1073,10 @@ class ChatLiteLLMRouter(BaseChatModel): formatted_messages = self._convert_messages(messages) formatted_messages = self._trim_messages_to_fit_context(formatted_messages) - # Add tools if bound - call_kwargs = {**kwargs} + # Merge static model_kwargs (e.g. cache_control_injection_points) under + # per-call kwargs so callers can still override per invocation. Then add + # bound tools. + call_kwargs = {**self.model_kwargs, **kwargs} if self._bound_tools: call_kwargs["tools"] = self._bound_tools if self._tool_choice is not None: @@ -1110,8 +1125,10 @@ class ChatLiteLLMRouter(BaseChatModel): formatted_messages = self._convert_messages(messages) formatted_messages = self._trim_messages_to_fit_context(formatted_messages) - # Add tools if bound - call_kwargs = {**kwargs} + # Merge static model_kwargs (e.g. cache_control_injection_points) under + # per-call kwargs so callers can still override per invocation. Then add + # bound tools. + call_kwargs = {**self.model_kwargs, **kwargs} if self._bound_tools: call_kwargs["tools"] = self._bound_tools if self._tool_choice is not None: diff --git a/surfsense_backend/pyproject.toml b/surfsense_backend/pyproject.toml index 131627386..cd683e2e1 100644 --- a/surfsense_backend/pyproject.toml +++ b/surfsense_backend/pyproject.toml @@ -74,7 +74,7 @@ dependencies = [ "deepagents>=0.4.12", "stripe>=15.0.0", "azure-ai-documentintelligence>=1.0.2", - "litellm>=1.83.4", + "litellm>=1.83.7", "langchain-litellm>=0.6.4", ] diff --git a/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py b/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py index 397b1c787..36fe04aa2 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py +++ b/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py @@ -226,6 +226,31 @@ class TestCompose: # Default block should NOT be present assert "<knowledge_base_only_policy>" not in prompt + def test_provider_hints_render_with_custom_system_instructions( + self, fixed_today: datetime + ) -> None: + """Regression guard for the always-append decision: provider hints + append AFTER a custom system prompt. + + Provider hints are stylistic nudges (parallel tool-call rules, + formatting guidance, etc.) that help the model regardless of + what the system instructions say. Suppressing them when a + custom prompt is set would partially defeat the per-family + prompt machinery. + """ + prompt = compose_system_prompt( + today=fixed_today, + custom_system_instructions="You are a custom assistant.", + model_name="anthropic/claude-3-5-sonnet", + ) + assert "You are a custom assistant." in prompt + assert "<provider_hints>" in prompt + # The custom prompt must come BEFORE the provider hints so the + # user's framing isn't drowned out by the stylistic nudges. + assert prompt.index("You are a custom assistant.") < prompt.index( + "<provider_hints>" + ) + def test_use_default_false_with_no_custom_yields_no_system_block( self, fixed_today: datetime ) -> None: diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py b/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py new file mode 100644 index 000000000..5b3a03581 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_prompt_caching.py @@ -0,0 +1,350 @@ +"""Tests for ``apply_litellm_prompt_caching`` in +:mod:`app.agents.new_chat.prompt_caching`. + +The helper replaces the legacy ``AnthropicPromptCachingMiddleware`` (which +never activated for our LiteLLM stack) with LiteLLM-native multi-provider +prompt caching. It mutates ``llm.model_kwargs`` so the kwargs flow to +``litellm.completion(...)``. The tests below pin its public contract: + +1. Always sets BOTH ``role: system`` and ``index: -1`` injection points so + savings compound across multi-turn conversations on Anthropic-family + providers. +2. Adds ``prompt_cache_key``/``prompt_cache_retention`` only for + single-model OPENAI/DEEPSEEK/XAI configs (where OpenAI's automatic + prompt-cache surface is available). +3. Treats ``ChatLiteLLMRouter`` (auto-mode) as universal-only — no + OpenAI-only kwargs because the router fans out across providers. +4. Idempotent: user-supplied values in ``model_kwargs`` are preserved. +5. Defensive: LLMs without a writable ``model_kwargs`` are silently + skipped rather than raising. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from app.agents.new_chat.llm_config import AgentConfig +from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Test doubles +# --------------------------------------------------------------------------- + + +class _FakeLLM: + """Stand-in for ``ChatLiteLLM``/``SanitizedChatLiteLLM``. + + The helper only inspects ``getattr(llm, "model_kwargs", None)``, + ``getattr(llm, "model", None)``, and ``type(llm).__name__``. A simple + object suffices — we don't need to spin up real LangChain/LiteLLM + machinery for unit tests of the helper's logic. + """ + + def __init__( + self, + model: str = "openai/gpt-4o", + model_kwargs: dict[str, Any] | None = None, + ) -> None: + self.model = model + self.model_kwargs: dict[str, Any] = dict(model_kwargs) if model_kwargs else {} + + +class ChatLiteLLMRouter: + """Class-name-only impostor of the real router. + + The helper's router gate is ``type(llm).__name__ == "ChatLiteLLMRouter"`` + (a deliberate stringly-typed check to avoid an import cycle with + ``app.services.llm_router_service``). Reusing the same class name here + triggers the same code path without instantiating a real ``Router``. + """ + + def __init__(self) -> None: + self.model = "auto" + self.model_kwargs: dict[str, Any] = {} + + +def _make_cfg(**overrides: Any) -> AgentConfig: + """Build an ``AgentConfig`` with sensible defaults for the helper test.""" + defaults: dict[str, Any] = { + "provider": "OPENAI", + "model_name": "gpt-4o", + "api_key": "k", + } + return AgentConfig(**{**defaults, **overrides}) + + +# --------------------------------------------------------------------------- +# (a) Universal injection points +# --------------------------------------------------------------------------- + + +def test_sets_both_cache_control_injection_points_with_no_config() -> None: + """Bare call (no agent_config, no thread_id) still sets the two + universal breakpoints — these cost nothing on providers that don't + consume them and unlock caching on every supported provider.""" + llm = _FakeLLM() + + apply_litellm_prompt_caching(llm) + + points = llm.model_kwargs["cache_control_injection_points"] + assert {"location": "message", "role": "system"} in points + assert {"location": "message", "index": -1} in points + assert len(points) == 2 + + +def test_injection_points_set_for_anthropic_config() -> None: + """Anthropic-family configs need the marker — verify it lands.""" + cfg = _make_cfg(provider="ANTHROPIC", model_name="claude-3-5-sonnet") + llm = _FakeLLM(model="anthropic/claude-3-5-sonnet") + + apply_litellm_prompt_caching(llm, agent_config=cfg) + + assert "cache_control_injection_points" in llm.model_kwargs + + +# --------------------------------------------------------------------------- +# (b) Idempotency / user override wins +# --------------------------------------------------------------------------- + + +def test_does_not_overwrite_user_supplied_cache_control_injection_points() -> None: + """Users who set their own injection points (e.g. with ``ttl: "1h"`` + via ``litellm_params``) keep them — the helper merges, never + clobbers.""" + user_points = [ + {"location": "message", "role": "system", "ttl": "1h"}, + ] + llm = _FakeLLM( + model_kwargs={"cache_control_injection_points": user_points}, + ) + + apply_litellm_prompt_caching(llm) + + assert llm.model_kwargs["cache_control_injection_points"] is user_points + + +def test_idempotent_when_called_multiple_times() -> None: + """Build-time + thread-time double-call must be a no-op the second time.""" + cfg = _make_cfg(provider="OPENAI") + llm = _FakeLLM() + + apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=1) + snapshot = { + "cache_control_injection_points": list( + llm.model_kwargs["cache_control_injection_points"] + ), + "prompt_cache_key": llm.model_kwargs["prompt_cache_key"], + "prompt_cache_retention": llm.model_kwargs["prompt_cache_retention"], + } + apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=1) + + assert ( + llm.model_kwargs["cache_control_injection_points"] + == snapshot["cache_control_injection_points"] + ) + assert llm.model_kwargs["prompt_cache_key"] == snapshot["prompt_cache_key"] + assert ( + llm.model_kwargs["prompt_cache_retention"] == snapshot["prompt_cache_retention"] + ) + + +def test_does_not_overwrite_user_supplied_prompt_cache_key() -> None: + """A pre-set ``prompt_cache_key`` (e.g. tenant-aware override via + ``litellm_params``) wins over our default per-thread key.""" + cfg = _make_cfg(provider="OPENAI") + llm = _FakeLLM(model_kwargs={"prompt_cache_key": "tenant-abc"}) + + apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42) + + assert llm.model_kwargs["prompt_cache_key"] == "tenant-abc" + + +# --------------------------------------------------------------------------- +# (c) OpenAI-family extras (OPENAI / DEEPSEEK / XAI) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("provider", ["OPENAI", "DEEPSEEK", "XAI"]) +def test_sets_openai_family_extras(provider: str) -> None: + """OpenAI-style providers gain ``prompt_cache_key`` (raises hit rate + via routing affinity) and ``prompt_cache_retention="24h"`` (extends + cache TTL beyond the default 5-10 min).""" + cfg = _make_cfg(provider=provider) + llm = _FakeLLM() + + apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42) + + assert llm.model_kwargs["prompt_cache_key"] == "surfsense-thread-42" + assert llm.model_kwargs["prompt_cache_retention"] == "24h" + + +def test_skips_prompt_cache_key_when_no_thread_id() -> None: + """Without a thread id we can't construct a per-thread key. Retention + is still useful so we set it (it's free).""" + cfg = _make_cfg(provider="OPENAI") + llm = _FakeLLM() + + apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=None) + + assert "prompt_cache_key" not in llm.model_kwargs + assert llm.model_kwargs["prompt_cache_retention"] == "24h" + + +@pytest.mark.parametrize( + "provider", + ["ANTHROPIC", "BEDROCK", "VERTEX_AI", "GOOGLE_AI_STUDIO", "GROQ", "MOONSHOT"], +) +def test_no_openai_extras_for_other_providers(provider: str) -> None: + """Non-OpenAI-family providers don't expose ``prompt_cache_key`` — + skip it. ``cache_control_injection_points`` is still set (universal).""" + cfg = _make_cfg(provider=provider) + llm = _FakeLLM() + + apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42) + + assert "prompt_cache_key" not in llm.model_kwargs + assert "prompt_cache_retention" not in llm.model_kwargs + assert "cache_control_injection_points" in llm.model_kwargs + + +def test_no_openai_extras_in_auto_mode() -> None: + """Auto-mode fans out across mixed providers — we can't statically + target OpenAI-only kwargs.""" + cfg = AgentConfig.from_auto_mode() + llm = _FakeLLM() + + apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42) + + assert "prompt_cache_key" not in llm.model_kwargs + assert "prompt_cache_retention" not in llm.model_kwargs + assert "cache_control_injection_points" in llm.model_kwargs + + +def test_no_openai_extras_for_custom_provider() -> None: + """Custom providers route through arbitrary user-supplied prefixes — + we don't try to infer OpenAI-family compatibility.""" + cfg = _make_cfg(provider="OPENAI", custom_provider="my_proxy") + llm = _FakeLLM() + + apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42) + + assert "prompt_cache_key" not in llm.model_kwargs + assert "prompt_cache_retention" not in llm.model_kwargs + + +# --------------------------------------------------------------------------- +# (d) ChatLiteLLMRouter — universal injection points only +# --------------------------------------------------------------------------- + + +def test_router_llm_gets_only_universal_injection_points() -> None: + """Even with an OpenAI-flavoured config, a ``ChatLiteLLMRouter`` must + receive only the universal injection points — its requests dispatch + across provider deployments and OpenAI-only kwargs would be wasted + (or stripped by ``drop_params``) on non-OpenAI legs.""" + router = ChatLiteLLMRouter() + cfg = _make_cfg(provider="OPENAI") + + apply_litellm_prompt_caching(router, agent_config=cfg, thread_id=42) + + assert "cache_control_injection_points" in router.model_kwargs + assert "prompt_cache_key" not in router.model_kwargs + assert "prompt_cache_retention" not in router.model_kwargs + + +# --------------------------------------------------------------------------- +# (e) Defensive paths +# --------------------------------------------------------------------------- + + +def test_handles_llm_with_no_writable_model_kwargs() -> None: + """Some LLM implementations (e.g. fakes / minimal subclasses) don't + expose a writable ``model_kwargs``. The helper must skip silently — + raising would crash the entire LLM build path on a non-critical + optimisation.""" + + class _ImmutableLLM: + # ``__slots__`` blocks attribute creation, so ``setattr`` raises. + __slots__ = ("model",) + + def __init__(self) -> None: + self.model = "openai/gpt-4o" + + llm = _ImmutableLLM() + + apply_litellm_prompt_caching(llm) + + +def test_initialises_missing_model_kwargs_dict() -> None: + """When ``model_kwargs`` is present-but-None (Pydantic v2 default + pattern when no factory is set), the helper initialises it to an + empty dict before mutating.""" + + class _LazyLLM: + def __init__(self) -> None: + self.model = "openai/gpt-4o" + self.model_kwargs: dict[str, Any] | None = None + + llm = _LazyLLM() + + apply_litellm_prompt_caching(llm) + + assert isinstance(llm.model_kwargs, dict) + assert "cache_control_injection_points" in llm.model_kwargs + + +def test_falls_back_to_llm_model_prefix_when_no_agent_config() -> None: + """Direct caller path (e.g. ``create_chat_litellm_from_config`` for + YAML configs without a structured ``AgentConfig``): without + ``agent_config`` the helper sets only the universal injection points + — no OpenAI-family extras even if the prefix says ``openai/``. + Conservative: we'd rather miss the speedup than silently misroute.""" + llm = _FakeLLM(model="openai/gpt-4o") + + apply_litellm_prompt_caching(llm, agent_config=None, thread_id=99) + + assert "cache_control_injection_points" in llm.model_kwargs + assert "prompt_cache_key" not in llm.model_kwargs + assert "prompt_cache_retention" not in llm.model_kwargs + + +# --------------------------------------------------------------------------- +# (f) drop_params safety net (regression guard for #19346) +# --------------------------------------------------------------------------- + + +def test_litellm_drop_params_is_globally_enabled() -> None: + """``litellm.drop_params=True`` is set globally in + :mod:`app.services.llm_service` so any ``prompt_cache_key`` / + ``prompt_cache_retention`` we set on an OpenAI-family config is + auto-stripped if the request later routes to a non-supporting + provider (e.g. via auto-mode router fallback). This test pins that + invariant — losing it would mean Bedrock/Vertex 400s on ``prompt_cache_key``. + """ + import litellm + + import app.services.llm_service # noqa: F401 (side-effect: sets globals) + + assert litellm.drop_params is True + + +# --------------------------------------------------------------------------- +# Regression note: LiteLLM #15696 (multi-content-block last message) +# --------------------------------------------------------------------------- +# +# Before LiteLLM 1.81 a list-form last message ``[block_a, block_b]`` +# would get ``cache_control`` applied to *every* content block instead +# of only the last one — wasting cache breakpoints and triggering 400s +# on Anthropic when it exceeded the 4-breakpoint limit. Fixed in +# https://github.com/BerriAI/litellm/pull/15699. +# +# We pin ``litellm>=1.83.7`` in ``pyproject.toml`` (well past the fix). +# An end-to-end behavioural test would need to run ``litellm.completion`` +# through the Anthropic transformer, which is integration territory and +# better covered by LiteLLM's own test suite. The unit guard here is the +# version pin plus the build-time ``model_kwargs`` shape we verify above. diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_resolve_prompt_model_name.py b/surfsense_backend/tests/unit/agents/new_chat/test_resolve_prompt_model_name.py new file mode 100644 index 000000000..ffe3dbaa4 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_resolve_prompt_model_name.py @@ -0,0 +1,117 @@ +"""Tests for ``_resolve_prompt_model_name`` in :mod:`app.agents.new_chat.chat_deepagent`. + +The helper picks the model id fed to ``detect_provider_variant`` so the +right ``<provider_hints>`` block lands in the system prompt. The tests +below pin its preference order: + +1. ``agent_config.litellm_params["base_model"]`` (Azure-correct). +2. ``agent_config.model_name``. +3. ``getattr(llm, "model", None)``. + +Without (1) an Azure deployment named e.g. ``"prod-chat-001"`` would +silently miss every provider regex. +""" + +from __future__ import annotations + +import pytest + +from app.agents.new_chat.chat_deepagent import _resolve_prompt_model_name +from app.agents.new_chat.llm_config import AgentConfig + +pytestmark = pytest.mark.unit + + +def _make_cfg(**overrides) -> AgentConfig: + """Build an ``AgentConfig`` with sensible defaults for the helper test.""" + defaults = { + "provider": "OPENAI", + "model_name": "x", + "api_key": "k", + } + return AgentConfig(**{**defaults, **overrides}) + + +class _FakeLLM: + """Stand-in for a ``ChatLiteLLM`` / ``ChatLiteLLMRouter`` instance. + + The resolver only reads the ``.model`` attribute via ``getattr``, + matching the established idiom in ``knowledge_search.py`` / + ``stream_new_chat.py`` / ``document_summarizer.py``. + """ + + def __init__(self, model: str | None) -> None: + self.model = model + + +def test_prefers_litellm_params_base_model_over_deployment_name() -> None: + """Azure deployment slug must NOT shadow the underlying model family. + + This is the failure mode the helper exists to prevent: a deployment + named ``"azure/prod-chat-001"`` would not match any provider regex + on its own, but the family ``"gpt-4o"`` lives in + ``litellm_params["base_model"]`` and routes to ``openai_classic``. + """ + cfg = _make_cfg( + model_name="azure/prod-chat-001", + litellm_params={"base_model": "gpt-4o"}, + ) + assert _resolve_prompt_model_name(cfg, _FakeLLM("azure/prod-chat-001")) == "gpt-4o" + + +def test_falls_back_to_model_name_when_litellm_params_is_none() -> None: + cfg = _make_cfg( + model_name="anthropic/claude-3-5-sonnet", + litellm_params=None, + ) + got = _resolve_prompt_model_name(cfg, _FakeLLM("anthropic/claude-3-5-sonnet")) + assert got == "anthropic/claude-3-5-sonnet" + + +def test_handles_litellm_params_without_base_model_key() -> None: + cfg = _make_cfg( + model_name="openai/gpt-4o", + litellm_params={"temperature": 0.5}, + ) + assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o" + + +def test_ignores_blank_base_model() -> None: + """Whitespace-only ``base_model`` must not shadow ``model_name``.""" + cfg = _make_cfg( + model_name="openai/gpt-4o", + litellm_params={"base_model": " "}, + ) + assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o" + + +def test_ignores_non_string_base_model() -> None: + """Defensive: a non-string ``base_model`` should not crash the resolver.""" + cfg = _make_cfg( + model_name="openai/gpt-4o", + litellm_params={"base_model": 42}, + ) + assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o" + + +def test_falls_back_to_llm_model_when_no_agent_config() -> None: + """No ``agent_config`` -> use ``llm.model`` directly. Defensive path + for direct callers; production callers always supply a config.""" + assert ( + _resolve_prompt_model_name(None, _FakeLLM("openai/gpt-4o-mini")) + == "openai/gpt-4o-mini" + ) + + +def test_returns_none_when_nothing_available() -> None: + """``compose_system_prompt`` treats ``None`` as the ``"default"`` + variant and emits no provider block.""" + assert _resolve_prompt_model_name(None, _FakeLLM(None)) is None + + +def test_auto_mode_resolves_to_auto_string() -> None: + """Auto mode -> ``"auto"``. ``detect_provider_variant("auto")`` + returns ``"default"``, which is correct: the child model isn't + known until the LiteLLM Router dispatches.""" + cfg = AgentConfig.from_auto_mode() + assert _resolve_prompt_model_name(cfg, _FakeLLM("auto")) == "auto" diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 5e6ad6abd..5935d73ae 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -372,3 +372,39 @@ def test_turn_status_sse_contract_exists(): assert 'type: "data-turn-status"' in state_source assert 'case "data-turn-status":' in pipeline_source assert "end_turn(str(chat_id))" in stream_source + + +def test_chat_deepagent_forwards_resolved_model_name_to_both_builders(): + """Regression guard: both system-prompt builders in chat_deepagent.py + must receive ``model_name=_resolve_prompt_model_name(...)`` so the + provider-variant dispatch can render the right ``<provider_hints>`` + block. Without this the prompt silently falls back to the empty + ``"default"`` variant — the original bug being fixed. + + This test mirrors :func:`test_stream_error_emission_keeps_machine_error_codes` + in style: it inspects module source text + a regex to enforce the + call-site shape, not just the wrapper layer (the wrappers already + forward ``model_name`` correctly, so testing them would not catch + the actual missed plumbing). + """ + import app.agents.new_chat.chat_deepagent as chat_deepagent_module + + source = inspect.getsource(chat_deepagent_module) + + # Helper itself must be defined. + assert "def _resolve_prompt_model_name(" in source + + # Both builder calls must forward the resolved model name. Match + # across newlines + whitespace because the kwargs are split over + # multiple lines. + pattern = re.compile( + r"build_(?:surfsense|configurable)_system_prompt\([^)]*" + r"model_name=_resolve_prompt_model_name\(", + re.DOTALL, + ) + matches = pattern.findall(source) + assert len(matches) == 2, ( + "Expected both system-prompt builder call sites to forward " + "`model_name=_resolve_prompt_model_name(...)`, found " + f"{len(matches)}" + ) diff --git a/surfsense_backend/uv.lock b/surfsense_backend/uv.lock index 209c42a9c..efe670d05 100644 --- a/surfsense_backend/uv.lock +++ b/surfsense_backend/uv.lock @@ -62,7 +62,7 @@ wheels = [ [[package]] name = "aiohttp" -version = "3.13.5" +version = "3.13.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohappyeyeballs" }, @@ -73,76 +73,76 @@ dependencies = [ { name = "propcache" }, { name = "yarl" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/77/9a/152096d4808df8e4268befa55fba462f440f14beab85e8ad9bf990516918/aiohttp-3.13.5.tar.gz", hash = "sha256:9d98cc980ecc96be6eb4c1994ce35d28d8b1f5e5208a23b421187d1209dbb7d1", size = 7858271 } +sdist = { url = "https://files.pythonhosted.org/packages/45/4a/064321452809dae953c1ed6e017504e72551a26b6f5708a5a80e4bf556ff/aiohttp-3.13.4.tar.gz", hash = "sha256:d97a6d09c66087890c2ab5d49069e1e570583f7ac0314ecf98294c1b6aaebd38", size = 7859748 } wheels = [ - { url = "https://files.pythonhosted.org/packages/be/6f/353954c29e7dcce7cf00280a02c75f30e133c00793c7a2ed3776d7b2f426/aiohttp-3.13.5-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:023ecba036ddd840b0b19bf195bfae970083fd7024ce1ac22e9bba90464620e9", size = 748876 }, - { url = "https://files.pythonhosted.org/packages/f5/1b/428a7c64687b3b2e9cd293186695affc0e1e54a445d0361743b231f11066/aiohttp-3.13.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:15c933ad7920b7d9a20de151efcd05a6e38302cbf0e10c9b2acb9a42210a2416", size = 499557 }, - { url = "https://files.pythonhosted.org/packages/29/47/7be41556bfbb6917069d6a6634bb7dd5e163ba445b783a90d40f5ac7e3a7/aiohttp-3.13.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ab2899f9fa2f9f741896ebb6fa07c4c883bfa5c7f2ddd8cf2aafa86fa981b2d2", size = 500258 }, - { url = "https://files.pythonhosted.org/packages/67/84/c9ecc5828cb0b3695856c07c0a6817a99d51e2473400f705275a2b3d9239/aiohttp-3.13.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a60eaa2d440cd4707696b52e40ed3e2b0f73f65be07fd0ef23b6b539c9c0b0b4", size = 1749199 }, - { url = "https://files.pythonhosted.org/packages/f0/d3/3c6d610e66b495657622edb6ae7c7fd31b2e9086b4ec50b47897ad6042a9/aiohttp-3.13.5-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:55b3bdd3292283295774ab585160c4004f4f2f203946997f49aac032c84649e9", size = 1721013 }, - { url = "https://files.pythonhosted.org/packages/49/a0/24409c12217456df0bae7babe3b014e460b0b38a8e60753d6cb339f6556d/aiohttp-3.13.5-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c2b2355dc094e5f7d45a7bb262fe7207aa0460b37a0d87027dcf21b5d890e7d5", size = 1781501 }, - { url = "https://files.pythonhosted.org/packages/98/9d/b65ec649adc5bccc008b0957a9a9c691070aeac4e41cea18559fef49958b/aiohttp-3.13.5-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b38765950832f7d728297689ad78f5f2cf79ff82487131c4d26fe6ceecdc5f8e", size = 1878981 }, - { url = "https://files.pythonhosted.org/packages/57/d8/8d44036d7eb7b6a8ec4c5494ea0c8c8b94fbc0ed3991c1a7adf230df03bf/aiohttp-3.13.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b18f31b80d5a33661e08c89e202edabf1986e9b49c42b4504371daeaa11b47c1", size = 1767934 }, - { url = "https://files.pythonhosted.org/packages/31/04/d3f8211f273356f158e3464e9e45484d3fb8c4ce5eb2f6fe9405c3273983/aiohttp-3.13.5-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:33add2463dde55c4f2d9635c6ab33ce154e5ecf322bd26d09af95c5f81cfa286", size = 1566671 }, - { url = "https://files.pythonhosted.org/packages/41/db/073e4ebe00b78e2dfcacff734291651729a62953b48933d765dc513bf798/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:327cc432fdf1356fb4fbc6fe833ad4e9f6aacb71a8acaa5f1855e4b25910e4a9", size = 1705219 }, - { url = "https://files.pythonhosted.org/packages/48/45/7dfba71a2f9fd97b15c95c06819de7eb38113d2cdb6319669195a7d64270/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:7c35b0bf0b48a70b4cb4fc5d7bed9b932532728e124874355de1a0af8ec4bc88", size = 1743049 }, - { url = "https://files.pythonhosted.org/packages/18/71/901db0061e0f717d226386a7f471bb59b19566f2cae5f0d93874b017271f/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:df23d57718f24badef8656c49743e11a89fd6f5358fa8a7b96e728fda2abf7d3", size = 1749557 }, - { url = "https://files.pythonhosted.org/packages/08/d5/41eebd16066e59cd43728fe74bce953d7402f2b4ddfdfef2c0e9f17ca274/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:02e048037a6501a5ec1f6fc9736135aec6eb8a004ce48838cb951c515f32c80b", size = 1558931 }, - { url = "https://files.pythonhosted.org/packages/30/e6/4a799798bf05740e66c3a1161079bda7a3dd8e22ca392481d7a7f9af82a6/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:31cebae8b26f8a615d2b546fee45d5ffb76852ae6450e2a03f42c9102260d6fe", size = 1774125 }, - { url = "https://files.pythonhosted.org/packages/84/63/7749337c90f92bc2cb18f9560d67aa6258c7060d1397d21529b8004fcf6f/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:888e78eb5ca55a615d285c3c09a7a91b42e9dd6fc699b166ebd5dee87c9ccf14", size = 1732427 }, - { url = "https://files.pythonhosted.org/packages/98/de/cf2f44ff98d307e72fb97d5f5bbae3bfcb442f0ea9790c0bf5c5c2331404/aiohttp-3.13.5-cp312-cp312-win32.whl", hash = "sha256:8bd3ec6376e68a41f9f95f5ed170e2fcf22d4eb27a1f8cb361d0508f6e0557f3", size = 433534 }, - { url = "https://files.pythonhosted.org/packages/aa/ca/eadf6f9c8fa5e31d40993e3db153fb5ed0b11008ad5d9de98a95045bed84/aiohttp-3.13.5-cp312-cp312-win_amd64.whl", hash = "sha256:110e448e02c729bcebb18c60b9214a87ba33bac4a9fa5e9a5f139938b56c6cb1", size = 460446 }, - { url = "https://files.pythonhosted.org/packages/78/e9/d76bf503005709e390122d34e15256b88f7008e246c4bdbe915cd4f1adce/aiohttp-3.13.5-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a5029cc80718bbd545123cd8fe5d15025eccaaaace5d0eeec6bd556ad6163d61", size = 742930 }, - { url = "https://files.pythonhosted.org/packages/57/00/4b7b70223deaebd9bb85984d01a764b0d7bd6526fcdc73cca83bcbe7243e/aiohttp-3.13.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4bb6bf5811620003614076bdc807ef3b5e38244f9d25ca5fe888eaccea2a9832", size = 496927 }, - { url = "https://files.pythonhosted.org/packages/9c/f5/0fb20fb49f8efdcdce6cd8127604ad2c503e754a8f139f5e02b01626523f/aiohttp-3.13.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a84792f8631bf5a94e52d9cc881c0b824ab42717165a5579c760b830d9392ac9", size = 497141 }, - { url = "https://files.pythonhosted.org/packages/3b/86/b7c870053e36a94e8951b803cb5b909bfbc9b90ca941527f5fcafbf6b0fa/aiohttp-3.13.5-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:57653eac22c6a4c13eb22ecf4d673d64a12f266e72785ab1c8b8e5940d0e8090", size = 1732476 }, - { url = "https://files.pythonhosted.org/packages/b5/e5/4e161f84f98d80c03a238671b4136e6530453d65262867d989bbe78244d0/aiohttp-3.13.5-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:e5e5f7debc7a57af53fdf5c5009f9391d9f4c12867049d509bf7bb164a6e295b", size = 1706507 }, - { url = "https://files.pythonhosted.org/packages/d4/56/ea11a9f01518bd5a2a2fcee869d248c4b8a0cfa0bb13401574fa31adf4d4/aiohttp-3.13.5-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c719f65bebcdf6716f10e9eff80d27567f7892d8988c06de12bbbd39307c6e3a", size = 1773465 }, - { url = "https://files.pythonhosted.org/packages/eb/40/333ca27fb74b0383f17c90570c748f7582501507307350a79d9f9f3c6eb1/aiohttp-3.13.5-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d97f93fdae594d886c5a866636397e2bcab146fd7a132fd6bb9ce182224452f8", size = 1873523 }, - { url = "https://files.pythonhosted.org/packages/f0/d2/e2f77eef1acb7111405433c707dc735e63f67a56e176e72e9e7a2cd3f493/aiohttp-3.13.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3df334e39d4c2f899a914f1dba283c1aadc311790733f705182998c6f7cae665", size = 1754113 }, - { url = "https://files.pythonhosted.org/packages/fb/56/3f653d7f53c89669301ec9e42c95233e2a0c0a6dd051269e6e678db4fdb0/aiohttp-3.13.5-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:fe6970addfea9e5e081401bcbadf865d2b6da045472f58af08427e108d618540", size = 1562351 }, - { url = "https://files.pythonhosted.org/packages/ec/a6/9b3e91eb8ae791cce4ee736da02211c85c6f835f1bdfac0594a8a3b7018c/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:7becdf835feff2f4f335d7477f121af787e3504b48b449ff737afb35869ba7bb", size = 1693205 }, - { url = "https://files.pythonhosted.org/packages/98/fc/bfb437a99a2fcebd6b6eaec609571954de2ed424f01c352f4b5504371dd3/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:676e5651705ad5d8a70aeb8eb6936c436d8ebbd56e63436cb7dd9bb36d2a9a46", size = 1730618 }, - { url = "https://files.pythonhosted.org/packages/e4/b6/c8534862126191a034f68153194c389addc285a0f1347d85096d349bbc15/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:9b16c653d38eb1a611cc898c41e76859ca27f119d25b53c12875fd0474ae31a8", size = 1745185 }, - { url = "https://files.pythonhosted.org/packages/0b/93/4ca8ee2ef5236e2707e0fd5fecb10ce214aee1ff4ab307af9c558bda3b37/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:999802d5fa0389f58decd24b537c54aa63c01c3219ce17d1214cbda3c2b22d2d", size = 1557311 }, - { url = "https://files.pythonhosted.org/packages/57/ae/76177b15f18c5f5d094f19901d284025db28eccc5ae374d1d254181d33f4/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:ec707059ee75732b1ba130ed5f9580fe10ff75180c812bc267ded039db5128c6", size = 1773147 }, - { url = "https://files.pythonhosted.org/packages/01/a4/62f05a0a98d88af59d93b7fcac564e5f18f513cb7471696ac286db970d6a/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:2d6d44a5b48132053c2f6cd5c8cb14bc67e99a63594e336b0f2af81e94d5530c", size = 1730356 }, - { url = "https://files.pythonhosted.org/packages/e4/85/fc8601f59dfa8c9523808281f2da571f8b4699685f9809a228adcc90838d/aiohttp-3.13.5-cp313-cp313-win32.whl", hash = "sha256:329f292ed14d38a6c4c435e465f48bebb47479fd676a0411936cc371643225cc", size = 432637 }, - { url = "https://files.pythonhosted.org/packages/c0/1b/ac685a8882896acf0f6b31d689e3792199cfe7aba37969fa91da63a7fa27/aiohttp-3.13.5-cp313-cp313-win_amd64.whl", hash = "sha256:69f571de7500e0557801c0b51f4780482c0ec5fe2ac851af5a92cfce1af1cb83", size = 458896 }, - { url = "https://files.pythonhosted.org/packages/5d/ce/46572759afc859e867a5bc8ec3487315869013f59281ce61764f76d879de/aiohttp-3.13.5-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:eb4639f32fd4a9904ab8fb45bf3383ba71137f3d9d4ba25b3b3f3109977c5b8c", size = 745721 }, - { url = "https://files.pythonhosted.org/packages/13/fe/8a2efd7626dbe6049b2ef8ace18ffda8a4dfcbe1bcff3ac30c0c7575c20b/aiohttp-3.13.5-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:7e5dc4311bd5ac493886c63cbf76ab579dbe4641268e7c74e48e774c74b6f2be", size = 497663 }, - { url = "https://files.pythonhosted.org/packages/9b/91/cc8cc78a111826c54743d88651e1687008133c37e5ee615fee9b57990fac/aiohttp-3.13.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:756c3c304d394977519824449600adaf2be0ccee76d206ee339c5e76b70ded25", size = 499094 }, - { url = "https://files.pythonhosted.org/packages/0a/33/a8362cb15cf16a3af7e86ed11962d5cd7d59b449202dc576cdc731310bde/aiohttp-3.13.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ecc26751323224cf8186efcf7fbcbc30f4e1d8c7970659daf25ad995e4032a56", size = 1726701 }, - { url = "https://files.pythonhosted.org/packages/45/0c/c091ac5c3a17114bd76cbf85d674650969ddf93387876cf67f754204bd77/aiohttp-3.13.5-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:10a75acfcf794edf9d8db50e5a7ec5fc818b2a8d3f591ce93bc7b1210df016d2", size = 1683360 }, - { url = "https://files.pythonhosted.org/packages/23/73/bcee1c2b79bc275e964d1446c55c54441a461938e70267c86afaae6fba27/aiohttp-3.13.5-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:0f7a18f258d124cd678c5fe072fe4432a4d5232b0657fca7c1847f599233c83a", size = 1773023 }, - { url = "https://files.pythonhosted.org/packages/c7/ef/720e639df03004fee2d869f771799d8c23046dec47d5b81e396c7cda583a/aiohttp-3.13.5-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:df6104c009713d3a89621096f3e3e88cc323fd269dbd7c20afe18535094320be", size = 1853795 }, - { url = "https://files.pythonhosted.org/packages/bd/c9/989f4034fb46841208de7aeeac2c6d8300745ab4f28c42f629ba77c2d916/aiohttp-3.13.5-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:241a94f7de7c0c3b616627aaad530fe2cb620084a8b144d3be7b6ecfe95bae3b", size = 1730405 }, - { url = "https://files.pythonhosted.org/packages/ce/75/ee1fd286ca7dc599d824b5651dad7b3be7ff8d9a7e7b3fe9820d9180f7db/aiohttp-3.13.5-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:c974fb66180e58709b6fc402846f13791240d180b74de81d23913abe48e96d94", size = 1558082 }, - { url = "https://files.pythonhosted.org/packages/c3/20/1e9e6650dfc436340116b7aa89ff8cb2bbdf0abc11dfaceaad8f74273a10/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:6e27ea05d184afac78aabbac667450c75e54e35f62238d44463131bd3f96753d", size = 1692346 }, - { url = "https://files.pythonhosted.org/packages/d8/40/8ebc6658d48ea630ac7903912fe0dd4e262f0e16825aa4c833c56c9f1f56/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:a79a6d399cef33a11b6f004c67bb07741d91f2be01b8d712d52c75711b1e07c7", size = 1698891 }, - { url = "https://files.pythonhosted.org/packages/d8/78/ea0ae5ec8ba7a5c10bdd6e318f1ba5e76fcde17db8275188772afc7917a4/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:c632ce9c0b534fbe25b52c974515ed674937c5b99f549a92127c85f771a78772", size = 1742113 }, - { url = "https://files.pythonhosted.org/packages/8a/66/9d308ed71e3f2491be1acb8769d96c6f0c47d92099f3bc9119cada27b357/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:fceedde51fbd67ee2bcc8c0b33d0126cc8b51ef3bbde2f86662bd6d5a6f10ec5", size = 1553088 }, - { url = "https://files.pythonhosted.org/packages/da/a6/6cc25ed8dfc6e00c90f5c6d126a98e2cf28957ad06fa1036bd34b6f24a2c/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:f92995dfec9420bb69ae629abf422e516923ba79ba4403bc750d94fb4a6c68c1", size = 1757976 }, - { url = "https://files.pythonhosted.org/packages/c1/2b/cce5b0ffe0de99c83e5e36d8f828e4161e415660a9f3e58339d07cce3006/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:20ae0ff08b1f2c8788d6fb85afcb798654ae6ba0b747575f8562de738078457b", size = 1712444 }, - { url = "https://files.pythonhosted.org/packages/6c/cf/9e1795b4160c58d29421eafd1a69c6ce351e2f7c8d3c6b7e4ca44aea1a5b/aiohttp-3.13.5-cp314-cp314-win32.whl", hash = "sha256:b20df693de16f42b2472a9c485e1c948ee55524786a0a34345511afdd22246f3", size = 438128 }, - { url = "https://files.pythonhosted.org/packages/22/4d/eaedff67fc805aeba4ba746aec891b4b24cebb1a7d078084b6300f79d063/aiohttp-3.13.5-cp314-cp314-win_amd64.whl", hash = "sha256:f85c6f327bf0b8c29da7d93b1cabb6363fb5e4e160a32fa241ed2dce21b73162", size = 464029 }, - { url = "https://files.pythonhosted.org/packages/79/11/c27d9332ee20d68dd164dc12a6ecdef2e2e35ecc97ed6cf0d2442844624b/aiohttp-3.13.5-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:1efb06900858bb618ff5cee184ae2de5828896c448403d51fb633f09e109be0a", size = 778758 }, - { url = "https://files.pythonhosted.org/packages/04/fb/377aead2e0a3ba5f09b7624f702a964bdf4f08b5b6728a9799830c80041e/aiohttp-3.13.5-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:fee86b7c4bd29bdaf0d53d14739b08a106fdda809ca5fe032a15f52fae5fe254", size = 512883 }, - { url = "https://files.pythonhosted.org/packages/bb/a6/aa109a33671f7a5d3bd78b46da9d852797c5e665bfda7d6b373f56bff2ec/aiohttp-3.13.5-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:20058e23909b9e65f9da62b396b77dfa95965cbe840f8def6e572538b1d32e36", size = 516668 }, - { url = "https://files.pythonhosted.org/packages/79/b3/ca078f9f2fa9563c36fb8ef89053ea2bb146d6f792c5104574d49d8acb63/aiohttp-3.13.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8cf20a8d6868cb15a73cab329ffc07291ba8c22b1b88176026106ae39aa6df0f", size = 1883461 }, - { url = "https://files.pythonhosted.org/packages/b7/e3/a7ad633ca1ca497b852233a3cce6906a56c3225fb6d9217b5e5e60b7419d/aiohttp-3.13.5-cp314-cp314t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:330f5da04c987f1d5bdb8ae189137c77139f36bd1cb23779ca1a354a4b027800", size = 1747661 }, - { url = "https://files.pythonhosted.org/packages/33/b9/cd6fe579bed34a906d3d783fe60f2fa297ef55b27bb4538438ee49d4dc41/aiohttp-3.13.5-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:6f1cbf0c7926d315c3c26c2da41fd2b5d2fe01ac0e157b78caefc51a782196cf", size = 1863800 }, - { url = "https://files.pythonhosted.org/packages/c0/3f/2c1e2f5144cefa889c8afd5cf431994c32f3b29da9961698ff4e3811b79a/aiohttp-3.13.5-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:53fc049ed6390d05423ba33103ded7281fe897cf97878f369a527070bd95795b", size = 1958382 }, - { url = "https://files.pythonhosted.org/packages/66/1d/f31ec3f1013723b3babe3609e7f119c2c2fb6ef33da90061a705ef3e1bc8/aiohttp-3.13.5-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:898703aa2667e3c5ca4c54ca36cd73f58b7a38ef87a5606414799ebce4d3fd3a", size = 1803724 }, - { url = "https://files.pythonhosted.org/packages/0e/b4/57712dfc6f1542f067daa81eb61da282fab3e6f1966fca25db06c4fc62d5/aiohttp-3.13.5-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:0494a01ca9584eea1e5fbd6d748e61ecff218c51b576ee1999c23db7066417d8", size = 1640027 }, - { url = "https://files.pythonhosted.org/packages/25/3c/734c878fb43ec083d8e31bf029daae1beafeae582d1b35da234739e82ee7/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:6cf81fe010b8c17b09495cbd15c1d35afbc8fb405c0c9cf4738e5ae3af1d65be", size = 1806644 }, - { url = "https://files.pythonhosted.org/packages/20/a5/f671e5cbec1c21d044ff3078223f949748f3a7f86b14e34a365d74a5d21f/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:c564dd5f09ddc9d8f2c2d0a301cd30a79a2cc1b46dd1a73bef8f0038863d016b", size = 1791630 }, - { url = "https://files.pythonhosted.org/packages/0b/63/fb8d0ad63a0b8a99be97deac8c04dacf0785721c158bdf23d679a87aa99e/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:2994be9f6e51046c4f864598fd9abeb4fba6e88f0b2152422c9666dcd4aea9c6", size = 1809403 }, - { url = "https://files.pythonhosted.org/packages/59/0c/bfed7f30662fcf12206481c2aac57dedee43fe1c49275e85b3a1e1742294/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:157826e2fa245d2ef46c83ea8a5faf77ca19355d278d425c29fda0beb3318037", size = 1634924 }, - { url = "https://files.pythonhosted.org/packages/17/d6/fd518d668a09fd5a3319ae5e984d4d80b9a4b3df4e21c52f02251ef5a32e/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:a8aca50daa9493e9e13c0f566201a9006f080e7c50e5e90d0b06f53146a54500", size = 1836119 }, - { url = "https://files.pythonhosted.org/packages/78/b7/15fb7a9d52e112a25b621c67b69c167805cb1f2ab8f1708a5c490d1b52fe/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:3b13560160d07e047a93f23aaa30718606493036253d5430887514715b67c9d9", size = 1772072 }, - { url = "https://files.pythonhosted.org/packages/7e/df/57ba7f0c4a553fc2bd8b6321df236870ec6fd64a2a473a8a13d4f733214e/aiohttp-3.13.5-cp314-cp314t-win32.whl", hash = "sha256:9a0f4474b6ea6818b41f82172d799e4b3d29e22c2c520ce4357856fced9af2f8", size = 471819 }, - { url = "https://files.pythonhosted.org/packages/62/29/2f8418269e46454a26171bfdd6a055d74febf32234e474930f2f60a17145/aiohttp-3.13.5-cp314-cp314t-win_amd64.whl", hash = "sha256:18a2f6c1182c51baa1d28d68fea51513cb2a76612f038853c0ad3c145423d3d9", size = 505441 }, + { url = "https://files.pythonhosted.org/packages/1e/bd/ede278648914cabbabfdf95e436679b5d4156e417896a9b9f4587169e376/aiohttp-3.13.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ee62d4471ce86b108b19c3364db4b91180d13fe3510144872d6bad5401957360", size = 752158 }, + { url = "https://files.pythonhosted.org/packages/90/de/581c053253c07b480b03785196ca5335e3c606a37dc73e95f6527f1591fe/aiohttp-3.13.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c0fd8f41b54b58636402eb493afd512c23580456f022c1ba2db0f810c959ed0d", size = 501037 }, + { url = "https://files.pythonhosted.org/packages/fa/f9/a5ede193c08f13cc42c0a5b50d1e246ecee9115e4cf6e900d8dbd8fd6acb/aiohttp-3.13.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4baa48ce49efd82d6b1a0be12d6a36b35e5594d1dd42f8bfba96ea9f8678b88c", size = 501556 }, + { url = "https://files.pythonhosted.org/packages/d6/10/88ff67cd48a6ec36335b63a640abe86135791544863e0cfe1f065d6cef7a/aiohttp-3.13.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d738ebab9f71ee652d9dbd0211057690022201b11197f9a7324fd4dba128aa97", size = 1757314 }, + { url = "https://files.pythonhosted.org/packages/8b/15/fdb90a5cf5a1f52845c276e76298c75fbbcc0ac2b4a86551906d54529965/aiohttp-3.13.4-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:0ce692c3468fa831af7dceed52edf51ac348cebfc8d3feb935927b63bd3e8576", size = 1731819 }, + { url = "https://files.pythonhosted.org/packages/ec/df/28146785a007f7820416be05d4f28cc207493efd1e8c6c1068e9bdc29198/aiohttp-3.13.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8e08abcfe752a454d2cb89ff0c08f2d1ecd057ae3e8cc6d84638de853530ebab", size = 1793279 }, + { url = "https://files.pythonhosted.org/packages/10/47/689c743abf62ea7a77774d5722f220e2c912a77d65d368b884d9779ef41b/aiohttp-3.13.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5977f701b3fff36367a11087f30ea73c212e686d41cd363c50c022d48b011d8d", size = 1891082 }, + { url = "https://files.pythonhosted.org/packages/b0/b6/f7f4f318c7e58c23b761c9b13b9a3c9b394e0f9d5d76fbc6622fa98509f6/aiohttp-3.13.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:54203e10405c06f8b6020bd1e076ae0fe6c194adcee12a5a78af3ffa3c57025e", size = 1773938 }, + { url = "https://files.pythonhosted.org/packages/aa/06/f207cb3121852c989586a6fc16ff854c4fcc8651b86c5d3bd1fc83057650/aiohttp-3.13.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:358a6af0145bc4dda037f13167bef3cce54b132087acc4c295c739d05d16b1c3", size = 1579548 }, + { url = "https://files.pythonhosted.org/packages/6c/58/e1289661a32161e24c1fe479711d783067210d266842523752869cc1d9c2/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:898ea1850656d7d61832ef06aa9846ab3ddb1621b74f46de78fbc5e1a586ba83", size = 1714669 }, + { url = "https://files.pythonhosted.org/packages/96/0a/3e86d039438a74a86e6a948a9119b22540bae037d6ba317a042ae3c22711/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:7bc30cceb710cf6a44e9617e43eebb6e3e43ad855a34da7b4b6a73537d8a6763", size = 1754175 }, + { url = "https://files.pythonhosted.org/packages/f4/30/e717fc5df83133ba467a560b6d8ef20197037b4bb5d7075b90037de1018e/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4a31c0c587a8a038f19a4c7e60654a6c899c9de9174593a13e7cc6e15ff271f9", size = 1762049 }, + { url = "https://files.pythonhosted.org/packages/e4/28/8f7a2d4492e336e40005151bdd94baf344880a4707573378579f833a64c1/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:2062f675f3fe6e06d6113eb74a157fb9df58953ffed0cdb4182554b116545758", size = 1570861 }, + { url = "https://files.pythonhosted.org/packages/78/45/12e1a3d0645968b1c38de4b23fdf270b8637735ea057d4f84482ff918ad9/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:3d1ba8afb847ff80626d5e408c1fdc99f942acc877d0702fe137015903a220a9", size = 1790003 }, + { url = "https://files.pythonhosted.org/packages/eb/0f/60374e18d590de16dcb39d6ff62f39c096c1b958e6f37727b5870026ea30/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b08149419994cdd4d5eecf7fd4bc5986b5a9380285bcd01ab4c0d6bfca47b79d", size = 1737289 }, + { url = "https://files.pythonhosted.org/packages/02/bf/535e58d886cfbc40a8b0013c974afad24ef7632d645bca0b678b70033a60/aiohttp-3.13.4-cp312-cp312-win32.whl", hash = "sha256:fc432f6a2c4f720180959bc19aa37259651c1a4ed8af8afc84dd41c60f15f791", size = 434185 }, + { url = "https://files.pythonhosted.org/packages/1e/1a/d92e3325134ebfff6f4069f270d3aac770d63320bd1fcd0eca023e74d9a8/aiohttp-3.13.4-cp312-cp312-win_amd64.whl", hash = "sha256:6148c9ae97a3e8bff9a1fc9c757fa164116f86c100468339730e717590a3fb77", size = 461285 }, + { url = "https://files.pythonhosted.org/packages/e3/ac/892f4162df9b115b4758d615f32ec63d00f3084c705ff5526630887b9b42/aiohttp-3.13.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:63dd5e5b1e43b8fb1e91b79b7ceba1feba588b317d1edff385084fcc7a0a4538", size = 745744 }, + { url = "https://files.pythonhosted.org/packages/97/a9/c5b87e4443a2f0ea88cb3000c93a8fdad1ee63bffc9ded8d8c8e0d66efc6/aiohttp-3.13.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:746ac3cc00b5baea424dacddea3ec2c2702f9590de27d837aa67004db1eebc6e", size = 498178 }, + { url = "https://files.pythonhosted.org/packages/94/42/07e1b543a61250783650df13da8ddcdc0d0a5538b2bd15cef6e042aefc61/aiohttp-3.13.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:bda8f16ea99d6a6705e5946732e48487a448be874e54a4f73d514660ff7c05d3", size = 498331 }, + { url = "https://files.pythonhosted.org/packages/20/d6/492f46bf0328534124772d0cf58570acae5b286ea25006900650f69dae0e/aiohttp-3.13.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4b061e7b5f840391e3f64d0ddf672973e45c4cfff7a0feea425ea24e51530fc2", size = 1744414 }, + { url = "https://files.pythonhosted.org/packages/e2/4d/e02627b2683f68051246215d2d62b2d2f249ff7a285e7a858dc47d6b6a14/aiohttp-3.13.4-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:b252e8d5cd66184b570d0d010de742736e8a4fab22c58299772b0c5a466d4b21", size = 1719226 }, + { url = "https://files.pythonhosted.org/packages/7b/6c/5d0a3394dd2b9f9aeba6e1b6065d0439e4b75d41f1fb09a3ec010b43552b/aiohttp-3.13.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:20af8aad61d1803ff11152a26146d8d81c266aa8c5aa9b4504432abb965c36a0", size = 1782110 }, + { url = "https://files.pythonhosted.org/packages/0d/2d/c20791e3437700a7441a7edfb59731150322424f5aadf635602d1d326101/aiohttp-3.13.4-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:13a5cc924b59859ad2adb1478e31f410a7ed46e92a2a619d6d1dd1a63c1a855e", size = 1884809 }, + { url = "https://files.pythonhosted.org/packages/c8/94/d99dbfbd1924a87ef643833932eb2a3d9e5eee87656efea7d78058539eff/aiohttp-3.13.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:534913dfb0a644d537aebb4123e7d466d94e3be5549205e6a31f72368980a81a", size = 1764938 }, + { url = "https://files.pythonhosted.org/packages/49/61/3ce326a1538781deb89f6cf5e094e2029cd308ed1e21b2ba2278b08426f6/aiohttp-3.13.4-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:320e40192a2dcc1cf4b5576936e9652981ab596bf81eb309535db7e2f5b5672f", size = 1570697 }, + { url = "https://files.pythonhosted.org/packages/b6/77/4ab5a546857bb3028fbaf34d6eea180267bdab022ee8b1168b1fcde4bfdd/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9e587fcfce2bcf06526a43cb705bdee21ac089096f2e271d75de9c339db3100c", size = 1702258 }, + { url = "https://files.pythonhosted.org/packages/79/63/d8f29021e39bc5af8e5d5e9da1b07976fb9846487a784e11e4f4eeda4666/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:9eb9c2eea7278206b5c6c1441fdd9dc420c278ead3f3b2cc87f9b693698cc500", size = 1740287 }, + { url = "https://files.pythonhosted.org/packages/55/3a/cbc6b3b124859a11bc8055d3682c26999b393531ef926754a3445b99dfef/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:29be00c51972b04bf9d5c8f2d7f7314f48f96070ca40a873a53056e652e805f7", size = 1753011 }, + { url = "https://files.pythonhosted.org/packages/e0/30/836278675205d58c1368b21520eab9572457cf19afd23759216c04483048/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:90c06228a6c3a7c9f776fe4fc0b7ff647fffd3bed93779a6913c804ae00c1073", size = 1566359 }, + { url = "https://files.pythonhosted.org/packages/50/b4/8032cc9b82d17e4277704ba30509eaccb39329dc18d6a35f05e424439e32/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:a533ec132f05fd9a1d959e7f34184cd7d5e8511584848dab85faefbaac573069", size = 1785537 }, + { url = "https://files.pythonhosted.org/packages/17/7d/5873e98230bde59f493bf1f7c3e327486a4b5653fa401144704df5d00211/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1c946f10f413836f82ea4cfb90200d2a59578c549f00857e03111cf45ad01ca5", size = 1740752 }, + { url = "https://files.pythonhosted.org/packages/7b/f2/13e46e0df051494d7d3c68b7f72d071f48c384c12716fc294f75d5b1a064/aiohttp-3.13.4-cp313-cp313-win32.whl", hash = "sha256:48708e2706106da6967eff5908c78ca3943f005ed6bcb75da2a7e4da94ef8c70", size = 433187 }, + { url = "https://files.pythonhosted.org/packages/ea/c0/649856ee655a843c8f8664592cfccb73ac80ede6a8c8db33a25d810c12db/aiohttp-3.13.4-cp313-cp313-win_amd64.whl", hash = "sha256:74a2eb058da44fa3a877a49e2095b591d4913308bb424c418b77beb160c55ce3", size = 459778 }, + { url = "https://files.pythonhosted.org/packages/6d/29/6657cc37ae04cacc2dbf53fb730a06b6091cc4cbe745028e047c53e6d840/aiohttp-3.13.4-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:e0a2c961fc92abeff61d6444f2ce6ad35bb982db9fc8ff8a47455beacf454a57", size = 749363 }, + { url = "https://files.pythonhosted.org/packages/90/7f/30ccdf67ca3d24b610067dc63d64dcb91e5d88e27667811640644aa4a85d/aiohttp-3.13.4-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:153274535985a0ff2bff1fb6c104ed547cec898a09213d21b0f791a44b14d933", size = 499317 }, + { url = "https://files.pythonhosted.org/packages/93/13/e372dd4e68ad04ee25dafb050c7f98b0d91ea643f7352757e87231102555/aiohttp-3.13.4-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:351f3171e2458da3d731ce83f9e6b9619e325c45cbd534c7759750cabf453ad7", size = 500477 }, + { url = "https://files.pythonhosted.org/packages/e5/fe/ee6298e8e586096fb6f5eddd31393d8544f33ae0792c71ecbb4c2bef98ac/aiohttp-3.13.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f989ac8bc5595ff761a5ccd32bdb0768a117f36dd1504b1c2c074ed5d3f4df9c", size = 1737227 }, + { url = "https://files.pythonhosted.org/packages/b0/b9/a7a0463a09e1a3fe35100f74324f23644bfc3383ac5fd5effe0722a5f0b7/aiohttp-3.13.4-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:d36fc1709110ec1e87a229b201dd3ddc32aa01e98e7868083a794609b081c349", size = 1694036 }, + { url = "https://files.pythonhosted.org/packages/57/7c/8972ae3fb7be00a91aee6b644b2a6a909aedb2c425269a3bfd90115e6f8f/aiohttp-3.13.4-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:42adaeea83cbdf069ab94f5103ce0787c21fb1a0153270da76b59d5578302329", size = 1786814 }, + { url = "https://files.pythonhosted.org/packages/93/01/c81e97e85c774decbaf0d577de7d848934e8166a3a14ad9f8aa5be329d28/aiohttp-3.13.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:92deb95469928cc41fd4b42a95d8012fa6df93f6b1c0a83af0ffbc4a5e218cde", size = 1866676 }, + { url = "https://files.pythonhosted.org/packages/5a/5f/5b46fe8694a639ddea2cd035bf5729e4677ea882cb251396637e2ef1590d/aiohttp-3.13.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0c0c7c07c4257ef3a1df355f840bc62d133bcdef5c1c5ba75add3c08553e2eed", size = 1740842 }, + { url = "https://files.pythonhosted.org/packages/20/a2/0d4b03d011cca6b6b0acba8433193c1e484efa8d705ea58295590fe24203/aiohttp-3.13.4-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:f062c45de8a1098cb137a1898819796a2491aec4e637a06b03f149315dff4d8f", size = 1566508 }, + { url = "https://files.pythonhosted.org/packages/98/17/e689fd500da52488ec5f889effd6404dece6a59de301e380f3c64f167beb/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:76093107c531517001114f0ebdb4f46858ce818590363e3e99a4a2280334454a", size = 1700569 }, + { url = "https://files.pythonhosted.org/packages/d8/0d/66402894dbcf470ef7db99449e436105ea862c24f7ea4c95c683e635af35/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:6f6ec32162d293b82f8b63a16edc80769662fbd5ae6fbd4936d3206a2c2cc63b", size = 1707407 }, + { url = "https://files.pythonhosted.org/packages/2f/eb/af0ab1a3650092cbd8e14ef29e4ab0209e1460e1c299996c3f8288b3f1ff/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:5903e2db3d202a00ad9f0ec35a122c005e85d90c9836ab4cda628f01edf425e2", size = 1752214 }, + { url = "https://files.pythonhosted.org/packages/5a/bf/72326f8a98e4c666f292f03c385545963cc65e358835d2a7375037a97b57/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:2d5bea57be7aca98dbbac8da046d99b5557c5cf4e28538c4c786313078aca09e", size = 1562162 }, + { url = "https://files.pythonhosted.org/packages/67/9f/13b72435f99151dd9a5469c96b3b5f86aa29b7e785ca7f35cf5e538f74c0/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:bcf0c9902085976edc0232b75006ef38f89686901249ce14226b6877f88464fb", size = 1768904 }, + { url = "https://files.pythonhosted.org/packages/18/bc/28d4970e7d5452ac7776cdb5431a1164a0d9cf8bd2fffd67b4fb463aa56d/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:c3295f98bfeed2e867cab588f2a146a9db37a85e3ae9062abf46ba062bd29165", size = 1723378 }, + { url = "https://files.pythonhosted.org/packages/53/74/b32458ca1a7f34d65bdee7aef2036adbe0438123d3d53e2b083c453c24dd/aiohttp-3.13.4-cp314-cp314-win32.whl", hash = "sha256:a598a5c5767e1369d8f5b08695cab1d8160040f796c4416af76fd773d229b3c9", size = 438711 }, + { url = "https://files.pythonhosted.org/packages/40/b2/54b487316c2df3e03a8f3435e9636f8a81a42a69d942164830d193beb56a/aiohttp-3.13.4-cp314-cp314-win_amd64.whl", hash = "sha256:c555db4bc7a264bead5a7d63d92d41a1122fcd39cc62a4db815f45ad46f9c2c8", size = 464977 }, + { url = "https://files.pythonhosted.org/packages/47/fb/e41b63c6ce71b07a59243bb8f3b457ee0c3402a619acb9d2c0d21ef0e647/aiohttp-3.13.4-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:45abbbf09a129825d13c18c7d3182fecd46d9da3cfc383756145394013604ac1", size = 781549 }, + { url = "https://files.pythonhosted.org/packages/97/53/532b8d28df1e17e44c4d9a9368b78dcb6bf0b51037522136eced13afa9e8/aiohttp-3.13.4-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:74c80b2bc2c2adb7b3d1941b2b60701ee2af8296fc8aad8b8bc48bc25767266c", size = 514383 }, + { url = "https://files.pythonhosted.org/packages/1b/1f/62e5d400603e8468cd635812d99cb81cfdc08127a3dc474c647615f31339/aiohttp-3.13.4-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:c97989ae40a9746650fa196894f317dafc12227c808c774929dda0ff873a5954", size = 518304 }, + { url = "https://files.pythonhosted.org/packages/90/57/2326b37b10896447e3c6e0cbef4fe2486d30913639a5cfd1332b5d870f82/aiohttp-3.13.4-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dae86be9811493f9990ef44fff1685f5c1a3192e9061a71a109d527944eed551", size = 1893433 }, + { url = "https://files.pythonhosted.org/packages/d2/b4/a24d82112c304afdb650167ef2fe190957d81cbddac7460bedd245f765aa/aiohttp-3.13.4-cp314-cp314t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:1db491abe852ca2fa6cc48a3341985b0174b3741838e1341b82ac82c8bd9e871", size = 1755901 }, + { url = "https://files.pythonhosted.org/packages/9e/2d/0883ef9d878d7846287f036c162a951968f22aabeef3ac97b0bea6f76d5d/aiohttp-3.13.4-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:0e5d701c0aad02a7dce72eef6b93226cf3734330f1a31d69ebbf69f33b86666e", size = 1876093 }, + { url = "https://files.pythonhosted.org/packages/ad/52/9204bb59c014869b71971addad6778f005daa72a96eed652c496789d7468/aiohttp-3.13.4-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8ac32a189081ae0a10ba18993f10f338ec94341f0d5df8fff348043962f3c6f8", size = 1970815 }, + { url = "https://files.pythonhosted.org/packages/d6/b5/e4eb20275a866dde0f570f411b36c6b48f7b53edfe4f4071aa1b0728098a/aiohttp-3.13.4-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:98e968cdaba43e45c73c3f306fca418c8009a957733bac85937c9f9cf3f4de27", size = 1816223 }, + { url = "https://files.pythonhosted.org/packages/d8/23/e98075c5bb146aa61a1239ee1ac7714c85e814838d6cebbe37d3fe19214a/aiohttp-3.13.4-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ca114790c9144c335d538852612d3e43ea0f075288f4849cf4b05d6cd2238ce7", size = 1649145 }, + { url = "https://files.pythonhosted.org/packages/d6/c1/7bad8be33bb06c2bb224b6468874346026092762cbec388c3bdb65a368ee/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:ea2e071661ba9cfe11eabbc81ac5376eaeb3061f6e72ec4cc86d7cdd1ffbdbbb", size = 1816562 }, + { url = "https://files.pythonhosted.org/packages/5c/10/c00323348695e9a5e316825969c88463dcc24c7e9d443244b8a2c9cf2eae/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:34e89912b6c20e0fd80e07fa401fd218a410aa1ce9f1c2f1dad6db1bd0ce0927", size = 1800333 }, + { url = "https://files.pythonhosted.org/packages/84/43/9b2147a1df3559f49bd723e22905b46a46c068a53adb54abdca32c4de180/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:0e217cf9f6a42908c52b46e42c568bd57adc39c9286ced31aaace614b6087965", size = 1820617 }, + { url = "https://files.pythonhosted.org/packages/a9/7f/b3481a81e7a586d02e99387b18c6dafff41285f6efd3daa2124c01f87eae/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:0c296f1221e21ba979f5ac1964c3b78cfde15c5c5f855ffd2caab337e9cd9182", size = 1643417 }, + { url = "https://files.pythonhosted.org/packages/8f/72/07181226bc99ce1124e0f89280f5221a82d3ae6a6d9d1973ce429d48e52b/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:d99a9d168ebaffb74f36d011750e490085ac418f4db926cce3989c8fe6cb6b1b", size = 1849286 }, + { url = "https://files.pythonhosted.org/packages/1a/e6/1b3566e103eca6da5be4ae6713e112a053725c584e96574caf117568ffef/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:cb19177205d93b881f3f89e6081593676043a6828f59c78c17a0fd6c1fbed2ba", size = 1782635 }, + { url = "https://files.pythonhosted.org/packages/37/58/1b11c71904b8d079eb0c39fe664180dd1e14bebe5608e235d8bfbadc8929/aiohttp-3.13.4-cp314-cp314t-win32.whl", hash = "sha256:c606aa5656dab6552e52ca368e43869c916338346bfaf6304e15c58fb113ea30", size = 472537 }, + { url = "https://files.pythonhosted.org/packages/bc/8f/87c56a1a1977d7dddea5b31e12189665a140fdb48a71e9038ff90bb564ec/aiohttp-3.13.4-cp314-cp314t-win_amd64.whl", hash = "sha256:014dcc10ec8ab8db681f0d68e939d1e9286a5aa2b993cbbdb0db130853e02144", size = 506381 }, ] [[package]] @@ -3723,7 +3723,7 @@ wheels = [ [[package]] name = "litellm" -version = "1.83.4" +version = "1.83.14" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -3739,9 +3739,9 @@ dependencies = [ { name = "tiktoken" }, { name = "tokenizers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/03/c4/30469c06ae7437a4406bc11e3c433cfd380a6771068cca15ea918dcd158f/litellm-1.83.4.tar.gz", hash = "sha256:6458d2030a41229460b321adee00517a91dbd8e63213cc953d355cb41d16f2d4", size = 17733899 } +sdist = { url = "https://files.pythonhosted.org/packages/8d/7c/c095649380adc96c8630273c1768c2ad1e74aa2ee1dd8dd05d218a60569f/litellm-1.83.14.tar.gz", hash = "sha256:24aef9b47cdc424c833e32f3727f411741c690832cd1fe4405e0077144fe09c9", size = 14836599 } wheels = [ - { url = "https://files.pythonhosted.org/packages/b8/bd/df19d3f8f6654535ee343a341fd921f81c411abf601a53e3eaef58129b02/litellm-1.83.4-py3-none-any.whl", hash = "sha256:17d7b4d48d47aca988ea4f762ddda5e7bd72cda3270192b22813d0330869d7b4", size = 16015555 }, + { url = "https://files.pythonhosted.org/packages/7f/5c/1b5691575420135e90578543b2bf219497caa33cfd0af64cb38f30288450/litellm-1.83.14-py3-none-any.whl", hash = "sha256:92b11ba2a32cf80707ddf388d18526696c7999a21b418c5e3b6eda1243d2cfdb", size = 16457054 }, ] [[package]] @@ -5124,7 +5124,7 @@ wheels = [ [[package]] name = "openai" -version = "2.30.0" +version = "2.24.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -5136,9 +5136,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/88/15/52580c8fbc16d0675d516e8749806eda679b16de1e4434ea06fb6feaa610/openai-2.30.0.tar.gz", hash = "sha256:92f7661c990bda4b22a941806c83eabe4896c3094465030dd882a71abe80c885", size = 676084 } +sdist = { url = "https://files.pythonhosted.org/packages/55/13/17e87641b89b74552ed408a92b231283786523edddc95f3545809fab673c/openai-2.24.0.tar.gz", hash = "sha256:1e5769f540dbd01cb33bc4716a23e67b9d695161a734aff9c5f925e2bf99a673", size = 658717 } wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/9e/5bfa2270f902d5b92ab7d41ce0475b8630572e71e349b2a4996d14bdda93/openai-2.30.0-py3-none-any.whl", hash = "sha256:9a5ae616888eb2748ec5e0c5b955a51592e0b201a11f4262db920f2a78c5231d", size = 1146656 }, + { url = "https://files.pythonhosted.org/packages/c9/30/844dc675ee6902579b8eef01ed23917cc9319a1c9c0c14ec6e39340c96d0/openai-2.24.0-py3-none-any.whl", hash = "sha256:fed30480d7d6c884303287bde864980a4b137b60553ffbcf9ab4a233b7a73d94", size = 1120122 }, ] [[package]] @@ -6780,11 +6780,11 @@ wheels = [ [[package]] name = "python-dotenv" -version = "1.0.1" +version = "1.2.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/bc/57/e84d88dfe0aec03b7a2d4327012c1627ab5f03652216c63d49846d7a6c58/python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca", size = 39115 } +sdist = { url = "https://files.pythonhosted.org/packages/82/ed/0301aeeac3e5353ef3d94b6ec08bbcabd04a72018415dcb29e588514bba8/python_dotenv-1.2.2.tar.gz", hash = "sha256:2c371a91fbd7ba082c2c1dc1f8bf89ca22564a087c2c287cd9b662adde799cf3", size = 50135 } wheels = [ - { url = "https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a", size = 19863 }, + { url = "https://files.pythonhosted.org/packages/0b/d7/1959b9648791274998a9c3526f6d0ec8fd2233e4d4acce81bbae76b44b2a/python_dotenv-1.2.2-py3-none-any.whl", hash = "sha256:1d8214789a24de455a8b8bd8ae6fe3c6b69a5e3d64aa8a8e5d68e694bbcb285a", size = 22101 }, ] [[package]] @@ -8070,7 +8070,7 @@ requires-dist = [ { name = "langgraph", specifier = ">=1.1.3" }, { name = "langgraph-checkpoint-postgres", specifier = ">=3.0.2" }, { name = "linkup-sdk", specifier = ">=0.2.4" }, - { name = "litellm", specifier = ">=1.83.4" }, + { name = "litellm", specifier = ">=1.83.7" }, { name = "llama-cloud-services", specifier = ">=0.6.25" }, { name = "markdown", specifier = ">=3.7" }, { name = "markdownify", specifier = ">=0.14.1" }, diff --git a/surfsense_web/components/pricing/pricing-section.tsx b/surfsense_web/components/pricing/pricing-section.tsx index 416fd8633..175cae4ab 100644 --- a/surfsense_web/components/pricing/pricing-section.tsx +++ b/surfsense_web/components/pricing/pricing-section.tsx @@ -17,7 +17,6 @@ const demoPlans = [ "Self Hostable", "500 pages included to start", "3 million premium tokens to start", - "Earn up to 3,000+ bonus pages for free", "Includes access to OpenAI text, audio and image models", "Realtime Collaborative Group Chats with teammates", "Community support on Discord", diff --git a/surfsense_web/components/settings/more-pages-content.tsx b/surfsense_web/components/settings/more-pages-content.tsx index 944f7418f..8de61b0c7 100644 --- a/surfsense_web/components/settings/more-pages-content.tsx +++ b/surfsense_web/components/settings/more-pages-content.tsx @@ -1,21 +1,14 @@ "use client"; import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; -import { Check, ExternalLink, Mail } from "lucide-react"; +import { Check, ExternalLink } from "lucide-react"; import Link from "next/link"; import { useParams } from "next/navigation"; -import { useEffect, useState } from "react"; +import { useEffect } from "react"; import { toast } from "sonner"; import { USER_QUERY_KEY } from "@/atoms/user/user-query.atoms"; import { Button } from "@/components/ui/button"; import { Card, CardContent } from "@/components/ui/card"; -import { - Dialog, - DialogContent, - DialogDescription, - DialogHeader, - DialogTitle, -} from "@/components/ui/dialog"; import { Separator } from "@/components/ui/separator"; import { Skeleton } from "@/components/ui/skeleton"; import { Spinner } from "@/components/ui/spinner"; @@ -33,7 +26,6 @@ export function MorePagesContent() { const params = useParams(); const queryClient = useQueryClient(); const searchSpaceId = params?.search_space_id ?? ""; - const [claimOpen, setClaimOpen] = useState(false); useEffect(() => { trackIncentivePageViewed(); @@ -79,35 +71,10 @@ export function MorePagesContent() { <div className="text-center"> <h2 className="text-xl font-bold tracking-tight">Get Free Pages</h2> <p className="mt-1 text-sm text-muted-foreground"> - Claim your free page offer and earn bonus pages + Earn bonus pages by completing tasks </p> </div> - {/* 3k free offer */} - <Card className="border-emerald-500/30 bg-emerald-500/5"> - <CardContent className="flex items-center gap-3 p-4"> - <div className="flex h-10 w-10 shrink-0 items-center justify-center rounded-full bg-emerald-600 text-white text-xs font-bold"> - 3k - </div> - <div className="min-w-0 flex-1"> - <p className="text-sm font-semibold">Claim 3,000 Free Pages</p> - <p className="text-xs text-muted-foreground"> - Limited offer. Schedule a meeting or email us to claim. - </p> - </div> - <Button - size="sm" - className="bg-emerald-600 text-white hover:bg-emerald-700" - onClick={() => setClaimOpen(true)} - > - Claim - </Button> - </CardContent> - </Card> - - <Separator /> - - {/* Free tasks */} <div className="space-y-2"> <h3 className="text-sm font-semibold">Earn Bonus Pages</h3> {isLoading ? ( @@ -182,7 +149,6 @@ export function MorePagesContent() { <Separator /> - {/* Link to buy pages */} <div className="text-center"> <p className="text-sm text-muted-foreground">Need more?</p> {pageBuyingEnabled ? ( @@ -197,25 +163,6 @@ export function MorePagesContent() { </p> )} </div> - - {/* Claim 3k dialog */} - <Dialog open={claimOpen} onOpenChange={setClaimOpen}> - <DialogContent className="sm:max-w-md"> - <DialogHeader> - <DialogTitle>Claim 3,000 Free Pages</DialogTitle> - <DialogDescription> - Send us an email to claim your free 3,000 pages. Include your account email and - primary usecase for free pages. - </DialogDescription> - </DialogHeader> - <Button asChild className="w-full gap-2"> - <a href="mailto:rohan@surfsense.com?subject=Claim%203%2C000%20Free%20Pages&body=Hi%2C%20I'd%20like%20to%20claim%20the%203%2C000%20free%20pages%20offer.%0A%0AMy%20account%20email%3A%20"> - <Mail className="h-4 w-4" /> - rohan@surfsense.com - </a> - </Button> - </DialogContent> - </Dialog> </div> ); } From 5dd45a5740156a96018ca560f5f0b91886879830 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 17:41:52 +0530 Subject: [PATCH 43/68] refactor(router): add router_pool_eligible filter and rebuild() API --- .../app/services/llm_router_service.py | 47 ++++++++++++++++++- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index 4bce79a43..d624ff56c 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -207,6 +207,12 @@ class LLMRouterService: """ Initialize the router with global LLM configurations. + Configs with ``router_pool_eligible=False`` are skipped so that + dynamic OpenRouter entries stay out of the shared router pool used + by title-gen / sub-agent ``model="auto"`` flows. Those dynamic + entries are still available for user-facing Auto-mode thread pinning + via ``auto_model_pin_service``. + Args: global_configs: List of global LLM config dictionaries from YAML router_settings: Optional router settings (routing_strategy, num_retries, etc.) @@ -220,6 +226,8 @@ class LLMRouterService: model_list = [] premium_models: set[str] = set() for config in global_configs: + if config.get("router_pool_eligible") is False: + continue deployment = cls._config_to_deployment(config) if deployment: model_list.append(deployment) @@ -308,10 +316,45 @@ class LLMRouterService: logger.error(f"Failed to initialize LLM Router: {e}") instance._router = None + @classmethod + def rebuild( + cls, + global_configs: list[dict], + router_settings: dict | None = None, + ) -> None: + """Reset the router and re-run ``initialize`` with fresh configs. + + ``initialize`` short-circuits once it has run to avoid re-creating the + LiteLLM Router on every request; ``rebuild`` deliberately clears + ``_initialized`` so a caller (e.g. background OpenRouter refresh) + can force the pool to be rebuilt after catalogue changes. + """ + instance = cls.get_instance() + instance._initialized = False + instance._router = None + instance._model_list = [] + instance._premium_model_strings = set() + cls.initialize(global_configs, router_settings) + @classmethod def is_premium_model(cls, model_string: str) -> bool: - """Return True if *model_string* (as reported by LiteLLM) belongs to a - premium-tier deployment in the router pool.""" + """Return True if *model_string* belongs to a premium-tier deployment + in the LiteLLM router pool. + + Scope: only covers configs with ``router_pool_eligible`` truthy. That + includes static YAML premium configs AND dynamic OpenRouter *premium* + entries (which opt in at generation time). Dynamic OpenRouter *free* + entries and the virtual ``openrouter/free`` router are deliberately + kept out of the router pool — OpenRouter enforces free-tier limits + globally per account, so per-deployment router accounting can't + represent them correctly — and therefore return ``False`` here, which + matches their ``billing_tier="free"`` (no premium quota). + + For per-request premium checks on an arbitrary config (static or + dynamic, pool or non-pool), read ``agent_config.is_premium`` instead; + that reflects the per-config ``billing_tier`` directly and is what + user-facing Auto-mode thread pinning uses to bill correctly. + """ instance = cls.get_instance() return model_string in instance._premium_model_strings From ccd7caf99f14411dffe5067cd3171357ab690808 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 17:42:21 +0530 Subject: [PATCH 44/68] feat(openrouter): derive billing tier per-model and stabilize config IDs --- .../openrouter_integration_service.py | 191 ++++++++++++++++-- 1 file changed, 173 insertions(+), 18 deletions(-) diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index 1245f73aa..2d6a42337 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -11,6 +11,7 @@ this service only manages the catalogue, not the inference path. """ import asyncio +import hashlib import logging import threading from typing import Any @@ -25,6 +26,56 @@ OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models" # dynamic OpenRouter entries from hand-written YAML entries during refresh. _OPENROUTER_DYNAMIC_MARKER = "__openrouter_dynamic__" +# Fixed negative ID for the virtual ``openrouter/free`` auto-select entry. +# Chosen to sit far below any reasonable ``id_offset`` so it never collides +# with per-model stable IDs. +_FREE_ROUTER_ID = -9_999_999 + +# Width of the hash space used by ``_stable_config_id``. 9_000_000 provides +# enough headroom to avoid frequent collisions for OpenRouter's catalogue +# (~300 models) while keeping IDs comfortably within Postgres INTEGER range. +_STABLE_ID_HASH_WIDTH = 9_000_000 + + +def _stable_config_id(model_id: str, offset: int, taken: set[int]) -> int: + """Derive a deterministic negative config ID from ``model_id``. + + The same ``model_id`` always hashes to the same base value so thread pins + survive catalogue churn (models appearing/disappearing/reordering between + refreshes). On collision we decrement until we find an unused slot; this + keeps the mapping stable for the first config that claimed a slot and + only shifts collisions, which is much less disruptive than the legacy + index-based scheme that reshuffled every ID when the catalogue changed. + """ + digest = hashlib.blake2b(model_id.encode("utf-8"), digest_size=6).digest() + base = offset - (int.from_bytes(digest, "big") % _STABLE_ID_HASH_WIDTH) + cid = base + while cid in taken: + cid -= 1 + taken.add(cid) + return cid + + +def _openrouter_tier(model: dict) -> str: + """Classify an OpenRouter model as ``"free"`` or ``"premium"``. + + Per OpenRouter's API contract, a model is free if: + - Its id ends with ``:free`` (OpenRouter's own free-variant convention), or + - Both ``pricing.prompt`` and ``pricing.completion`` are zero strings. + + Anything else (missing pricing, non-zero pricing) falls through to + ``"premium"`` so we never under-charge users. This derivation runs off the + already-cached /api/v1/models payload, so it adds no network cost. + """ + if model.get("id", "").endswith(":free"): + return "free" + pricing = model.get("pricing") or {} + prompt = str(pricing.get("prompt", "")).strip() + completion = str(pricing.get("completion", "")).strip() + if prompt == "0" and completion == "0": + return "free" + return "premium" + def _is_text_output_model(model: dict) -> bool: """Return True if the model produces text output only (skip image/audio generators).""" @@ -109,24 +160,77 @@ async def _fetch_models_async() -> list[dict] | None: return None +def _build_free_router_config(settings: dict[str, Any]) -> dict[str, Any]: + """Build the virtual ``openrouter/free`` auto-select config entry. + + This exposes OpenRouter's Free Models Router as a single selectable + option. LiteLLM forwards ``openrouter/openrouter/free`` and OpenRouter + picks a capable free model per request (availability varies, account-wide + rate limit is ~20 req/min). + """ + return { + "id": _FREE_ROUTER_ID, + "name": "OpenRouter Free (Auto-Select)", + "description": ( + "OpenRouter picks a capable free model per request. " + "~20 req/min account-wide; availability varies." + ), + "provider": "OPENROUTER", + "model_name": "openrouter/free", + "api_key": settings.get("api_key", ""), + "api_base": "", + "billing_tier": "free", + "rpm": settings.get("free_rpm", 20), + "tpm": settings.get("free_tpm", 100_000), + "anonymous_enabled": settings.get("anonymous_enabled_free", False), + "seo_enabled": False, + "seo_slug": None, + "quota_reserve_tokens": settings.get("quota_reserve_tokens", 4000), + "litellm_params": dict(settings.get("litellm_params") or {}), + "system_instructions": settings.get("system_instructions", ""), + "use_default_system_instructions": settings.get( + "use_default_system_instructions", True + ), + "citations_enabled": settings.get("citations_enabled", True), + "router_pool_eligible": False, + _OPENROUTER_DYNAMIC_MARKER: True, + } + + def _generate_configs( raw_models: list[dict], settings: dict[str, Any], ) -> list[dict]: - """ - Convert raw OpenRouter model entries into global LLM config dicts. + """Convert raw OpenRouter model entries into global LLM config dicts. - Models are sorted by ID for deterministic, stable ID assignment across - restarts and refreshes. + Tier (``billing_tier``) is derived per-model from OpenRouter's own API + signals via ``_openrouter_tier`` — there is no longer a uniform YAML + override. Config IDs are derived via ``_stable_config_id`` so they + survive catalogue churn across refreshes. + + Router-pool membership is tier-aware: + + - Premium OR models join the LiteLLM router pool (``router_pool_eligible=True``) + so sub-agent ``model="auto"`` flows benefit from load balancing and + failover across the curated YAML configs and the OR premium passthrough. + - Free OR models and the virtual ``openrouter/free`` entry stay excluded + (``router_pool_eligible=False``). LiteLLM Router tracks rate limits per + deployment, but OpenRouter enforces a single global free-tier quota + (~20 RPM + 50-1000 daily requests account-wide across every ``:free`` + model), so rotating across many free deployments would only burn the + shared bucket faster. Free OR models remain fully available for user- + facing Auto-mode thread pinning via ``auto_model_pin_service``. """ id_offset: int = settings.get("id_offset", -10000) api_key: str = settings.get("api_key", "") - billing_tier: str = settings.get("billing_tier", "premium") - anonymous_enabled: bool = settings.get("anonymous_enabled", False) seo_enabled: bool = settings.get("seo_enabled", False) quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000) rpm: int = settings.get("rpm", 200) - tpm: int = settings.get("tpm", 1000000) + tpm: int = settings.get("tpm", 1_000_000) + free_rpm: int = settings.get("free_rpm", 20) + free_tpm: int = settings.get("free_tpm", 100_000) + anon_paid: bool = settings.get("anonymous_enabled_paid", False) + anon_free: bool = settings.get("anonymous_enabled_free", False) litellm_params: dict = settings.get("litellm_params") or {} system_instructions: str = settings.get("system_instructions", "") use_default: bool = settings.get("use_default_system_instructions", True) @@ -142,19 +246,27 @@ def _generate_configs( and _is_allowed_model(m) and "/" in m.get("id", "") ] - text_models.sort(key=lambda m: m["id"]) configs: list[dict] = [] - for idx, model in enumerate(text_models): + + if settings.get("free_router_enabled", True) and api_key: + configs.append(_build_free_router_config(settings)) + + taken: set[int] = set() + if configs: + taken.add(_FREE_ROUTER_ID) + + for model in text_models: model_id: str = model["id"] name: str = model.get("name", model_id) + tier = _openrouter_tier(model) cfg: dict[str, Any] = { - "id": id_offset - idx, + "id": _stable_config_id(model_id, id_offset, taken), "name": name, "description": f"{name} via OpenRouter", - "billing_tier": billing_tier, - "anonymous_enabled": anonymous_enabled, + "billing_tier": tier, + "anonymous_enabled": anon_free if tier == "free" else anon_paid, "seo_enabled": seo_enabled, "seo_slug": None, "quota_reserve_tokens": quota_reserve_tokens, @@ -162,12 +274,18 @@ def _generate_configs( "model_name": model_id, "api_key": api_key, "api_base": "", - "rpm": rpm, - "tpm": tpm, + "rpm": free_rpm if tier == "free" else rpm, + "tpm": free_tpm if tier == "free" else tpm, "litellm_params": dict(litellm_params), "system_instructions": system_instructions, "use_default_system_instructions": use_default, "citations_enabled": citations_enabled, + # Premium OR deployments join the LiteLLM router pool so sub-agent + # model="auto" flows can load-balance / fail over across them. + # Free OR deployments stay out: OpenRouter's free tier is a single + # account-wide quota, so per-deployment routing can't spread load + # there — it just drains the shared bucket faster. + "router_pool_eligible": tier == "premium", _OPENROUTER_DYNAMIC_MARKER: True, } configs.append(cfg) @@ -220,11 +338,12 @@ class OpenRouterIntegrationService: self._configs_by_id = {c["id"]: c for c in self._configs} self._initialized = True + tier_counts = self._tier_counts(self._configs) logger.info( - "OpenRouter integration: loaded %d models (IDs %d to %d)", + "OpenRouter integration: loaded %d models (free=%d, premium=%d)", len(self._configs), - self._configs[0]["id"] if self._configs else 0, - self._configs[-1]["id"] if self._configs else 0, + tier_counts["free"], + tier_counts["premium"], ) return self._configs @@ -254,7 +373,43 @@ class OpenRouterIntegrationService: self._configs = new_configs self._configs_by_id = new_by_id - logger.info("OpenRouter refresh: updated to %d models", len(new_configs)) + tier_counts = self._tier_counts(new_configs) + logger.info( + "OpenRouter refresh: updated to %d models (free=%d, premium=%d)", + len(new_configs), + tier_counts["free"], + tier_counts["premium"], + ) + + # Rebuild the LiteLLM router so freshly fetched configs flow through + # (the router filters dynamic OR entries out of its pool, but a + # refresh still needs to pick up any static-config edits and reset + # cached context-window profiles). + try: + from app.config import config as _app_config + from app.services.llm_router_service import LLMRouterService + from app.services.llm_router_service import ( + _router_instance_cache as _chat_router_cache, + ) + + LLMRouterService.rebuild( + _app_config.GLOBAL_LLM_CONFIGS, + getattr(_app_config, "ROUTER_SETTINGS", None), + ) + _chat_router_cache.clear() + except Exception as exc: + logger.warning( + "OpenRouter refresh: router rebuild skipped (%s)", exc + ) + + @staticmethod + def _tier_counts(configs: list[dict]) -> dict[str, int]: + counts = {"free": 0, "premium": 0} + for cfg in configs: + tier = str(cfg.get("billing_tier", "")).lower() + if tier in counts: + counts[tier] += 1 + return counts async def _refresh_loop(self, interval_hours: float) -> None: interval_sec = interval_hours * 3600 From 925c33abd18424d5d0837ccea8ca0288fd5a6c44 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 17:42:44 +0530 Subject: [PATCH 45/68] chore(config): deprecate billing_tier / anonymous_enabled, split anon flags --- surfsense_backend/app/config/__init__.py | 50 ++++++++++++++++--- .../app/config/global_llm_config.example.yaml | 50 ++++++++++++++----- 2 files changed, 81 insertions(+), 19 deletions(-) diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index bd97d2bb1..11cbe24a7 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -194,6 +194,9 @@ def load_openrouter_integration_settings() -> dict | None: """ Load OpenRouter integration settings from the YAML config. + Emits startup warnings for deprecated keys (``billing_tier``, + ``anonymous_enabled``) and seeds their replacements for back-compat. + Returns: dict with settings if present and enabled, None otherwise """ @@ -206,9 +209,31 @@ def load_openrouter_integration_settings() -> dict | None: with open(global_config_file, encoding="utf-8") as f: data = yaml.safe_load(f) settings = data.get("openrouter_integration") - if settings and settings.get("enabled"): - return settings - return None + if not settings or not settings.get("enabled"): + return None + + if "billing_tier" in settings: + print( + "Warning: openrouter_integration.billing_tier is deprecated; " + "tier is now derived per model from OpenRouter data " + "(':free' suffix or zero pricing). Remove this key." + ) + + if "anonymous_enabled" in settings: + print( + "Warning: openrouter_integration.anonymous_enabled is " + "deprecated; use anonymous_enabled_paid and/or " + "anonymous_enabled_free instead. Both new flags have been " + "seeded from the legacy value for back-compat." + ) + settings.setdefault( + "anonymous_enabled_paid", settings["anonymous_enabled"] + ) + settings.setdefault( + "anonymous_enabled_free", settings["anonymous_enabled"] + ) + + return settings except Exception as e: print(f"Warning: Failed to load OpenRouter integration settings: {e}") return None @@ -217,9 +242,14 @@ def load_openrouter_integration_settings() -> dict | None: def initialize_openrouter_integration(): """ If enabled, fetch all OpenRouter models and append them to - config.GLOBAL_LLM_CONFIGS as dynamic premium entries. - Should be called BEFORE initialize_llm_router() so the router - correctly excludes premium models from Auto mode. + config.GLOBAL_LLM_CONFIGS as dynamic entries. Each model's ``billing_tier`` + is derived per-model from OpenRouter's API signals (``:free`` suffix or + zero pricing), so free OpenRouter models correctly skip premium quota. + + Should be called BEFORE initialize_llm_router(). Dynamic entries are + tagged ``router_pool_eligible=False`` so the LiteLLM Router pool (used + by title-gen / sub-agent flows) remains scoped to curated YAML configs, + while user-facing Auto-mode thread pinning still considers them. """ settings = load_openrouter_integration_settings() if not settings: @@ -235,9 +265,15 @@ def initialize_openrouter_integration(): if new_configs: config.GLOBAL_LLM_CONFIGS.extend(new_configs) + free_count = sum( + 1 for c in new_configs if c.get("billing_tier") == "free" + ) + premium_count = sum( + 1 for c in new_configs if c.get("billing_tier") == "premium" + ) print( f"Info: OpenRouter integration added {len(new_configs)} models " - f"(billing_tier={settings.get('billing_tier', 'premium')})" + f"(free={free_count}, premium={premium_count})" ) else: print("Info: OpenRouter integration enabled but no models fetched") diff --git a/surfsense_backend/app/config/global_llm_config.example.yaml b/surfsense_backend/app/config/global_llm_config.example.yaml index 9aca0f022..d62b4a4a5 100644 --- a/surfsense_backend/app/config/global_llm_config.example.yaml +++ b/surfsense_backend/app/config/global_llm_config.example.yaml @@ -245,31 +245,57 @@ global_llm_configs: # ============================================================================= # When enabled, dynamically fetches ALL available models from the OpenRouter API # and injects them as global configs. This gives premium users access to any model -# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota. +# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota, +# while free-tier OpenRouter models show up with a green Free badge and do NOT +# consume premium quota. # Models are fetched at startup and refreshed periodically in the background. # All calls go through LiteLLM with the openrouter/ prefix. openrouter_integration: enabled: false api_key: "sk-or-your-openrouter-api-key" - # billing_tier: "premium" or "free". Controls whether users need premium tokens. - billing_tier: "premium" - # anonymous_enabled: set true to also show OpenRouter models to no-login users - anonymous_enabled: false + + # Tier is derived PER MODEL from OpenRouter's own API signals: + # - id ends with ":free" -> billing_tier=free + # - pricing.prompt AND pricing.completion == "0" -> billing_tier=free + # - otherwise -> billing_tier=premium + # No global billing_tier knob is honored; any legacy value emits a startup warning. + + # Anonymous access is split by tier so operators can expose only free + # models to no-login users without leaking paid inference. + anonymous_enabled_paid: false + anonymous_enabled_free: false + seo_enabled: false # quota_reserve_tokens: tokens reserved per call for quota enforcement quota_reserve_tokens: 4000 - # id_offset: starting negative ID for dynamically generated configs. - # Must not overlap with your static global_llm_configs IDs above. + # id_offset: base negative ID for dynamically generated configs. + # Model IDs are derived deterministically via BLAKE2b so they survive + # catalogue churn. Must not overlap with your static global_llm_configs IDs. id_offset: -10000 # refresh_interval_hours: how often to re-fetch models from OpenRouter (0 = startup only) refresh_interval_hours: 24 - # rpm/tpm: Applied uniformly to all OpenRouter models for LiteLLM Router load balancing. - # OpenRouter doesn't expose per-model rate limits via API; actual throttling is handled - # upstream by OpenRouter itself (your account limits are at https://openrouter.ai/settings/limits). - # These values only matter if you set billing_tier to "free" (adding them to Auto mode). - # For premium-only models they are cosmetic. Set conservatively or match your account tier. + + # Rate limits for PAID OpenRouter models. These are used by LiteLLM Router + # for per-deployment accounting when OR premium models participate in the + # shared sub-agent "auto" pool. They do NOT cap OpenRouter itself — your + # real account limits live at https://openrouter.ai/settings/limits. rpm: 200 tpm: 1000000 + + # Rate limits for FREE OpenRouter models. Informational only: free OR + # models and openrouter/free are intentionally kept OUT of the LiteLLM + # Router pool, because OpenRouter enforces free-tier limits globally per + # account (~20 RPM + 50-1000 daily requests across every ":free" model + # combined) — per-deployment router accounting can't represent a shared + # bucket correctly. Free OR models stay fully available in the model + # selector and for user-facing Auto thread pinning. + free_rpm: 20 + free_tpm: 100000 + + # Expose openrouter/free as a single virtual "Free (Auto-Select)" entry. + # Recommended: keep true. OpenRouter picks a capable free model per request. + free_router_enabled: true + litellm_params: max_tokens: 16384 system_instructions: "" From 2019e90a04149cc491f0513d8c14f498792e2104 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 17:42:54 +0530 Subject: [PATCH 46/68] test(openrouter): cover pool filter, per-model tier, legacy config warnings --- .../services/test_llm_router_pool_filter.py | 215 ++++++++++++++++ .../test_openrouter_integration_service.py | 236 ++++++++++++++++++ .../services/test_openrouter_legacy_config.py | 110 ++++++++ 3 files changed, 561 insertions(+) create mode 100644 surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py create mode 100644 surfsense_backend/tests/unit/services/test_openrouter_integration_service.py create mode 100644 surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py diff --git a/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py b/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py new file mode 100644 index 000000000..0191025ec --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py @@ -0,0 +1,215 @@ +"""LLMRouterService pool-filter / rebuild tests. + +These tests focus on the *config plumbing* (which configs enter the router +pool, rebuild resets state correctly). They stub out the underlying +``litellm.Router`` so we don't need real API keys or network access. +""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest + +from app.services.llm_router_service import LLMRouterService + +pytestmark = pytest.mark.unit + + +def _fake_yaml_config( + *, + id: int, + model_name: str, + billing_tier: str = "free", +) -> dict: + return { + "id": id, + "name": f"yaml-{id}", + "provider": "OPENAI", + "model_name": model_name, + "api_key": "sk-test", + "api_base": "", + "billing_tier": billing_tier, + "rpm": 100, + "tpm": 100_000, + "litellm_params": {}, + } + + +def _fake_openrouter_config( + *, + id: int, + model_name: str, + billing_tier: str, + router_pool_eligible: bool | None = None, +) -> dict: + """Build a synthetic dynamic-OR config dict for router-pool tests. + + Defaults mirror Strategy 3: premium OR enters the pool, free OR stays + out. Callers can override ``router_pool_eligible`` to simulate legacy + configs or to regression-test the filter mechanics directly. + """ + if router_pool_eligible is None: + router_pool_eligible = billing_tier == "premium" + return { + "id": id, + "name": f"or-{id}", + "provider": "OPENROUTER", + "model_name": model_name, + "api_key": "sk-or-test", + "api_base": "", + "billing_tier": billing_tier, + "rpm": 20 if billing_tier == "free" else 200, + "tpm": 100_000 if billing_tier == "free" else 1_000_000, + "litellm_params": {}, + "router_pool_eligible": router_pool_eligible, + } + + +def _reset_router_singleton() -> None: + instance = LLMRouterService.get_instance() + instance._initialized = False + instance._router = None + instance._model_list = [] + instance._premium_model_strings = set() + + +def test_router_pool_includes_or_premium_excludes_or_free(): + """Strategy 3: premium OR joins the pool, free OR stays out. + + Dynamic OpenRouter premium entries opt into load balancing alongside + curated YAML configs. Dynamic OR free entries are intentionally kept + out because OpenRouter's free tier enforces a single account-global + quota bucket that per-deployment router accounting can't represent. + """ + _reset_router_singleton() + configs = [ + _fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"), + _fake_yaml_config(id=-2, model_name="gpt-4o-mini", billing_tier="free"), + _fake_openrouter_config( + id=-10_001, model_name="openai/gpt-4o", billing_tier="premium" + ), + _fake_openrouter_config( + id=-10_002, + model_name="meta-llama/llama-3.3-70b:free", + billing_tier="free", + ), + ] + + with patch("app.services.llm_router_service.Router") as mock_router, patch( + "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups" + ) as mock_ctx_fb: + mock_ctx_fb.side_effect = lambda ml: (ml, None) + mock_router.return_value = object() + LLMRouterService.initialize(configs) + + pool_models = { + dep["litellm_params"]["model"] + for dep in LLMRouterService.get_instance()._model_list + } + # YAML premium + YAML free + dynamic OR premium are all in the pool. + # Dynamic OR free is NOT (shared-bucket rate limits can't be load-balanced). + assert pool_models == { + "openai/gpt-4o", + "openai/gpt-4o-mini", + "openrouter/openai/gpt-4o", + } + + prem = LLMRouterService.get_instance()._premium_model_strings + # YAML premium is fingerprinted under both its model_string and its + # ``base_model`` form (existing behavior we don't want to regress). + assert "openai/gpt-4o" in prem + # Dynamic OR premium is now fingerprinted as premium so pool-level + # calls through the router are billed against premium quota. + assert "openrouter/openai/gpt-4o" in prem + assert LLMRouterService.is_premium_model("openrouter/openai/gpt-4o") is True + # Dynamic OR free never enters the pool, so it's never counted as premium. + assert LLMRouterService.is_premium_model( + "openrouter/meta-llama/llama-3.3-70b:free" + ) is False + + +def test_router_pool_filter_mechanics_respect_override(): + """The ``router_pool_eligible`` filter itself works independently of tier. + + Regression guard: if a future refactor ever sets the flag False on a + premium config (e.g. for maintenance), that config MUST be skipped by + ``initialize`` even though its tier is premium. + """ + _reset_router_singleton() + configs = [ + _fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"), + _fake_openrouter_config( + id=-10_001, + model_name="openai/gpt-4o", + billing_tier="premium", + router_pool_eligible=False, # opt out despite being premium + ), + ] + + with patch("app.services.llm_router_service.Router") as mock_router, patch( + "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups" + ) as mock_ctx_fb: + mock_ctx_fb.side_effect = lambda ml: (ml, None) + mock_router.return_value = object() + LLMRouterService.initialize(configs) + + pool_models = { + dep["litellm_params"]["model"] + for dep in LLMRouterService.get_instance()._model_list + } + assert pool_models == {"openai/gpt-4o"} + assert LLMRouterService.is_premium_model("openrouter/openai/gpt-4o") is False + + +def test_rebuild_refreshes_pool_after_configs_change(): + _reset_router_singleton() + configs_v1 = [ + _fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"), + ] + configs_v2 = configs_v1 + [ + _fake_yaml_config(id=-2, model_name="gpt-4o-mini", billing_tier="free"), + ] + + with patch("app.services.llm_router_service.Router") as mock_router, patch( + "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups" + ) as mock_ctx_fb: + mock_ctx_fb.side_effect = lambda ml: (ml, None) + mock_router.return_value = object() + + LLMRouterService.initialize(configs_v1) + assert len(LLMRouterService.get_instance()._model_list) == 1 + + # ``initialize`` should be a no-op here (already initialized). + LLMRouterService.initialize(configs_v2) + assert len(LLMRouterService.get_instance()._model_list) == 1 + + # ``rebuild`` must clear the guard and re-run with the new configs. + LLMRouterService.rebuild(configs_v2) + assert len(LLMRouterService.get_instance()._model_list) == 2 + + +def test_auto_model_pin_candidates_include_dynamic_openrouter(): + """Dynamic OR configs must remain Auto-mode thread-pin candidates. + + Guards against a future regression where someone adds the + ``router_pool_eligible`` filter to ``auto_model_pin_service._global_candidates``. + """ + from app.config import config + from app.services.auto_model_pin_service import _global_candidates + + or_premium = _fake_openrouter_config( + id=-10_001, model_name="openai/gpt-4o", billing_tier="premium" + ) + or_free = _fake_openrouter_config( + id=-10_002, + model_name="meta-llama/llama-3.3-70b:free", + billing_tier="free", + ) + original = config.GLOBAL_LLM_CONFIGS + try: + config.GLOBAL_LLM_CONFIGS = [or_premium, or_free] + candidate_ids = {c["id"] for c in _global_candidates()} + assert candidate_ids == {-10_001, -10_002} + finally: + config.GLOBAL_LLM_CONFIGS = original diff --git a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py new file mode 100644 index 000000000..618edc23c --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py @@ -0,0 +1,236 @@ +"""Unit tests for the dynamic OpenRouter integration.""" + +from __future__ import annotations + +import pytest + +from app.services.openrouter_integration_service import ( + _FREE_ROUTER_ID, + _OPENROUTER_DYNAMIC_MARKER, + _build_free_router_config, + _generate_configs, + _openrouter_tier, + _stable_config_id, +) + +pytestmark = pytest.mark.unit + + +def _minimal_openrouter_model( + *, + model_id: str, + pricing: dict | None = None, + name: str | None = None, +) -> dict: + """Return a synthetic OpenRouter /api/v1/models entry. + + The real API payload includes a lot of fields; we only populate what + ``_generate_configs`` actually inspects (architecture, tool support, + context, pricing, id). + """ + return { + "id": model_id, + "name": name or model_id, + "architecture": {"output_modalities": ["text"]}, + "supported_parameters": ["tools"], + "context_length": 200_000, + "pricing": pricing or {"prompt": "0.000003", "completion": "0.000015"}, + } + + +# --------------------------------------------------------------------------- +# _openrouter_tier +# --------------------------------------------------------------------------- + + +def test_openrouter_tier_free_suffix(): + assert _openrouter_tier({"id": "foo/bar:free"}) == "free" + + +def test_openrouter_tier_zero_pricing(): + model = { + "id": "foo/bar", + "pricing": {"prompt": "0", "completion": "0"}, + } + assert _openrouter_tier(model) == "free" + + +def test_openrouter_tier_paid(): + model = { + "id": "foo/bar", + "pricing": {"prompt": "0.000003", "completion": "0.000015"}, + } + assert _openrouter_tier(model) == "premium" + + +def test_openrouter_tier_missing_pricing_is_premium(): + assert _openrouter_tier({"id": "foo/bar"}) == "premium" + assert _openrouter_tier({"id": "foo/bar", "pricing": {}}) == "premium" + + +# --------------------------------------------------------------------------- +# _stable_config_id +# --------------------------------------------------------------------------- + + +def test_stable_config_id_deterministic(): + taken1: set[int] = set() + taken2: set[int] = set() + a = _stable_config_id("openai/gpt-4o", -10_000, taken1) + b = _stable_config_id("openai/gpt-4o", -10_000, taken2) + assert a == b + assert a < 0 + + +def test_stable_config_id_collision_decrements(): + """When two model_ids hash to the same slot, the second should decrement.""" + taken: set[int] = set() + a = _stable_config_id("openai/gpt-4o", -10_000, taken) + # Force a collision by pre-populating ``taken`` with a slot we know will be + # picked. + taken_forced = {a} + b = _stable_config_id("openai/gpt-4o", -10_000, taken_forced) + assert b != a + assert b == a - 1 + assert b in taken_forced + + +def test_stable_config_id_different_models_different_ids(): + taken: set[int] = set() + ids = { + _stable_config_id("openai/gpt-4o", -10_000, taken), + _stable_config_id("anthropic/claude-3.5-sonnet", -10_000, taken), + _stable_config_id("google/gemini-2.0-flash", -10_000, taken), + } + assert len(ids) == 3 + + +def test_stable_config_id_survives_catalogue_churn(): + """Removing a model should not shift other models' IDs (the bug we fix).""" + taken1: set[int] = set() + id_a1 = _stable_config_id("openai/gpt-4o", -10_000, taken1) + _ = _stable_config_id("anthropic/claude-3-haiku", -10_000, taken1) + id_c1 = _stable_config_id("google/gemini-2.0-flash", -10_000, taken1) + + taken2: set[int] = set() + id_a2 = _stable_config_id("openai/gpt-4o", -10_000, taken2) + id_c2 = _stable_config_id("google/gemini-2.0-flash", -10_000, taken2) + + assert id_a1 == id_a2 + assert id_c1 == id_c2 + + +# --------------------------------------------------------------------------- +# _generate_configs +# --------------------------------------------------------------------------- + + +_SETTINGS_BASE: dict = { + "api_key": "sk-or-test", + "id_offset": -10_000, + "rpm": 200, + "tpm": 1_000_000, + "free_rpm": 20, + "free_tpm": 100_000, + "anonymous_enabled_paid": False, + "anonymous_enabled_free": True, + "quota_reserve_tokens": 4000, + "free_router_enabled": False, +} + + +def test_generate_configs_respects_tier(): + """Premium OR models opt into the router pool; free OR models stay out. + + Strategy-3 split: premium participates in LiteLLM Router load balancing, + free stays excluded because OpenRouter enforces a shared global free-tier + bucket that per-deployment router accounting can't represent. + """ + raw = [ + _minimal_openrouter_model(model_id="openai/gpt-4o"), + _minimal_openrouter_model( + model_id="meta-llama/llama-3.3-70b-instruct:free", + pricing={"prompt": "0", "completion": "0"}, + ), + ] + cfgs = _generate_configs(raw, dict(_SETTINGS_BASE)) + by_model = {c["model_name"]: c for c in cfgs} + + paid = by_model["openai/gpt-4o"] + assert paid["billing_tier"] == "premium" + assert paid["rpm"] == 200 + assert paid["tpm"] == 1_000_000 + assert paid["anonymous_enabled"] is False + assert paid["router_pool_eligible"] is True + assert paid[_OPENROUTER_DYNAMIC_MARKER] is True + + free = by_model["meta-llama/llama-3.3-70b-instruct:free"] + assert free["billing_tier"] == "free" + assert free["rpm"] == 20 + assert free["tpm"] == 100_000 + assert free["anonymous_enabled"] is True + assert free["router_pool_eligible"] is False + + +def test_generate_configs_includes_free_router_when_enabled(): + raw = [_minimal_openrouter_model(model_id="openai/gpt-4o")] + settings = {**_SETTINGS_BASE, "free_router_enabled": True} + cfgs = _generate_configs(raw, settings) + free_router = next( + (c for c in cfgs if c["model_name"] == "openrouter/free"), None + ) + assert free_router is not None + assert free_router["id"] == _FREE_ROUTER_ID + assert free_router["billing_tier"] == "free" + assert free_router["router_pool_eligible"] is False + assert free_router["anonymous_enabled"] is True + + +def test_generate_configs_excludes_free_router_when_disabled(): + raw = [_minimal_openrouter_model(model_id="openai/gpt-4o")] + settings = {**_SETTINGS_BASE, "free_router_enabled": False} + cfgs = _generate_configs(raw, settings) + assert not any(c["model_name"] == "openrouter/free" for c in cfgs) + + +def test_generate_configs_excludes_free_router_without_api_key(): + """Without an API key the free-router entry is useless; skip it.""" + raw = [_minimal_openrouter_model(model_id="openai/gpt-4o")] + settings = {**_SETTINGS_BASE, "free_router_enabled": True, "api_key": ""} + cfgs = _generate_configs(raw, settings) + assert not any(c["model_name"] == "openrouter/free" for c in cfgs) + + +def test_generate_configs_drops_non_text_and_non_tool_models(): + raw = [ + _minimal_openrouter_model(model_id="openai/gpt-4o"), + { # image-output model + "id": "openai/dall-e", + "architecture": {"output_modalities": ["image"]}, + "supported_parameters": ["tools"], + "context_length": 200_000, + "pricing": {"prompt": "0.01", "completion": "0.01"}, + }, + { # text but no tool calling + "id": "openai/completion-only", + "architecture": {"output_modalities": ["text"]}, + "supported_parameters": [], + "context_length": 200_000, + "pricing": {"prompt": "0.01", "completion": "0.01"}, + }, + ] + cfgs = _generate_configs(raw, dict(_SETTINGS_BASE)) + model_names = [c["model_name"] for c in cfgs] + assert "openai/gpt-4o" in model_names + assert "openai/dall-e" not in model_names + assert "openai/completion-only" not in model_names + + +def test_build_free_router_config_shape(): + cfg = _build_free_router_config(dict(_SETTINGS_BASE)) + assert cfg["provider"] == "OPENROUTER" + assert cfg["model_name"] == "openrouter/free" + assert cfg["id"] == _FREE_ROUTER_ID + assert cfg["billing_tier"] == "free" + assert cfg["router_pool_eligible"] is False + assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True diff --git a/surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py b/surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py new file mode 100644 index 000000000..b3dd2bf18 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py @@ -0,0 +1,110 @@ +"""Tests for deprecated-key warnings and back-compat in +``load_openrouter_integration_settings``. +""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +pytestmark = pytest.mark.unit + + +def _write_yaml(tmp_path: Path, body: str) -> Path: + cfg_dir = tmp_path / "app" / "config" + cfg_dir.mkdir(parents=True) + cfg_path = cfg_dir / "global_llm_config.yaml" + cfg_path.write_text(body, encoding="utf-8") + return cfg_path + + +def _patch_base_dir(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + from app import config as config_module + + monkeypatch.setattr(config_module, "BASE_DIR", tmp_path) + + +def test_legacy_billing_tier_emits_warning(monkeypatch, tmp_path, capsys): + _write_yaml( + tmp_path, + """ +openrouter_integration: + enabled: true + api_key: "sk-or-test" + billing_tier: "premium" +""".lstrip(), + ) + _patch_base_dir(monkeypatch, tmp_path) + + from app.config import load_openrouter_integration_settings + + settings = load_openrouter_integration_settings() + captured = capsys.readouterr().out + assert settings is not None + assert "billing_tier is deprecated" in captured + + +def test_legacy_anonymous_enabled_back_compat(monkeypatch, tmp_path, capsys): + _write_yaml( + tmp_path, + """ +openrouter_integration: + enabled: true + api_key: "sk-or-test" + anonymous_enabled: true +""".lstrip(), + ) + _patch_base_dir(monkeypatch, tmp_path) + + from app.config import load_openrouter_integration_settings + + settings = load_openrouter_integration_settings() + captured = capsys.readouterr().out + assert settings is not None + assert settings["anonymous_enabled_paid"] is True + assert settings["anonymous_enabled_free"] is True + assert "anonymous_enabled is" in captured + assert "deprecated" in captured + + +def test_new_keys_take_priority_over_legacy_back_compat( + monkeypatch, tmp_path, capsys +): + """If both legacy and new keys are present, new keys win (setdefault).""" + _write_yaml( + tmp_path, + """ +openrouter_integration: + enabled: true + api_key: "sk-or-test" + anonymous_enabled: true + anonymous_enabled_paid: false + anonymous_enabled_free: false +""".lstrip(), + ) + _patch_base_dir(monkeypatch, tmp_path) + + from app.config import load_openrouter_integration_settings + + settings = load_openrouter_integration_settings() + capsys.readouterr() + assert settings is not None + assert settings["anonymous_enabled_paid"] is False + assert settings["anonymous_enabled_free"] is False + + +def test_disabled_integration_returns_none(monkeypatch, tmp_path): + _write_yaml( + tmp_path, + """ +openrouter_integration: + enabled: false + api_key: "sk-or-test" +""".lstrip(), + ) + _patch_base_dir(monkeypatch, tmp_path) + + from app.config import load_openrouter_integration_settings + + assert load_openrouter_integration_settings() is None From 4d34b56c4da4e3a935eaaa1b6cb6321597088802 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 18:09:50 +0530 Subject: [PATCH 47/68] docs(router): drop reference to virtual openrouter/free in is_premium_model --- surfsense_backend/app/services/llm_router_service.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index d624ff56c..060e01675 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -344,11 +344,11 @@ class LLMRouterService: Scope: only covers configs with ``router_pool_eligible`` truthy. That includes static YAML premium configs AND dynamic OpenRouter *premium* entries (which opt in at generation time). Dynamic OpenRouter *free* - entries and the virtual ``openrouter/free`` router are deliberately - kept out of the router pool — OpenRouter enforces free-tier limits - globally per account, so per-deployment router accounting can't - represent them correctly — and therefore return ``False`` here, which - matches their ``billing_tier="free"`` (no premium quota). + entries are deliberately kept out of the router pool — OpenRouter + enforces free-tier limits globally per account, so per-deployment + router accounting can't represent them correctly — and therefore + return ``False`` here, which matches their ``billing_tier="free"`` + (no premium quota). For per-request premium checks on an arbitrary config (static or dynamic, pool or non-pool), read ``agent_config.is_premium`` instead; From 680a1c1c38d090c54f790adbdf35e6beed5d7566 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 18:16:47 +0530 Subject: [PATCH 48/68] refactor(openrouter): remove virtual openrouter/free auto-select entry --- .../app/config/global_llm_config.example.yaml | 16 ++-- .../openrouter_integration_service.py | 78 +++++-------------- .../test_openrouter_integration_service.py | 56 +++++-------- 3 files changed, 45 insertions(+), 105 deletions(-) diff --git a/surfsense_backend/app/config/global_llm_config.example.yaml b/surfsense_backend/app/config/global_llm_config.example.yaml index d62b4a4a5..79cbe1e51 100644 --- a/surfsense_backend/app/config/global_llm_config.example.yaml +++ b/surfsense_backend/app/config/global_llm_config.example.yaml @@ -283,19 +283,15 @@ openrouter_integration: tpm: 1000000 # Rate limits for FREE OpenRouter models. Informational only: free OR - # models and openrouter/free are intentionally kept OUT of the LiteLLM - # Router pool, because OpenRouter enforces free-tier limits globally per - # account (~20 RPM + 50-1000 daily requests across every ":free" model - # combined) — per-deployment router accounting can't represent a shared - # bucket correctly. Free OR models stay fully available in the model - # selector and for user-facing Auto thread pinning. + # models are intentionally kept OUT of the LiteLLM Router pool, because + # OpenRouter enforces free-tier limits globally per account (~20 RPM + + # 50-1000 daily requests across every ":free" model combined) — + # per-deployment router accounting can't represent a shared bucket + # correctly. Free OR models stay fully available in the model selector + # and for user-facing Auto thread pinning. free_rpm: 20 free_tpm: 100000 - # Expose openrouter/free as a single virtual "Free (Auto-Select)" entry. - # Recommended: keep true. OpenRouter picks a capable free model per request. - free_router_enabled: true - litellm_params: max_tokens: 16384 system_instructions: "" diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index 2d6a42337..06b7becdc 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -26,11 +26,6 @@ OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models" # dynamic OpenRouter entries from hand-written YAML entries during refresh. _OPENROUTER_DYNAMIC_MARKER = "__openrouter_dynamic__" -# Fixed negative ID for the virtual ``openrouter/free`` auto-select entry. -# Chosen to sit far below any reasonable ``id_offset`` so it never collides -# with per-model stable IDs. -_FREE_ROUTER_ID = -9_999_999 - # Width of the hash space used by ``_stable_config_id``. 9_000_000 provides # enough headroom to avoid frequent collisions for OpenRouter's catalogue # (~300 models) while keeping IDs comfortably within Postgres INTEGER range. @@ -107,6 +102,11 @@ _EXCLUDED_MODEL_IDS: set[str] = { # Deep-research models reject standard params (temperature, etc.) "openai/o3-deep-research", "openai/o4-mini-deep-research", + # OpenRouter's own meta-router over free models. We already enumerate every + # concrete ``:free`` model into GLOBAL_LLM_CONFIGS and Auto-mode thread + # pinning handles churn via the repair path, so exposing an additional + # indirection layer would only duplicate the capability with an opaque slug. + "openrouter/free", } _EXCLUDED_MODEL_SUFFIXES: tuple[str, ...] = ("-deep-research",) @@ -160,43 +160,6 @@ async def _fetch_models_async() -> list[dict] | None: return None -def _build_free_router_config(settings: dict[str, Any]) -> dict[str, Any]: - """Build the virtual ``openrouter/free`` auto-select config entry. - - This exposes OpenRouter's Free Models Router as a single selectable - option. LiteLLM forwards ``openrouter/openrouter/free`` and OpenRouter - picks a capable free model per request (availability varies, account-wide - rate limit is ~20 req/min). - """ - return { - "id": _FREE_ROUTER_ID, - "name": "OpenRouter Free (Auto-Select)", - "description": ( - "OpenRouter picks a capable free model per request. " - "~20 req/min account-wide; availability varies." - ), - "provider": "OPENROUTER", - "model_name": "openrouter/free", - "api_key": settings.get("api_key", ""), - "api_base": "", - "billing_tier": "free", - "rpm": settings.get("free_rpm", 20), - "tpm": settings.get("free_tpm", 100_000), - "anonymous_enabled": settings.get("anonymous_enabled_free", False), - "seo_enabled": False, - "seo_slug": None, - "quota_reserve_tokens": settings.get("quota_reserve_tokens", 4000), - "litellm_params": dict(settings.get("litellm_params") or {}), - "system_instructions": settings.get("system_instructions", ""), - "use_default_system_instructions": settings.get( - "use_default_system_instructions", True - ), - "citations_enabled": settings.get("citations_enabled", True), - "router_pool_eligible": False, - _OPENROUTER_DYNAMIC_MARKER: True, - } - - def _generate_configs( raw_models: list[dict], settings: dict[str, Any], @@ -213,13 +176,18 @@ def _generate_configs( - Premium OR models join the LiteLLM router pool (``router_pool_eligible=True``) so sub-agent ``model="auto"`` flows benefit from load balancing and failover across the curated YAML configs and the OR premium passthrough. - - Free OR models and the virtual ``openrouter/free`` entry stay excluded - (``router_pool_eligible=False``). LiteLLM Router tracks rate limits per - deployment, but OpenRouter enforces a single global free-tier quota - (~20 RPM + 50-1000 daily requests account-wide across every ``:free`` - model), so rotating across many free deployments would only burn the - shared bucket faster. Free OR models remain fully available for user- - facing Auto-mode thread pinning via ``auto_model_pin_service``. + - Free OR models stay excluded (``router_pool_eligible=False``). LiteLLM + Router tracks rate limits per deployment, but OpenRouter enforces a + single global free-tier quota (~20 RPM + 50-1000 daily requests + account-wide across every ``:free`` model), so rotating across many + free deployments would only burn the shared bucket faster. Free OR + models remain fully available for user-facing Auto-mode thread pinning + via ``auto_model_pin_service``. + + OpenRouter's own ``openrouter/free`` meta-router is filtered out upstream + via ``_EXCLUDED_MODEL_IDS``; we don't expose a redundant auto-select layer + because our own Auto (Fastest) pin + 24 h refresh + repair logic already + cover the catalogue-churn case. """ id_offset: int = settings.get("id_offset", -10000) api_key: str = settings.get("api_key", "") @@ -248,13 +216,7 @@ def _generate_configs( ] configs: list[dict] = [] - - if settings.get("free_router_enabled", True) and api_key: - configs.append(_build_free_router_config(settings)) - taken: set[int] = set() - if configs: - taken.add(_FREE_ROUTER_ID) for model in text_models: model_id: str = model["id"] @@ -382,9 +344,9 @@ class OpenRouterIntegrationService: ) # Rebuild the LiteLLM router so freshly fetched configs flow through - # (the router filters dynamic OR entries out of its pool, but a - # refresh still needs to pick up any static-config edits and reset - # cached context-window profiles). + # (dynamic OR premium entries now opt into the pool, free ones stay + # out; a refresh also needs to pick up any static-config edits and + # reset cached context-window profiles). try: from app.config import config as _app_config from app.services.llm_router_service import LLMRouterService diff --git a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py index 618edc23c..d3921729d 100644 --- a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py +++ b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py @@ -5,9 +5,7 @@ from __future__ import annotations import pytest from app.services.openrouter_integration_service import ( - _FREE_ROUTER_ID, _OPENROUTER_DYNAMIC_MARKER, - _build_free_router_config, _generate_configs, _openrouter_tier, _stable_config_id, @@ -135,7 +133,6 @@ _SETTINGS_BASE: dict = { "anonymous_enabled_paid": False, "anonymous_enabled_free": True, "quota_reserve_tokens": 4000, - "free_router_enabled": False, } @@ -172,33 +169,26 @@ def test_generate_configs_respects_tier(): assert free["router_pool_eligible"] is False -def test_generate_configs_includes_free_router_when_enabled(): - raw = [_minimal_openrouter_model(model_id="openai/gpt-4o")] - settings = {**_SETTINGS_BASE, "free_router_enabled": True} - cfgs = _generate_configs(raw, settings) - free_router = next( - (c for c in cfgs if c["model_name"] == "openrouter/free"), None - ) - assert free_router is not None - assert free_router["id"] == _FREE_ROUTER_ID - assert free_router["billing_tier"] == "free" - assert free_router["router_pool_eligible"] is False - assert free_router["anonymous_enabled"] is True +def test_generate_configs_excludes_upstream_openrouter_free_router(): + """OpenRouter's own ``openrouter/free`` meta-router must never become a card. - -def test_generate_configs_excludes_free_router_when_disabled(): - raw = [_minimal_openrouter_model(model_id="openai/gpt-4o")] - settings = {**_SETTINGS_BASE, "free_router_enabled": False} - cfgs = _generate_configs(raw, settings) - assert not any(c["model_name"] == "openrouter/free" for c in cfgs) - - -def test_generate_configs_excludes_free_router_without_api_key(): - """Without an API key the free-router entry is useless; skip it.""" - raw = [_minimal_openrouter_model(model_id="openai/gpt-4o")] - settings = {**_SETTINGS_BASE, "free_router_enabled": True, "api_key": ""} - cfgs = _generate_configs(raw, settings) - assert not any(c["model_name"] == "openrouter/free" for c in cfgs) + The upstream API returns this as a first-class zero-priced model, so + without an explicit blocklist entry it would slip through every other + filter (text output, tool calling, 200k context, non-Amazon) and land + in the selector as a duplicate of the concrete ``:free`` cards. The + exclusion in ``_EXCLUDED_MODEL_IDS`` prevents that. + """ + raw = [ + _minimal_openrouter_model(model_id="openai/gpt-4o"), + _minimal_openrouter_model( + model_id="openrouter/free", + pricing={"prompt": "0", "completion": "0"}, + ), + ] + cfgs = _generate_configs(raw, dict(_SETTINGS_BASE)) + model_names = {c["model_name"] for c in cfgs} + assert "openrouter/free" not in model_names + assert "openai/gpt-4o" in model_names def test_generate_configs_drops_non_text_and_non_tool_models(): @@ -226,11 +216,3 @@ def test_generate_configs_drops_non_text_and_non_tool_models(): assert "openai/completion-only" not in model_names -def test_build_free_router_config_shape(): - cfg = _build_free_router_config(dict(_SETTINGS_BASE)) - assert cfg["provider"] == "OPENROUTER" - assert cfg["model_name"] == "openrouter/free" - assert cfg["id"] == _FREE_ROUTER_ID - assert cfg["billing_tier"] == "free" - assert cfg["router_pool_eligible"] is False - assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True From 1863f2832b203d101a159653ff2198d59b93ddfc Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 18:43:45 +0530 Subject: [PATCH 49/68] fix(LayoutShell): add 'isolate' class to main content panel --- surfsense_web/components/layout/ui/shell/LayoutShell.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/surfsense_web/components/layout/ui/shell/LayoutShell.tsx b/surfsense_web/components/layout/ui/shell/LayoutShell.tsx index d41dd9e6d..207d27f7b 100644 --- a/surfsense_web/components/layout/ui/shell/LayoutShell.tsx +++ b/surfsense_web/components/layout/ui/shell/LayoutShell.tsx @@ -132,7 +132,7 @@ function MainContentPanel({ const isDocumentTab = activeTab?.type === "document"; return ( - <div className="relative flex flex-1 flex-col min-w-0"> + <div className="relative isolate flex flex-1 flex-col min-w-0"> <TabBar onTabSwitch={onTabSwitch} onNewChat={onNewChat} From 421a4d7d0807f17da7c29290e96e96ff73cc5e72 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 19:32:42 +0530 Subject: [PATCH 50/68] refactor(auto_model_pin): simplify thread-level pinning by removing unused fields and indexes --- ...38_add_thread_auto_model_pinning_fields.py | 31 +++++----------- surfsense_backend/app/db.py | 13 +++---- .../app/routes/search_spaces_routes.py | 6 +-- .../app/services/auto_model_pin_service.py | 37 ++++++++----------- .../services/test_auto_model_pin_service.py | 28 +++----------- 5 files changed, 37 insertions(+), 78 deletions(-) diff --git a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py index 3972b84b9..fba621a0c 100644 --- a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py +++ b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py @@ -4,10 +4,12 @@ Revision ID: 138 Revises: 137 Create Date: 2026-04-30 -Add thread-level fields to persist Auto (Fastest) model pinning metadata: -- pinned_llm_config_id: concrete resolved config id used for this thread -- pinned_auto_mode: auto policy identifier (currently "auto_fastest") -- pinned_at: timestamp when the pin was created/refreshed +Add a single thread-level column to persist the Auto (Fastest) model pin: +- pinned_llm_config_id: concrete resolved global LLM config id used for this + thread. NULL means "no pin; Auto will resolve on next turn". + +The column is unindexed: all reads are by new_chat_threads.id (primary key), +so a secondary index would be dead write amplification. """ from __future__ import annotations @@ -27,29 +29,14 @@ def upgrade() -> None: "ALTER TABLE new_chat_threads " "ADD COLUMN IF NOT EXISTS pinned_llm_config_id INTEGER" ) - op.execute( - "ALTER TABLE new_chat_threads " - "ADD COLUMN IF NOT EXISTS pinned_auto_mode VARCHAR(32)" - ) - op.execute( - "ALTER TABLE new_chat_threads " - "ADD COLUMN IF NOT EXISTS pinned_at TIMESTAMP WITH TIME ZONE" - ) - - op.execute( - "CREATE INDEX IF NOT EXISTS ix_new_chat_threads_pinned_llm_config_id " - "ON new_chat_threads (pinned_llm_config_id)" - ) - op.execute( - "CREATE INDEX IF NOT EXISTS ix_new_chat_threads_pinned_auto_mode " - "ON new_chat_threads (pinned_auto_mode)" - ) def downgrade() -> None: + # Drop any shape the thread row may be carrying. The extra columns and + # indexes only exist on dev DBs that ran an earlier draft of 138; IF EXISTS + # makes each statement a safe no-op on the lean shape. op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_auto_mode") op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_llm_config_id") - op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_at") op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_auto_mode") op.execute( diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index ca3334f8b..2fe478d9b 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -638,13 +638,12 @@ class NewChatThread(BaseModel, TimestampMixin): default=False, server_default="false", ) - # Auto model pinning metadata: - # - pinned_llm_config_id stores the concrete resolved model config id. - # - pinned_auto_mode indicates which auto policy produced the pin. - # This allows Auto (Fastest) to resolve once per thread and stay stable. - pinned_llm_config_id = Column(Integer, nullable=True, index=True) - pinned_auto_mode = Column(String(32), nullable=True, index=True) - pinned_at = Column(TIMESTAMP(timezone=True), nullable=True) + # Auto (Fastest) model pin for this thread: concrete resolved global LLM + # config id. NULL means no pin; Auto will resolve on the next turn. + # Single-writer invariant: only app.services.auto_model_pin_service sets + # or clears this column (plus bulk clears when a search space's + # agent_llm_id changes). Unindexed: all reads are by primary key. + pinned_llm_config_id = Column(Integer, nullable=True) # Relationships search_space = relationship("SearchSpace", back_populates="new_chat_threads") diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 7944e7d66..72715ea5b 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -803,11 +803,7 @@ async def update_llm_preferences( await session.execute( update(NewChatThread) .where(NewChatThread.search_space_id == search_space_id) - .values( - pinned_llm_config_id=None, - pinned_auto_mode=None, - pinned_at=None, - ) + .values(pinned_llm_config_id=None) ) logger.info( "Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)", diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index 6b69c91ea..1a2061492 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -2,8 +2,14 @@ Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we resolve that virtual mode to one concrete global LLM config exactly once and -persist the chosen config id on ``new_chat_threads`` so subsequent turns are -stable. +persist the chosen config id on ``new_chat_threads.pinned_llm_config_id`` so +subsequent turns are stable. + +Single-writer invariant: this module is the only writer of +``NewChatThread.pinned_llm_config_id`` (aside from the bulk clear in +``search_spaces_routes`` when a search space's ``agent_llm_id`` changes). +Therefore a non-NULL value unambiguously means "this thread has an +Auto-resolved pin"; no separate source/policy column is needed. """ from __future__ import annotations @@ -11,7 +17,6 @@ from __future__ import annotations import hashlib import logging from dataclasses import dataclass -from datetime import UTC, datetime from uuid import UUID from sqlalchemy import select @@ -90,10 +95,10 @@ async def resolve_or_get_pinned_llm_config_id( selected_llm_config_id: int, force_repin_free: bool = False, ) -> AutoPinResolution: - """Resolve Auto (Fastest) to one concrete config id and persist pin metadata. + """Resolve Auto (Fastest) to one concrete config id and persist the pin. - For non-auto selections, this function clears existing auto pin metadata and - returns the selected id as-is. + For non-auto selections, this function clears any existing pin and returns + the selected id as-is. """ thread = ( ( @@ -113,16 +118,10 @@ async def resolve_or_get_pinned_llm_config_id( f"Thread {thread_id} does not belong to search space {search_space_id}" ) - # Explicit model selected: clear stale auto pin metadata. + # Explicit model selected: clear any stale pin. if selected_llm_config_id != AUTO_FASTEST_ID: - if ( - thread.pinned_llm_config_id is not None - or thread.pinned_auto_mode is not None - or thread.pinned_at is not None - ): + if thread.pinned_llm_config_id is not None: thread.pinned_llm_config_id = None - thread.pinned_auto_mode = None - thread.pinned_at = None await session.commit() return AutoPinResolution( resolved_llm_config_id=selected_llm_config_id, @@ -135,12 +134,11 @@ async def resolve_or_get_pinned_llm_config_id( raise ValueError("No usable global LLM configs are available for Auto mode") candidate_by_id = {int(c["id"]): c for c in candidates} - # Reuse existing valid pin without re-checking current quota (no silent tier switch), - # unless the caller explicitly requests a forced repin to free. + # Reuse an existing valid pin without re-checking current quota (no silent + # tier switch), unless the caller explicitly requests a forced repin to free. pinned_id = thread.pinned_llm_config_id if ( not force_repin_free - and thread.pinned_auto_mode == AUTO_FASTEST_MODE and pinned_id is not None and int(pinned_id) in candidate_by_id ): @@ -159,11 +157,10 @@ async def resolve_or_get_pinned_llm_config_id( ) if pinned_id is not None: logger.info( - "auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s pinned_auto_mode=%s", + "auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s", thread_id, search_space_id, pinned_id, - thread.pinned_auto_mode, ) premium_eligible = ( @@ -184,8 +181,6 @@ async def resolve_or_get_pinned_llm_config_id( selected_tier = _tier_of(selected_cfg) thread.pinned_llm_config_id = selected_id - thread.pinned_auto_mode = AUTO_FASTEST_MODE - thread.pinned_at = datetime.now(UTC) await session.commit() if force_repin_free: diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index 0a2342e05..2094ea6dd 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -6,7 +6,6 @@ from types import SimpleNamespace import pytest from app.services.auto_model_pin_service import ( - AUTO_FASTEST_MODE, resolve_or_get_pinned_llm_config_id, ) @@ -45,14 +44,11 @@ def _thread( *, search_space_id: int = 10, pinned_llm_config_id: int | None = None, - pinned_auto_mode: str | None = None, ): return SimpleNamespace( id=1, search_space_id=search_space_id, pinned_llm_config_id=pinned_llm_config_id, - pinned_auto_mode=pinned_auto_mode, - pinned_at=None, ) @@ -93,8 +89,6 @@ async def test_auto_first_turn_pins_one_model(monkeypatch): ) assert result.resolved_llm_config_id in {-1, -2} assert session.thread.pinned_llm_config_id == result.resolved_llm_config_id - assert session.thread.pinned_auto_mode == AUTO_FASTEST_MODE - assert session.thread.pinned_at is not None assert session.commit_count == 1 @@ -102,9 +96,7 @@ async def test_auto_first_turn_pins_one_model(monkeypatch): async def test_next_turn_reuses_existing_pin(monkeypatch): from app.config import config - session = _FakeSession( - _thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE) - ) + session = _FakeSession(_thread(pinned_llm_config_id=-1)) monkeypatch.setattr( config, "GLOBAL_LLM_CONFIGS", @@ -228,9 +220,7 @@ async def test_premium_ineligible_auto_pins_free_only(monkeypatch): async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): from app.config import config - session = _FakeSession( - _thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE) - ) + session = _FakeSession(_thread(pinned_llm_config_id=-1)) monkeypatch.setattr( config, "GLOBAL_LLM_CONFIGS", @@ -275,9 +265,7 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch): from app.config import config - session = _FakeSession( - _thread(pinned_llm_config_id=-1, pinned_auto_mode=AUTO_FASTEST_MODE) - ) + session = _FakeSession(_thread(pinned_llm_config_id=-1)) monkeypatch.setattr( config, "GLOBAL_LLM_CONFIGS", @@ -325,9 +313,7 @@ async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch): async def test_explicit_user_model_change_clears_pin(monkeypatch): from app.config import config - session = _FakeSession( - _thread(pinned_llm_config_id=-2, pinned_auto_mode=AUTO_FASTEST_MODE) - ) + session = _FakeSession(_thread(pinned_llm_config_id=-2)) monkeypatch.setattr( config, "GLOBAL_LLM_CONFIGS", @@ -345,8 +331,6 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch): ) assert result.resolved_llm_config_id == 7 assert session.thread.pinned_llm_config_id is None - assert session.thread.pinned_auto_mode is None - assert session.thread.pinned_at is None assert session.commit_count == 1 @@ -354,9 +338,7 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch): async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): from app.config import config - session = _FakeSession( - _thread(pinned_llm_config_id=-999, pinned_auto_mode=AUTO_FASTEST_MODE) - ) + session = _FakeSession(_thread(pinned_llm_config_id=-999)) monkeypatch.setattr( config, "GLOBAL_LLM_CONFIGS", From d9058b73f5306f6dc40ba553cec92cf659246d1a Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 23:37:49 +0530 Subject: [PATCH 51/68] feat(auto_pin): add pure-function quality scoring module --- .../app/services/quality_score.py | 382 ++++++++++++++++++ .../tests/unit/services/test_quality_score.py | 342 ++++++++++++++++ 2 files changed, 724 insertions(+) create mode 100644 surfsense_backend/app/services/quality_score.py create mode 100644 surfsense_backend/tests/unit/services/test_quality_score.py diff --git a/surfsense_backend/app/services/quality_score.py b/surfsense_backend/app/services/quality_score.py new file mode 100644 index 000000000..8f6c75d56 --- /dev/null +++ b/surfsense_backend/app/services/quality_score.py @@ -0,0 +1,382 @@ +"""Pure-function quality scoring for Auto (Fastest) model selection. + +This module is import-free of any service / request-path dependencies. All +numbers are computed once during the OpenRouter refresh tick (or YAML load) +and cached on the cfg dict, so the chat hot path only does a precomputed +sort and a SHA256 pick. + +Score components (0-100 scale, higher is better): + +* ``static_score_or`` – derived from the bulk ``/api/v1/models`` payload + (provider prestige + ``created`` recency + pricing band + context window + + capabilities + narrow tiny/legacy slug penalty). +* ``static_score_yaml`` – same shape for hand-curated YAML configs, plus + an operator-trust bonus (the operator deliberately picked this model). +* ``aggregate_health`` – run on per-model ``/api/v1/models/{id}/endpoints`` + responses; returns ``(gated, score_or_none)``. + +The blended ``quality_score`` (0.5 * static + 0.5 * health) is computed in +:mod:`app.services.openrouter_integration_service` because that's the only +caller that sees both halves. +""" + +from __future__ import annotations + +# --------------------------------------------------------------------------- +# Tunables (constants, not flags) +# --------------------------------------------------------------------------- + +# Top-K size for deterministic spread inside the locked tier. +_QUALITY_TOP_K: int = 5 + +# Hard health gate: any cfg whose best non-null uptime is below this % +# is excluded from Auto-mode selection entirely. +_HEALTH_GATE_UPTIME_PCT: float = 90.0 + +# Health/static blend weight when a cfg has fresh /endpoints data. +_HEALTH_BLEND_WEIGHT: float = 0.5 + +# Static bonus applied to YAML cfgs because the operator hand-picked them. +_OPERATOR_TRUST_BONUS: int = 20 + +# /endpoints fan-out is bounded per refresh tick. +_HEALTH_ENRICH_TOP_N_PREMIUM: int = 50 +_HEALTH_ENRICH_TOP_N_FREE: int = 30 +_HEALTH_ENRICH_CONCURRENCY: int = 15 +_HEALTH_FETCH_TIMEOUT_SEC: float = 5.0 + +# If at least this fraction of /endpoints fetches fail in a refresh cycle, +# fall back to the previous cycle's last-good cache instead of writing +# partial / stale health values. +_HEALTH_FAIL_RATIO_FALLBACK: float = 0.25 + +# Narrow tiny/legacy slug penalties only. We deliberately do NOT penalise +# ``-nano`` / ``-mini`` / ``-lite`` because modern frontier models ship with +# those naming patterns (``gpt-5-mini``, ``gemini-2.5-flash-lite`` etc.) and +# blanket-penalising them suppresses high-quality picks. +_TINY_LEGACY_PENALTY_PATTERNS: tuple[str, ...] = ( + "-1b-", + "-1.2b-", + "-1.5b-", + "-2b-", + "-3b-", + "gemma-3n", + "lfm-", + "-base", + "-distill", + ":nitro", + "-preview", +) + + +# --------------------------------------------------------------------------- +# Provider prestige tables +# --------------------------------------------------------------------------- + +# OpenRouter-side provider slug (the prefix before ``/`` in the model id). +# Tiers are coarse: frontier labs > strong open / fast-moving labs > +# specialist labs > everything else. +PROVIDER_PRESTIGE_OR: dict[str, int] = { + # Frontier labs + "openai": 50, + "anthropic": 50, + "google": 50, + "x-ai": 50, + # Strong open / fast-moving labs + "deepseek": 38, + "qwen": 38, + "meta-llama": 38, + "mistralai": 38, + "cohere": 38, + "nvidia": 38, + "alibaba": 38, + # Specialist / regional / strong second-tier + "microsoft": 28, + "01-ai": 28, + "minimax": 28, + "moonshot": 28, + "z-ai": 28, + "nousresearch": 28, + "ai21": 28, + "perplexity": 28, + # Smaller / niche providers + "liquid": 18, + "cognitivecomputations": 18, + "venice": 18, + "inflection": 18, +} + +# YAML provider field (the upstream API shape the operator selected). +PROVIDER_PRESTIGE_YAML: dict[str, int] = { + "AZURE_OPENAI": 50, + "OPENAI": 50, + "ANTHROPIC": 50, + "GOOGLE": 50, + "VERTEX_AI": 50, + "GEMINI": 50, + "XAI": 50, + "MISTRAL": 38, + "DEEPSEEK": 38, + "COHERE": 38, + "GROQ": 30, + "TOGETHER_AI": 28, + "FIREWORKS_AI": 28, + "PERPLEXITY": 28, + "MINIMAX": 28, + "BEDROCK": 28, + "OPENROUTER": 25, + "OLLAMA": 12, + "CUSTOM": 12, +} + + +# --------------------------------------------------------------------------- +# Pure scoring helpers +# --------------------------------------------------------------------------- + +# Calibrated against the live /api/v1/models bulk dump. Frontier models +# released in the last ~6 months (GPT-5 family, Claude 4.x, Gemini 2.5, +# Grok 4) score in the 18-20 band; mid-2024 models in the 8-12 band; +# anything older trails off. +_RECENCY_BANDS_DAYS: tuple[tuple[int, int], ...] = ( + (60, 20), + (180, 16), + (365, 12), + (540, 9), + (730, 6), + (1095, 3), +) + + +def created_recency_signal(created_ts: int | None, now_ts: int) -> int: + """Return 0-20 based on how recently the model was published. + + Uses the OpenRouter ``created`` Unix timestamp (or any equivalent for + YAML cfgs). Models without a usable timestamp get 0 (we don't penalise, + we just don't reward). + """ + if created_ts is None or created_ts <= 0 or now_ts <= 0: + return 0 + age_days = max(0, (now_ts - int(created_ts)) // 86_400) + for cutoff, score in _RECENCY_BANDS_DAYS: + if age_days <= cutoff: + return score + return 0 + + +def pricing_band( + prompt: str | float | int | None, + completion: str | float | int | None, +) -> int: + """Return 0-15 based on combined prompt+completion cost per 1M tokens. + + Higher-priced models tend to be the larger / more capable ones. A free + model returns 0 (we use other signals to rank free-vs-free instead). + Uncoercible inputs are treated as 0 rather than raising. + """ + + def _to_float(value) -> float: + if value is None: + return 0.0 + try: + return float(value) + except (TypeError, ValueError): + return 0.0 + + p = _to_float(prompt) + c = _to_float(completion) + total_per_million = (p + c) * 1_000_000 + + if total_per_million >= 20.0: + return 15 + if total_per_million >= 5.0: + return 12 + if total_per_million >= 1.0: + return 9 + if total_per_million >= 0.3: + return 6 + if total_per_million >= 0.05: + return 4 + if total_per_million > 0.0: + return 2 + return 0 + + +def context_signal(ctx: int | None) -> int: + """Return 0-10 based on the model's context window.""" + if not ctx or ctx <= 0: + return 0 + if ctx >= 1_000_000: + return 10 + if ctx >= 400_000: + return 8 + if ctx >= 200_000: + return 6 + if ctx >= 128_000: + return 4 + if ctx >= 100_000: + return 2 + return 0 + + +def capabilities_signal(supported_parameters: list[str] | None) -> int: + """Return 0-5 for capabilities that matter for our agent flows.""" + if not supported_parameters: + return 0 + params = set(supported_parameters) + score = 0 + if "tools" in params: + score += 2 + if "structured_outputs" in params or "response_format" in params: + score += 2 + if "reasoning" in params or "include_reasoning" in params: + score += 1 + return min(score, 5) + + +def slug_penalty(model_id: str) -> int: + """Return a non-positive number; matches the narrow tiny/legacy patterns.""" + if not model_id: + return 0 + needle = model_id.lower() + for pattern in _TINY_LEGACY_PENALTY_PATTERNS: + if pattern in needle: + return -10 + return 0 + + +def _provider_prestige_or(model_id: str) -> int: + if "/" not in model_id: + return 0 + slug = model_id.split("/", 1)[0].lower() + return PROVIDER_PRESTIGE_OR.get(slug, 15) + + +def static_score_or(or_model: dict, *, now_ts: int) -> int: + """Score a raw OpenRouter ``/api/v1/models`` entry on a 0-100 scale.""" + model_id = str(or_model.get("id", "")) + pricing = or_model.get("pricing") or {} + + score = ( + _provider_prestige_or(model_id) + + created_recency_signal(or_model.get("created"), now_ts) + + pricing_band(pricing.get("prompt"), pricing.get("completion")) + + context_signal(or_model.get("context_length")) + + capabilities_signal(or_model.get("supported_parameters")) + + slug_penalty(model_id) + ) + return max(0, min(100, int(score))) + + +def static_score_yaml(cfg: dict) -> int: + """Score a YAML-curated cfg on a 0-100 scale. + + Includes ``_OPERATOR_TRUST_BONUS`` because the operator deliberately + listed this model. Pricing / context fall through to lazy ``litellm`` + lookups; failures are silent (we just lose those sub-points). + """ + provider = str(cfg.get("provider", "")).upper() + base = PROVIDER_PRESTIGE_YAML.get(provider, 15) + + model_name = cfg.get("model_name") or "" + litellm_params = cfg.get("litellm_params") or {} + lookup_name = ( + litellm_params.get("base_model") + or litellm_params.get("model") + or model_name + ) + + ctx = 0 + p_cost: float = 0.0 + c_cost: float = 0.0 + try: + from litellm import get_model_info # lazy: avoid cold-import cost + + info = get_model_info(lookup_name) or {} + ctx = int(info.get("max_input_tokens") or info.get("max_tokens") or 0) + p_cost = float(info.get("input_cost_per_token") or 0.0) + c_cost = float(info.get("output_cost_per_token") or 0.0) + except Exception: + # Unknown to litellm — that's fine for prestige+operator-bonus weighting. + pass + + score = ( + base + + _OPERATOR_TRUST_BONUS + + pricing_band(p_cost, c_cost) + + context_signal(ctx) + + slug_penalty(str(model_name)) + ) + return max(0, min(100, int(score))) + + +# --------------------------------------------------------------------------- +# Health aggregation +# --------------------------------------------------------------------------- + + +def _coerce_pct(value) -> float | None: + try: + if value is None: + return None + f = float(value) + except (TypeError, ValueError): + return None + if f < 0: + return None + # OpenRouter reports uptime as a 0-1 fraction; some endpoints surface it + # as a 0-100 percentage. Normalise. + return f * 100.0 if f <= 1.0 else f + + +def _best_uptime(endpoints: list[dict]) -> tuple[float | None, str | None]: + """Pick the best (highest) non-null uptime across all endpoints. + + Window preference: ``uptime_last_30m`` > ``uptime_last_1d`` > + ``uptime_last_5m``. Returns ``(uptime_pct, window_used)``. + """ + for window in ("uptime_last_30m", "uptime_last_1d", "uptime_last_5m"): + values = [_coerce_pct(ep.get(window)) for ep in endpoints] + values = [v for v in values if v is not None] + if values: + return max(values), window + return None, None + + +def aggregate_health(endpoints: list[dict]) -> tuple[bool, float | None]: + """Aggregate a model's per-endpoint health into ``(gated, score_or_none)``. + + Hard gate (returns ``(True, None)``): + * ``endpoints`` empty, + * no endpoint reports ``status == 0`` (OK), or + * best non-null uptime below ``_HEALTH_GATE_UPTIME_PCT``. + + On a pass, returns a 0-100 health score blending uptime, status, and a + freshness-weighted recent uptime sample. + """ + if not endpoints: + return True, None + + any_ok = any(int(ep.get("status", 1)) == 0 for ep in endpoints) + if not any_ok: + return True, None + + best_uptime, _ = _best_uptime(endpoints) + if best_uptime is None or best_uptime < _HEALTH_GATE_UPTIME_PCT: + return True, None + + # Freshness term: prefer 5m, fall through to 30m / 1d if 5m is missing. + freshness = None + for window in ("uptime_last_5m", "uptime_last_30m", "uptime_last_1d"): + values = [_coerce_pct(ep.get(window)) for ep in endpoints] + values = [v for v in values if v is not None] + if values: + freshness = max(values) + break + + uptime_term = best_uptime + status_term = 100.0 if any_ok else 0.0 + freshness_term = freshness if freshness is not None else best_uptime + + score = 0.50 * uptime_term + 0.30 * status_term + 0.20 * freshness_term + return False, max(0.0, min(100.0, score)) diff --git a/surfsense_backend/tests/unit/services/test_quality_score.py b/surfsense_backend/tests/unit/services/test_quality_score.py new file mode 100644 index 000000000..fbc91521d --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_quality_score.py @@ -0,0 +1,342 @@ +"""Unit tests for the Auto (Fastest) quality scoring module.""" + +from __future__ import annotations + +import time + +import pytest + +from app.services.quality_score import ( + _HEALTH_GATE_UPTIME_PCT, + _OPERATOR_TRUST_BONUS, + aggregate_health, + capabilities_signal, + context_signal, + created_recency_signal, + pricing_band, + slug_penalty, + static_score_or, + static_score_yaml, +) + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# created_recency_signal +# --------------------------------------------------------------------------- + + +def test_created_recency_signal_recent_model_scores_high(): + now = 1_750_000_000 # ~mid-2025 + one_month_ago = now - (30 * 86_400) + assert created_recency_signal(one_month_ago, now) == 20 + + +def test_created_recency_signal_old_model_scores_zero(): + now = 1_750_000_000 + five_years_ago = now - (5 * 365 * 86_400) + assert created_recency_signal(five_years_ago, now) == 0 + + +def test_created_recency_signal_missing_timestamp_is_neutral(): + now = 1_750_000_000 + assert created_recency_signal(None, now) == 0 + assert created_recency_signal(0, now) == 0 + + +def test_created_recency_signal_monotonic_decay(): + now = 1_750_000_000 + scores = [ + created_recency_signal(now - days * 86_400, now) + for days in (30, 120, 300, 500, 700, 1000, 1500) + ] + assert scores == sorted(scores, reverse=True) + + +# --------------------------------------------------------------------------- +# pricing_band +# --------------------------------------------------------------------------- + + +def test_pricing_band_free_returns_zero(): + assert pricing_band("0", "0") == 0 + assert pricing_band(0.0, 0.0) == 0 + assert pricing_band(None, None) == 0 + + +def test_pricing_band_handles_unparseable(): + assert pricing_band("not-a-number", "0") == 0 + assert pricing_band({}, []) == 0 # type: ignore[arg-type] + + +def test_pricing_band_premium_tiers_increase_with_price(): + cheap = pricing_band("0.0000003", "0.0000005") + mid = pricing_band("0.000003", "0.000015") + flagship = pricing_band("0.00001", "0.00005") + assert 0 < cheap < mid < flagship + + +# --------------------------------------------------------------------------- +# context_signal +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "ctx,expected", + [ + (1_500_000, 10), + (1_000_000, 10), + (500_000, 8), + (200_000, 6), + (128_000, 4), + (100_000, 2), + (50_000, 0), + (0, 0), + (None, 0), + ], +) +def test_context_signal_bands(ctx, expected): + assert context_signal(ctx) == expected + + +# --------------------------------------------------------------------------- +# capabilities_signal +# --------------------------------------------------------------------------- + + +def test_capabilities_signal_caps_at_five(): + assert capabilities_signal( + ["tools", "structured_outputs", "reasoning", "include_reasoning"] + ) <= 5 + + +def test_capabilities_signal_tools_only(): + assert capabilities_signal(["tools"]) == 2 + + +def test_capabilities_signal_empty(): + assert capabilities_signal(None) == 0 + assert capabilities_signal([]) == 0 + + +# --------------------------------------------------------------------------- +# slug_penalty +# --------------------------------------------------------------------------- + + +def test_slug_penalty_demotes_tiny_models(): + assert slug_penalty("meta-llama/llama-3.2-1b-instruct") < 0 + assert slug_penalty("liquid/lfm-7b") < 0 + assert slug_penalty("google/gemma-3n-e4b-it") < 0 + + +def test_slug_penalty_skips_capable_mini_nano_lite_models(): + """Critical Option C+ regression: don't penalise modern frontier + models named ``-nano`` / ``-mini`` / ``-lite`` (gpt-5-mini, etc.).""" + assert slug_penalty("openai/gpt-5-mini") == 0 + assert slug_penalty("openai/gpt-5-nano") == 0 + assert slug_penalty("google/gemini-2.5-flash-lite") == 0 + assert slug_penalty("anthropic/claude-haiku-4.5") == 0 + + +def test_slug_penalty_demotes_legacy_variants(): + assert slug_penalty("openai/o1-preview") < 0 + assert slug_penalty("foo/bar-base") < 0 + assert slug_penalty("foo/bar-distill") < 0 + + +def test_slug_penalty_empty_input(): + assert slug_penalty("") == 0 + + +# --------------------------------------------------------------------------- +# static_score_or +# --------------------------------------------------------------------------- + + +def _or_model( + *, + model_id: str, + created: int | None = None, + prompt: str = "0.000003", + completion: str = "0.000015", + context: int = 200_000, + params: list[str] | None = None, +) -> dict: + return { + "id": model_id, + "created": created, + "pricing": {"prompt": prompt, "completion": completion}, + "context_length": context, + "supported_parameters": params if params is not None else ["tools"], + } + + +def test_static_score_or_frontier_premium_beats_free_tiny(): + now = 1_750_000_000 + frontier = _or_model( + model_id="openai/gpt-5", + created=now - (60 * 86_400), + prompt="0.000005", + completion="0.000020", + context=400_000, + params=["tools", "structured_outputs", "reasoning"], + ) + tiny_free = _or_model( + model_id="meta-llama/llama-3.2-1b-instruct:free", + created=now - (5 * 365 * 86_400), + prompt="0", + completion="0", + context=128_000, + params=["tools"], + ) + assert static_score_or(frontier, now_ts=now) > static_score_or( + tiny_free, now_ts=now + ) + + +def test_static_score_or_score_is_clamped_0_to_100(): + now = int(time.time()) + score = static_score_or(_or_model(model_id="openai/gpt-4o"), now_ts=now) + assert 0 <= score <= 100 + + +def test_static_score_or_unknown_provider_is_neutral_not_zero(): + now = int(time.time()) + score = static_score_or( + _or_model(model_id="some-new-lab/some-model"), + now_ts=now, + ) + assert score > 0 + + +def test_static_score_or_recent_release_beats_year_old_same_provider(): + now = 1_750_000_000 + fresh = _or_model(model_id="openai/gpt-5", created=now - (60 * 86_400)) + old = _or_model(model_id="openai/gpt-4-turbo", created=now - (700 * 86_400)) + assert static_score_or(fresh, now_ts=now) > static_score_or(old, now_ts=now) + + +# --------------------------------------------------------------------------- +# static_score_yaml +# --------------------------------------------------------------------------- + + +def test_static_score_yaml_includes_operator_bonus(): + cfg = { + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "litellm_params": {"base_model": "azure/gpt-5"}, + } + score = static_score_yaml(cfg) + assert score >= _OPERATOR_TRUST_BONUS + + +def test_static_score_yaml_unknown_provider_still_carries_bonus(): + cfg = { + "provider": "SOME_NEW_PROVIDER", + "model_name": "weird-model", + } + score = static_score_yaml(cfg) + assert score >= _OPERATOR_TRUST_BONUS + + +def test_static_score_yaml_clamped_0_to_100(): + cfg = { + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "litellm_params": {"base_model": "azure/gpt-5"}, + } + assert 0 <= static_score_yaml(cfg) <= 100 + + +# --------------------------------------------------------------------------- +# aggregate_health +# --------------------------------------------------------------------------- + + +def test_aggregate_health_gates_when_uptime_below_threshold(): + """Live data showed Venice-routed cfgs at 53-68%; this guards that the + 90% gate excludes them.""" + venice_endpoints = [ + { + "status": 0, + "uptime_last_30m": 0.55, + "uptime_last_1d": 0.60, + "uptime_last_5m": 0.50, + }, + { + "status": 0, + "uptime_last_30m": 0.65, + "uptime_last_1d": 0.68, + "uptime_last_5m": 0.62, + }, + ] + gated, score = aggregate_health(venice_endpoints) + assert gated is True + assert score is None + + +def test_aggregate_health_passes_for_healthy_provider(): + healthy = [ + { + "status": 0, + "uptime_last_30m": 0.99, + "uptime_last_1d": 0.995, + "uptime_last_5m": 0.99, + }, + ] + gated, score = aggregate_health(healthy) + assert gated is False + assert score is not None + assert score >= _HEALTH_GATE_UPTIME_PCT + + +def test_aggregate_health_picks_best_endpoint_across_multiple(): + """Multi-endpoint aggregation should reward the best non-null uptime.""" + mixed = [ + {"status": 0, "uptime_last_30m": 0.55}, + {"status": 0, "uptime_last_30m": 0.97}, # this one passes the gate + ] + gated, score = aggregate_health(mixed) + assert gated is False + assert score is not None + + +def test_aggregate_health_empty_endpoints_gated(): + gated, score = aggregate_health([]) + assert gated is True + assert score is None + + +def test_aggregate_health_no_status_zero_gated(): + """Even with high uptime, no OK status means the cfg is broken upstream.""" + endpoints = [ + {"status": 1, "uptime_last_30m": 0.99}, + {"status": 2, "uptime_last_30m": 0.98}, + ] + gated, score = aggregate_health(endpoints) + assert gated is True + assert score is None + + +def test_aggregate_health_all_uptime_null_gated(): + endpoints = [ + {"status": 0, "uptime_last_30m": None, "uptime_last_1d": None}, + ] + gated, score = aggregate_health(endpoints) + assert gated is True + assert score is None + + +def test_aggregate_health_pct_normalisation(): + """OpenRouter returns 0-1 fractions; some endpoints surface 0-100% + percentages. Both should reach the same gate decision.""" + fraction_form = [{"status": 0, "uptime_last_30m": 0.95}] + pct_form = [{"status": 0, "uptime_last_30m": 95.0}] + g1, s1 = aggregate_health(fraction_form) + g2, s2 = aggregate_health(pct_form) + assert g1 == g2 == False # noqa: E712 + assert s1 is not None and s2 is not None + assert abs(s1 - s2) < 0.5 From c229b4356ac7112576e98397b5eb304b3ca8eefa Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 23:38:21 +0530 Subject: [PATCH 52/68] feat(config): stamp Auto (Fastest) ranking metadata on YAML configs --- surfsense_backend/app/config/__init__.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index 11cbe24a7..b3eff571e 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -63,6 +63,27 @@ def load_global_llm_configs(): else: seen_slugs[slug] = cfg.get("id", 0) + # Stamp Auto (Fastest) ranking metadata. YAML configs are always + # Tier A — operator-curated, locked first when premium-eligible. + # The OpenRouter refresh tick later re-stamps health for any cfg + # whose provider == "OPENROUTER" via _enrich_health. + try: + from app.services.quality_score import static_score_yaml + + for cfg in configs: + cfg["auto_pin_tier"] = "A" + static_q = static_score_yaml(cfg) + cfg["quality_score_static"] = static_q + cfg["quality_score"] = static_q + cfg["quality_score_health"] = None + # YAML cfgs whose provider is OPENROUTER are also subject + # to health gating against their own /endpoints data — a + # hand-picked dead OR model is still dead. _enrich_health + # re-stamps health_gated for them on the next refresh tick. + cfg["health_gated"] = False + except Exception as e: + print(f"Warning: Failed to score global LLM configs: {e}") + return configs except Exception as e: print(f"Warning: Failed to load global LLM configs: {e}") From 1eedcaa55178134ce9c7f45c11707a7406bdb291 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 23:38:40 +0530 Subject: [PATCH 53/68] feat(openrouter): blend per-model /endpoints health into quality score --- .../openrouter_integration_service.py | 231 ++++++++++++ .../services/test_or_health_enrichment.py | 331 ++++++++++++++++++ 2 files changed, 562 insertions(+) create mode 100644 surfsense_backend/tests/unit/services/test_or_health_enrichment.py diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index 06b7becdc..9c3eaa5ea 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -14,13 +14,28 @@ import asyncio import hashlib import logging import threading +import time from typing import Any import httpx +from app.services.quality_score import ( + _HEALTH_BLEND_WEIGHT, + _HEALTH_ENRICH_CONCURRENCY, + _HEALTH_ENRICH_TOP_N_FREE, + _HEALTH_ENRICH_TOP_N_PREMIUM, + _HEALTH_FAIL_RATIO_FALLBACK, + _HEALTH_FETCH_TIMEOUT_SEC, + aggregate_health, + static_score_or, +) + logger = logging.getLogger(__name__) OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models" +OPENROUTER_ENDPOINTS_URL_TEMPLATE = ( + "https://openrouter.ai/api/v1/models/{model_id}/endpoints" +) # Sentinel value stored on each generated config so we can distinguish # dynamic OpenRouter entries from hand-written YAML entries during refresh. @@ -217,12 +232,15 @@ def _generate_configs( configs: list[dict] = [] taken: set[int] = set() + now_ts = int(time.time()) for model in text_models: model_id: str = model["id"] name: str = model.get("name", model_id) tier = _openrouter_tier(model) + static_q = static_score_or(model, now_ts=now_ts) + cfg: dict[str, Any] = { "id": _stable_config_id(model_id, id_offset, taken), "name": name, @@ -249,6 +267,15 @@ def _generate_configs( # there — it just drains the shared bucket faster. "router_pool_eligible": tier == "premium", _OPENROUTER_DYNAMIC_MARKER: True, + # Auto (Fastest) ranking metadata. ``quality_score`` is initialised + # to the static score and gets re-blended with health on the next + # ``_enrich_health`` pass (synchronous on refresh, deferred on cold + # start so startup latency is unchanged). + "auto_pin_tier": "B" if tier == "premium" else "C", + "quality_score_static": static_q, + "quality_score_health": None, + "quality_score": static_q, + "health_gated": False, } configs.append(cfg) @@ -267,6 +294,12 @@ class OpenRouterIntegrationService: self._configs_by_id: dict[int, dict] = {} self._initialized = False self._refresh_task: asyncio.Task | None = None + # Last-good per-model health snapshot. Survives across refresh + # cycles so a transient OpenRouter /endpoints outage doesn't drop + # every cfg back to static-only scoring. + # Shape: {model_name: {"gated": bool, "score": float | None}} + self._health_cache: dict[str, dict[str, Any]] = {} + self._enrich_task: asyncio.Task | None = None @classmethod def get_instance(cls) -> "OpenRouterIntegrationService": @@ -307,6 +340,20 @@ class OpenRouterIntegrationService: tier_counts["free"], tier_counts["premium"], ) + + # Schedule the first health-enrichment pass as a deferred task so + # cold-start latency is unchanged. Only valid when an event loop is + # already running (e.g. FastAPI lifespan); Celery worker init is + # fully sync so we silently skip — its first refresh tick (or the + # next refresh from the web process) will populate health data. + try: + loop = asyncio.get_running_loop() + self._enrich_task = loop.create_task( + self._enrich_health_safely(self._configs) + ) + except RuntimeError: + pass + return self._configs # ------------------------------------------------------------------ @@ -343,6 +390,13 @@ class OpenRouterIntegrationService: tier_counts["premium"], ) + # Re-blend health scores against the freshly fetched catalogue. Also + # re-stamps health for any YAML-curated cfg with provider==OPENROUTER + # so a hand-picked dead OR model is gated like a dynamic one. + await self._enrich_health_safely( + static_configs + new_configs, log_summary=True + ) + # Rebuild the LiteLLM router so freshly fetched configs flow through # (dynamic OR premium entries now opt into the pool, free ones stay # out; a refresh also needs to pick up any static-config edits and @@ -373,6 +427,183 @@ class OpenRouterIntegrationService: counts[tier] += 1 return counts + # ------------------------------------------------------------------ + # Auto (Fastest) health enrichment + # ------------------------------------------------------------------ + + async def _enrich_health_safely( + self, configs: list[dict], *, log_summary: bool = True + ) -> None: + """Wrapper around ``_enrich_health`` that swallows all errors. + + Health enrichment is best-effort: any failure must leave cfgs in + their static-only state and never break refresh / startup. + """ + try: + await self._enrich_health(configs, log_summary=log_summary) + except Exception: + logger.exception("OpenRouter health enrichment failed") + + async def _enrich_health( + self, configs: list[dict], *, log_summary: bool = True + ) -> None: + """Fetch per-model ``/endpoints`` data for the top OR cfgs and blend + the resulting health score into ``cfg["quality_score"]``. + + Bounded fan-out: top-N per tier by ``quality_score_static`` only, + with ``asyncio.Semaphore(_HEALTH_ENRICH_CONCURRENCY)`` guarding the + outbound HTTP. Misses fall back to a per-model last-good cache; if + the failure ratio crosses ``_HEALTH_FAIL_RATIO_FALLBACK`` we keep + the entire previous cycle's cache for this run. + """ + or_cfgs = [ + c for c in configs if str(c.get("provider", "")).upper() == "OPENROUTER" + ] + if not or_cfgs: + return + + premium_pool = sorted( + [ + c + for c in or_cfgs + if str(c.get("billing_tier", "")).lower() == "premium" + ], + key=lambda c: -int(c.get("quality_score_static") or 0), + )[:_HEALTH_ENRICH_TOP_N_PREMIUM] + free_pool = sorted( + [ + c + for c in or_cfgs + if str(c.get("billing_tier", "")).lower() == "free" + ], + key=lambda c: -int(c.get("quality_score_static") or 0), + )[:_HEALTH_ENRICH_TOP_N_FREE] + # De-duplicate while preserving order: a cfg shouldn't fall in both + # tiers, but defensive code is cheap here. + seen_ids: set[int] = set() + selected: list[dict] = [] + for cfg in premium_pool + free_pool: + cid = int(cfg.get("id", 0)) + if cid in seen_ids: + continue + seen_ids.add(cid) + selected.append(cfg) + + if not selected: + return + + api_key = str(self._settings.get("api_key") or "") + semaphore = asyncio.Semaphore(_HEALTH_ENRICH_CONCURRENCY) + + async with httpx.AsyncClient( + timeout=_HEALTH_FETCH_TIMEOUT_SEC + ) as client: + results = await asyncio.gather( + *( + self._fetch_endpoints(client, semaphore, api_key, cfg) + for cfg in selected + ) + ) + + fail_count = sum(1 for _, _, err in results if err is not None) + fail_ratio = fail_count / len(results) if results else 0.0 + degraded = fail_ratio >= _HEALTH_FAIL_RATIO_FALLBACK + if degraded: + logger.warning( + "auto_pin_health_enrich_degraded fail_ratio=%.2f total=%d " + "using_last_good_cache=true", + fail_ratio, + len(results), + ) + + # Per-cfg health update. + for cfg, endpoints, err in results: + model_name = str(cfg.get("model_name", "")) + if not degraded and err is None and endpoints is not None: + gated, h_score = aggregate_health(endpoints) + cfg["health_gated"] = bool(gated) + cfg["quality_score_health"] = h_score + self._health_cache[model_name] = { + "gated": bool(gated), + "score": h_score, + } + else: + cached = self._health_cache.get(model_name) + if cached is not None: + cfg["health_gated"] = bool(cached.get("gated", False)) + cfg["quality_score_health"] = cached.get("score") + # else: keep current values (initial defaults from + # _generate_configs / load_global_llm_configs). + + # Blend health into the final score for every OR cfg, including + # those outside the enriched top-N (they fall through to static). + gated_count = 0 + by_provider: dict[str, int] = {} + for cfg in or_cfgs: + static_q = int(cfg.get("quality_score_static") or 0) + h = cfg.get("quality_score_health") + if h is not None and not cfg.get("health_gated"): + blended = ( + _HEALTH_BLEND_WEIGHT * float(h) + + (1 - _HEALTH_BLEND_WEIGHT) * static_q + ) + cfg["quality_score"] = round(blended) + else: + cfg["quality_score"] = static_q + + if cfg.get("health_gated"): + gated_count += 1 + model_id = str(cfg.get("model_name", "")) + provider_slug = ( + model_id.split("/", 1)[0] if "/" in model_id else "unknown" + ) + by_provider[provider_slug] = by_provider.get(provider_slug, 0) + 1 + + if log_summary: + logger.info( + "auto_pin_health_gated count=%d by_provider=%s fail_ratio=%.2f " + "total_enriched=%d", + gated_count, + dict(sorted(by_provider.items(), key=lambda kv: -kv[1])), + fail_ratio, + len(selected), + ) + + @staticmethod + async def _fetch_endpoints( + client: httpx.AsyncClient, + semaphore: asyncio.Semaphore, + api_key: str, + cfg: dict, + ) -> tuple[dict, list[dict] | None, Exception | None]: + """Fetch ``/api/v1/models/{id}/endpoints`` for one cfg. + + Returns ``(cfg, endpoints, err)`` so the caller can keep batched + results aligned with their cfgs without raising. + """ + model_id = str(cfg.get("model_name", "")) + if not model_id: + return cfg, None, ValueError("missing model_name") + + url = OPENROUTER_ENDPOINTS_URL_TEMPLATE.format(model_id=model_id) + headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} + + async with semaphore: + try: + resp = await client.get(url, headers=headers) + resp.raise_for_status() + data = resp.json() + except Exception as exc: + return cfg, None, exc + + payload = data.get("data") if isinstance(data, dict) else None + if not isinstance(payload, dict): + return cfg, None, ValueError("malformed endpoints payload") + endpoints = payload.get("endpoints") + if not isinstance(endpoints, list): + return cfg, [], None + return cfg, endpoints, None + async def _refresh_loop(self, interval_hours: float) -> None: interval_sec = interval_hours * 3600 while True: diff --git a/surfsense_backend/tests/unit/services/test_or_health_enrichment.py b/surfsense_backend/tests/unit/services/test_or_health_enrichment.py new file mode 100644 index 000000000..1c74aa928 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_or_health_enrichment.py @@ -0,0 +1,331 @@ +"""Unit tests for the OpenRouter ``_enrich_health`` background task.""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from app.services.openrouter_integration_service import ( + OpenRouterIntegrationService, +) +from app.services.quality_score import ( + _HEALTH_FAIL_RATIO_FALLBACK, +) + +pytestmark = pytest.mark.unit + + +def _or_cfg( + *, + cid: int, + model_name: str, + tier: str = "premium", + static_score: int = 50, +) -> dict: + return { + "id": cid, + "provider": "OPENROUTER", + "model_name": model_name, + "billing_tier": tier, + "auto_pin_tier": "B" if tier == "premium" else "C", + "quality_score_static": static_score, + "quality_score_health": None, + "quality_score": static_score, + "health_gated": False, + } + + +class _StubResponse: + def __init__(self, *, payload: dict, status_code: int = 200): + self._payload = payload + self.status_code = status_code + + def raise_for_status(self) -> None: + if self.status_code >= 400: + raise RuntimeError(f"HTTP {self.status_code}") + + def json(self) -> dict: + return self._payload + + +class _StubAsyncClient: + """Minimal drop-in for ``httpx.AsyncClient`` used by ``_fetch_endpoints``.""" + + def __init__(self, responder): + self._responder = responder + self.requests: list[str] = [] + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def get(self, url: str, headers: dict | None = None) -> _StubResponse: + self.requests.append(url) + return self._responder(url) + + +def _patch_async_client(monkeypatch, responder) -> _StubAsyncClient: + """Replace ``httpx.AsyncClient`` for the duration of the test.""" + client = _StubAsyncClient(responder) + monkeypatch.setattr( + "app.services.openrouter_integration_service.httpx.AsyncClient", + lambda *_args, **_kwargs: client, + ) + return client + + +def _healthy_payload() -> dict: + return { + "data": { + "endpoints": [ + { + "status": 0, + "uptime_last_30m": 0.99, + "uptime_last_1d": 0.995, + "uptime_last_5m": 0.99, + } + ] + } + } + + +def _unhealthy_payload() -> dict: + return { + "data": { + "endpoints": [ + { + "status": 0, + "uptime_last_30m": 0.55, + "uptime_last_1d": 0.62, + "uptime_last_5m": 0.50, + } + ] + } + } + + +# --------------------------------------------------------------------------- +# Bounded fan-out + happy path +# --------------------------------------------------------------------------- + + +async def test_enrich_health_marks_healthy_and_gates_unhealthy(monkeypatch): + cfgs = [ + _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70), + _or_cfg(cid=-2, model_name="venice/dead-model", static_score=60), + ] + + def responder(url: str) -> _StubResponse: + if "anthropic" in url: + return _StubResponse(payload=_healthy_payload()) + return _StubResponse(payload=_unhealthy_payload()) + + _patch_async_client(monkeypatch, responder) + + service = OpenRouterIntegrationService() + service._settings = {"api_key": ""} + await service._enrich_health(cfgs) + + healthy = next(c for c in cfgs if c["id"] == -1) + gated = next(c for c in cfgs if c["id"] == -2) + + assert healthy["health_gated"] is False + assert healthy["quality_score_health"] is not None + assert healthy["quality_score"] >= healthy["quality_score_static"] + + assert gated["health_gated"] is True + assert gated["quality_score"] == gated["quality_score_static"] + + +async def test_enrich_health_only_touches_or_provider(monkeypatch): + """YAML cfgs that aren't OPENROUTER must be skipped entirely.""" + yaml_cfg = { + "id": -1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score_static": 80, + "quality_score": 80, + "health_gated": False, + } + or_cfg = _or_cfg(cid=-2, model_name="anthropic/claude-haiku") + + requests: list[str] = [] + + def responder(url: str) -> _StubResponse: + requests.append(url) + return _StubResponse(payload=_healthy_payload()) + + _patch_async_client(monkeypatch, responder) + + service = OpenRouterIntegrationService() + service._settings = {} + await service._enrich_health([yaml_cfg, or_cfg]) + + assert all("anthropic/claude-haiku" in r for r in requests) + # YAML cfg is untouched. + assert yaml_cfg["quality_score"] == 80 + assert yaml_cfg["health_gated"] is False + + +# --------------------------------------------------------------------------- +# Failure ratio fallback +# --------------------------------------------------------------------------- + + +async def test_enrich_health_falls_back_to_last_good_when_failure_ratio_high( + monkeypatch, +): + """If >= 25% of fetches fail, keep last-good cache instead of writing + partial data.""" + cfgs = [ + _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70), + _or_cfg(cid=-2, model_name="openai/gpt-5", static_score=80), + _or_cfg(cid=-3, model_name="google/gemini-flash", static_score=65), + _or_cfg(cid=-4, model_name="venice/something", static_score=50), + ] + + service = OpenRouterIntegrationService() + service._settings = {} + # Pre-seed last-good cache with a known-healthy snapshot. + service._health_cache = { + "anthropic/claude-haiku": {"gated": False, "score": 95.0}, + } + + def all_fail(_url: str) -> _StubResponse: + return _StubResponse(payload={}, status_code=500) + + _patch_async_client(monkeypatch, all_fail) + await service._enrich_health(cfgs) + + # Above threshold ⇒ degraded; last-good cache wins for the cached cfg. + cached_hit = next(c for c in cfgs if c["model_name"] == "anthropic/claude-haiku") + assert cached_hit["quality_score_health"] == 95.0 + assert cached_hit["health_gated"] is False + # Confirm the threshold constant we're testing against is real. + assert _HEALTH_FAIL_RATIO_FALLBACK <= 1.0 + + +async def test_enrich_health_keeps_static_only_with_no_cache_and_failures( + monkeypatch, +): + """If a fetch fails and there's no last-good cache, the cfg keeps its + static-only ``quality_score`` and is *not* gated by default.""" + cfgs = [ + _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70), + ] + + def fail(_url: str) -> _StubResponse: + return _StubResponse(payload={}, status_code=500) + + _patch_async_client(monkeypatch, fail) + + service = OpenRouterIntegrationService() + service._settings = {} + await service._enrich_health(cfgs) + + cfg = cfgs[0] + assert cfg["health_gated"] is False + assert cfg["quality_score"] == cfg["quality_score_static"] + assert cfg["quality_score_health"] is None + + +# --------------------------------------------------------------------------- +# Last-good cache: success populates, next failure reuses +# --------------------------------------------------------------------------- + + +async def test_enrich_health_populates_cache_on_success_then_reuses_on_failure( + monkeypatch, +): + cfg = _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70) + + service = OpenRouterIntegrationService() + service._settings = {} + + def healthy(_url: str) -> _StubResponse: + return _StubResponse(payload=_healthy_payload()) + + _patch_async_client(monkeypatch, healthy) + await service._enrich_health([cfg]) + + assert "anthropic/claude-haiku" in service._health_cache + cached_score = service._health_cache["anthropic/claude-haiku"]["score"] + assert cached_score is not None + + # Next cycle: enough other healthy cfgs so failure ratio stays below + # the 25% threshold even when this one fails individually. + other_cfgs = [ + _or_cfg(cid=-2 - i, model_name=f"healthy/m-{i}", static_score=60) + for i in range(10) + ] + cfg["quality_score_health"] = None + cfg["quality_score"] = cfg["quality_score_static"] + + def mixed(url: str) -> _StubResponse: + if "anthropic" in url: + return _StubResponse(payload={}, status_code=500) + return _StubResponse(payload=_healthy_payload()) + + _patch_async_client(monkeypatch, mixed) + await service._enrich_health([cfg, *other_cfgs]) + + assert cfg["quality_score_health"] == cached_score + assert cfg["health_gated"] is False + + +# --------------------------------------------------------------------------- +# Bounded fan-out: respects top-N caps +# --------------------------------------------------------------------------- + + +async def test_enrich_health_bounds_premium_fanout(monkeypatch): + """Top-N premium cap is honoured even when many cfgs are present.""" + from app.services.quality_score import _HEALTH_ENRICH_TOP_N_PREMIUM + + cfgs = [ + _or_cfg( + cid=-i, model_name=f"openai/m-{i}", tier="premium", static_score=100 - i + ) + for i in range(1, _HEALTH_ENRICH_TOP_N_PREMIUM + 20) + ] + + seen: list[str] = [] + + def responder(url: str) -> _StubResponse: + seen.append(url) + return _StubResponse(payload=_healthy_payload()) + + _patch_async_client(monkeypatch, responder) + + service = OpenRouterIntegrationService() + service._settings = {} + await service._enrich_health(cfgs) + + assert len(seen) == _HEALTH_ENRICH_TOP_N_PREMIUM + + +async def test_enrich_health_no_or_cfgs_is_noop(monkeypatch): + """When the catalogue has no OR cfgs at all, no HTTP calls fire.""" + yaml_cfg: dict[str, Any] = { + "id": -1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "billing_tier": "premium", + } + requests: list[str] = [] + + def responder(url: str) -> _StubResponse: + requests.append(url) + return _StubResponse(payload=_healthy_payload()) + + _patch_async_client(monkeypatch, responder) + + service = OpenRouterIntegrationService() + service._settings = {} + await service._enrich_health([yaml_cfg]) + assert requests == [] From 4bef75d2986b0c46d79b3104dfa4f71dbba5c7fa Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 1 May 2026 23:38:53 +0530 Subject: [PATCH 54/68] feat(auto_pin): quality-aware tier-locked selection with health gate --- .../app/services/auto_model_pin_service.py | 56 ++- .../services/test_auto_model_pin_service.py | 336 ++++++++++++++++++ 2 files changed, 387 insertions(+), 5 deletions(-) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index 1a2061492..94aa6b734 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -24,6 +24,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.config import config from app.db import NewChatThread +from app.services.quality_score import _QUALITY_TOP_K from app.services.token_quota_service import TokenQuotaService logger = logging.getLogger(__name__) @@ -49,8 +50,16 @@ def _is_usable_global_config(cfg: dict) -> bool: def _global_candidates() -> list[dict]: + """Return Auto-eligible global cfgs. + + Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime + below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers + can't be picked as the thread's pin. + """ candidates = [ - cfg for cfg in config.GLOBAL_LLM_CONFIGS if _is_usable_global_config(cfg) + cfg + for cfg in config.GLOBAL_LLM_CONFIGS + if _is_usable_global_config(cfg) and not cfg.get("health_gated") ] return sorted(candidates, key=lambda c: int(c.get("id", 0))) @@ -59,10 +68,26 @@ def _tier_of(cfg: dict) -> str: return str(cfg.get("billing_tier", "free")).lower() -def _deterministic_pick(candidates: list[dict], thread_id: int) -> dict: +def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]: + """Pick a config with quality-first ranking + deterministic spread. + + Tier policy is lock-first: prefer Tier A (operator-curated YAML) + cfgs and only fall through to Tier B/C (dynamic OpenRouter) if no + Tier A cfg is eligible after upstream filters. Within the locked + pool, sort by ``quality_score`` and pick from the top-K via + ``SHA256(thread_id)`` so different new threads spread across the + best models without ever picking a low-ranked one. + + Returns ``(chosen_cfg, top_k_size)``. ``top_k_size`` is exposed for + structured logging in the caller. + """ + tier_a = [c for c in eligible if c.get("auto_pin_tier") in (None, "A")] + pool = tier_a if tier_a else eligible + pool = sorted(pool, key=lambda c: -int(c.get("quality_score") or 0)) + top_k = pool[:_QUALITY_TOP_K] digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest() - idx = int.from_bytes(digest[:8], "big") % len(candidates) - return candidates[idx] + idx = int.from_bytes(digest[:8], "big") % len(top_k) + return top_k[idx], len(top_k) def _to_uuid(user_id: str | UUID | None) -> UUID | None: @@ -150,6 +175,15 @@ async def resolve_or_get_pinned_llm_config_id( pinned_id, _tier_of(pinned_cfg), ) + logger.info( + "auto_pin_resolved thread_id=%s config_id=%s tier=%s " + "auto_pin_tier=%s score=%s top_k_size=0 from_existing_pin=True", + thread_id, + pinned_id, + _tier_of(pinned_cfg), + pinned_cfg.get("auto_pin_tier", "?"), + int(pinned_cfg.get("quality_score") or 0), + ) return AutoPinResolution( resolved_llm_config_id=int(pinned_id), resolved_tier=_tier_of(pinned_cfg), @@ -176,7 +210,7 @@ async def resolve_or_get_pinned_llm_config_id( "Auto mode could not find an eligible LLM config for this user and quota state" ) - selected_cfg = _deterministic_pick(eligible, thread_id) + selected_cfg, top_k_size = _select_pin(eligible, thread_id) selected_id = int(selected_cfg["id"]) selected_tier = _tier_of(selected_cfg) @@ -211,6 +245,18 @@ async def resolve_or_get_pinned_llm_config_id( selected_tier, premium_eligible, ) + + logger.info( + "auto_pin_resolved thread_id=%s config_id=%s tier=%s " + "auto_pin_tier=%s score=%s top_k_size=%d from_existing_pin=False", + thread_id, + selected_id, + selected_tier, + selected_cfg.get("auto_pin_tier", "?"), + int(selected_cfg.get("quality_score") or 0), + top_k_size, + ) + return AutoPinResolution( resolved_llm_config_id=selected_id, resolved_tier=selected_tier, diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index 2094ea6dd..be9d7f721 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -365,3 +365,339 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): assert result.resolved_llm_config_id == -2 assert session.thread.pinned_llm_config_id == -2 assert session.commit_count == 1 + + +# --------------------------------------------------------------------------- +# Quality-aware pin selection (Auto Fastest upgrade) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_health_gated_config_is_excluded_from_selection(monkeypatch): + """A cfg flagged ``health_gated`` must never be picked even if it has + the highest score among eligible cfgs.""" + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "venice/dead-model", + "api_key": "k1", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 95, + "health_gated": True, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "google/gemini-flash", + "api_key": "k1", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 60, + "health_gated": False, + }, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + + +@pytest.mark.asyncio +async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch): + """Premium-eligible users with Tier A available should never spill to + Tier B even if a B cfg ranks higher by ``quality_score``.""" + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "api_key": "k-yaml", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 70, + "health_gated": False, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "openai/gpt-5", + "api_key": "k-or", + "billing_tier": "premium", + "auto_pin_tier": "B", + "quality_score": 95, + "health_gated": False, + }, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.resolved_tier == "premium" + + +@pytest.mark.asyncio +async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch): + """Free-only user with no Tier A free cfg should pick from Tier C.""" + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "api_key": "k-yaml", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 100, + "health_gated": False, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "google/gemini-flash:free", + "api_key": "k-or", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 60, + "health_gated": False, + }, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + + +@pytest.mark.asyncio +async def test_top_k_picks_only_high_score_models(monkeypatch): + """Different thread IDs should spread across top-K, never pick the + obvious low-quality cfg even when it sits in the candidate list.""" + from app.config import config + + high_score_cfgs = [ + { + "id": -i, + "provider": "AZURE_OPENAI", + "model_name": f"gpt-x-{i}", + "api_key": "k", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 90, + "health_gated": False, + } + for i in range(1, 6) # 5 high-quality Tier A cfgs + ] + low_score_trap = { + "id": -99, + "provider": "AZURE_OPENAI", + "model_name": "tiny-legacy", + "api_key": "k", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 10, + "health_gated": False, + } + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + high_score_cfgs + [low_score_trap], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + high_score_ids = {c["id"] for c in high_score_cfgs} + seen = set() + for thread_id in range(1, 50): + session = _FakeSession(_thread()) + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=thread_id, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + seen.add(result.resolved_llm_config_id) + assert result.resolved_llm_config_id != -99, ( + "low-score trap cfg should never be picked" + ) + assert result.resolved_llm_config_id in high_score_ids + + # Spread across at least a couple of top-K cfgs. + assert len(seen) > 1 + + +@pytest.mark.asyncio +async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch): + """An *already* pinned cfg that later flips to ``health_gated`` should + still not be reused — gated cfgs are filtered out of the candidate + pool, which forces a repair to a healthy cfg. + + This guards the no-silent-tier-switch invariant: we don't keep using + a known-broken model just because the thread happened to be pinned + to it before the gate fired.""" + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "venice/dead-model", + "api_key": "k", + "billing_tier": "premium", + "auto_pin_tier": "B", + "quality_score": 50, + "health_gated": True, + }, + { + "id": -2, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "api_key": "k", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 90, + "health_gated": False, + }, + ], + ) + + async def _allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + assert result.from_existing_pin is False + + +@pytest.mark.asyncio +async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch): + """Existing pin reuse must short-circuit the new tier/score logic.""" + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5", + "api_key": "k", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 50, # lower than -2 + "health_gated": False, + }, + { + "id": -2, + "provider": "AZURE_OPENAI", + "model_name": "gpt-5-pro", + "api_key": "k", + "billing_tier": "premium", + "auto_pin_tier": "A", + "quality_score": 99, + "health_gated": False, + }, + ], + ) + + async def _must_not_call(*_args, **_kwargs): + raise AssertionError("premium_get_usage should not run on pin reuse") + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _must_not_call, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.from_existing_pin is True + assert session.commit_count == 0 From f65b3be1ce72e311dffd03de2d60e0fe73f2aef8 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 00:57:52 +0530 Subject: [PATCH 55/68] feat(auto_model_pin): implement runtime cooldown for error handling and enhance candidate selection --- .../app/services/auto_model_pin_service.py | 64 ++- .../app/tasks/chat/stream_new_chat.py | 380 ++++++++++++++---- .../services/test_auto_model_pin_service.py | 112 ++++++ .../unit/test_stream_new_chat_contract.py | 16 + 4 files changed, 486 insertions(+), 86 deletions(-) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index 94aa6b734..05a54b257 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -16,6 +16,8 @@ from __future__ import annotations import hashlib import logging +import threading +import time from dataclasses import dataclass from uuid import UUID @@ -31,6 +33,13 @@ logger = logging.getLogger(__name__) AUTO_FASTEST_ID = 0 AUTO_FASTEST_MODE = "auto_fastest" +_RUNTIME_COOLDOWN_SECONDS = 600 + +# In-memory runtime cooldown map for configs that recently hard-failed at +# provider runtime (e.g. OpenRouter 429 on a pinned free model). This keeps +# the same unhealthy config from being reselected immediately during repair. +_runtime_cooldown_until: dict[int, float] = {} +_runtime_cooldown_lock = threading.Lock() @dataclass @@ -49,17 +58,68 @@ def _is_usable_global_config(cfg: dict) -> bool: ) +def _prune_runtime_cooldowns(now_ts: float | None = None) -> None: + now = time.time() if now_ts is None else now_ts + stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now] + for cid in stale: + _runtime_cooldown_until.pop(cid, None) + + +def _is_runtime_cooled_down(config_id: int) -> bool: + with _runtime_cooldown_lock: + _prune_runtime_cooldowns() + return config_id in _runtime_cooldown_until + + +def mark_runtime_cooldown( + config_id: int, + *, + reason: str = "rate_limited", + cooldown_seconds: int = _RUNTIME_COOLDOWN_SECONDS, +) -> None: + """Temporarily suppress a config from Auto selection. + + Used by runtime error handlers (e.g. OpenRouter 429) so an already pinned + config that is currently unhealthy does not get immediately reused on the + same thread during repair. + """ + if cooldown_seconds <= 0: + cooldown_seconds = _RUNTIME_COOLDOWN_SECONDS + until = time.time() + int(cooldown_seconds) + with _runtime_cooldown_lock: + _runtime_cooldown_until[int(config_id)] = until + _prune_runtime_cooldowns() + logger.info( + "auto_pin_runtime_cooled_down config_id=%s reason=%s cooldown_seconds=%s", + config_id, + reason, + cooldown_seconds, + ) + + +def clear_runtime_cooldown(config_id: int | None = None) -> None: + """Test/ops helper to clear runtime cooldown entries.""" + with _runtime_cooldown_lock: + if config_id is None: + _runtime_cooldown_until.clear() + return + _runtime_cooldown_until.pop(int(config_id), None) + + def _global_candidates() -> list[dict]: """Return Auto-eligible global cfgs. Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers - can't be picked as the thread's pin. + can't be picked as the thread's pin. Also excludes configs currently + in runtime cooldown (e.g. temporary 429 bursts). """ candidates = [ cfg for cfg in config.GLOBAL_LLM_CONFIGS - if _is_usable_global_config(cfg) and not cfg.get("health_gated") + if _is_usable_global_config(cfg) + and not cfg.get("health_gated") + and not _is_runtime_cooled_down(int(cfg.get("id", 0))) ] return sorted(candidates, key=lambda c: int(c.get("id", 0))) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 5abcb63eb..8f596927d 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -64,7 +64,10 @@ from app.db import ( shielded_async_session, ) from app.prompts import TITLE_GENERATION_PROMPT -from app.services.auto_model_pin_service import resolve_or_get_pinned_llm_config_id +from app.services.auto_model_pin_service import ( + mark_runtime_cooldown, + resolve_or_get_pinned_llm_config_id, +) from app.services.chat_session_state_service import ( clear_ai_responding, set_ai_responding, @@ -414,6 +417,60 @@ def _parse_error_payload(message: str) -> dict[str, Any] | None: return None +def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None: + if not isinstance(parsed, dict): + return None + candidates: list[Any] = [parsed.get("code")] + nested = parsed.get("error") + if isinstance(nested, dict): + candidates.append(nested.get("code")) + for value in candidates: + try: + if value is None: + continue + return int(value) + except Exception: + continue + return None + + +def _is_provider_rate_limited(exc: BaseException) -> bool: + """Best-effort detection for provider-side runtime throttling. + + Covers LiteLLM/OpenRouter shapes like: + - class name contains ``RateLimit`` + - nested payload ``{"error": {"code": 429}}`` + - nested payload ``{"error": {"type": "rate_limit_error"}}`` + """ + raw = str(exc) + lowered = raw.lower() + if "ratelimit" in type(exc).__name__.lower(): + return True + parsed = _parse_error_payload(raw) + provider_code = _extract_provider_error_code(parsed) + if provider_code == 429: + return True + + provider_error_type = "" + if parsed: + top_type = parsed.get("type") + if isinstance(top_type, str): + provider_error_type = top_type.lower() + nested = parsed.get("error") + if isinstance(nested, dict): + nested_type = nested.get("type") + if isinstance(nested_type, str): + provider_error_type = nested_type.lower() + if provider_error_type == "rate_limit_error": + return True + + return ( + "rate limited" in lowered + or "rate-limited" in lowered + or "temporarily rate-limited upstream" in lowered + ) + + def _classify_stream_exception( exc: Exception, *, @@ -449,19 +506,7 @@ def _classify_stream_exception( None, ) - parsed = _parse_error_payload(raw) - provider_error_type = "" - if parsed: - top_type = parsed.get("type") - if isinstance(top_type, str): - provider_error_type = top_type.lower() - nested = parsed.get("error") - if isinstance(nested, dict): - nested_type = nested.get("type") - if isinstance(nested_type, str): - provider_error_type = nested_type.lower() - - if provider_error_type == "rate_limit_error": + if _is_provider_rate_limited(exc): return ( "rate_limited", "RATE_LIMITED", @@ -2671,54 +2716,144 @@ async def stream_new_chat( _t_stream_start = time.perf_counter() _first_event_logged = False - async for sse in _stream_agent_events( - agent=agent, - config=config, - input_data=input_state, - streaming_service=streaming_service, - result=stream_result, - step_prefix="thinking", - initial_step_id=initial_step_id, - initial_step_title=initial_title, - initial_step_items=initial_items, - fallback_commit_search_space_id=search_space_id, - fallback_commit_created_by_id=user_id, - fallback_commit_filesystem_mode=( - filesystem_selection.mode - if filesystem_selection - else FilesystemMode.CLOUD - ), - fallback_commit_thread_id=chat_id, - ): - if not _first_event_logged: - _perf_log.info( - "[stream_new_chat] First agent event in %.3fs (time since stream start), " - "%.3fs (total since request start) (chat_id=%s)", - time.perf_counter() - _t_stream_start, - time.perf_counter() - _t_total, - chat_id, - ) - _first_event_logged = True - yield sse - - # Inject title update mid-stream as soon as the background task finishes - if title_task is not None and title_task.done() and not title_emitted: - generated_title, title_usage = title_task.result() - if title_usage: - accumulator.add(**title_usage) - if generated_title: - async with shielded_async_session() as title_session: - title_thread_result = await title_session.execute( - select(NewChatThread).filter(NewChatThread.id == chat_id) + runtime_rate_limit_recovered = False + while True: + try: + async for sse in _stream_agent_events( + agent=agent, + config=config, + input_data=input_state, + streaming_service=streaming_service, + result=stream_result, + step_prefix="thinking", + initial_step_id=initial_step_id, + initial_step_title=initial_title, + initial_step_items=initial_items, + fallback_commit_search_space_id=search_space_id, + fallback_commit_created_by_id=user_id, + fallback_commit_filesystem_mode=( + filesystem_selection.mode + if filesystem_selection + else FilesystemMode.CLOUD + ), + fallback_commit_thread_id=chat_id, + ): + if not _first_event_logged: + _perf_log.info( + "[stream_new_chat] First agent event in %.3fs (time since stream start), " + "%.3fs (total since request start) (chat_id=%s)", + time.perf_counter() - _t_stream_start, + time.perf_counter() - _t_total, + chat_id, ) - title_thread = title_thread_result.scalars().first() - if title_thread: - title_thread.title = generated_title - await title_session.commit() - yield streaming_service.format_thread_title_update( - chat_id, generated_title + _first_event_logged = True + yield sse + + # Inject title update mid-stream as soon as the background + # task finishes. + if title_task is not None and title_task.done() and not title_emitted: + generated_title, title_usage = title_task.result() + if title_usage: + accumulator.add(**title_usage) + if generated_title: + async with shielded_async_session() as title_session: + title_thread_result = await title_session.execute( + select(NewChatThread).filter( + NewChatThread.id == chat_id + ) + ) + title_thread = title_thread_result.scalars().first() + if title_thread: + title_thread.title = generated_title + await title_session.commit() + yield streaming_service.format_thread_title_update( + chat_id, generated_title + ) + title_emitted = True + break + except Exception as stream_exc: + can_runtime_recover = ( + not runtime_rate_limit_recovered + and requested_llm_config_id == 0 + and llm_config_id < 0 + and not _first_event_logged + and _is_provider_rate_limited(stream_exc) + ) + if not can_runtime_recover: + raise + + runtime_rate_limit_recovered = True + previous_config_id = llm_config_id + mark_runtime_cooldown( + previous_config_id, + reason="provider_rate_limited", + ) + + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, ) - title_emitted = True + ).resolved_llm_config_id + + llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + if llm_load_error: + raise stream_exc + + # Title generation uses the initial llm object. After a runtime + # repin we keep the stream focused on response recovery and skip + # title generation for this turn. + if title_task is not None and not title_task.done(): + title_task.cancel() + title_task = None + + _t0 = time.perf_counter() + agent = await create_surfsense_deep_agent( + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + disabled_tools=disabled_tools, + mentioned_document_ids=mentioned_document_ids, + filesystem_selection=filesystem_selection, + ) + _perf_log.info( + "[stream_new_chat] Runtime rate-limit recovery repinned " + "config_id=%s -> %s and rebuilt agent in %.3fs", + previous_config_id, + llm_config_id, + time.perf_counter() - _t0, + ) + _log_chat_stream_error( + flow=flow, + error_kind="rate_limited", + error_code="RATE_LIMITED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Auto-pinned model hit runtime rate limit; switched to " + "another eligible model and retried." + ), + extra={ + "auto_runtime_recover": True, + "previous_config_id": previous_config_id, + "fallback_config_id": llm_config_id, + }, + ) + continue _perf_log.info( "[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)", @@ -3265,31 +3400,108 @@ async def stream_resume_chat( _t_stream_start = time.perf_counter() _first_event_logged = False - async for sse in _stream_agent_events( - agent=agent, - config=config, - input_data=Command(resume={"decisions": decisions}), - streaming_service=streaming_service, - result=stream_result, - step_prefix="thinking-resume", - fallback_commit_search_space_id=search_space_id, - fallback_commit_created_by_id=user_id, - fallback_commit_filesystem_mode=( - filesystem_selection.mode - if filesystem_selection - else FilesystemMode.CLOUD - ), - fallback_commit_thread_id=chat_id, - ): - if not _first_event_logged: - _perf_log.info( - "[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)", - time.perf_counter() - _t_stream_start, - time.perf_counter() - _t_total, - chat_id, + runtime_rate_limit_recovered = False + while True: + try: + async for sse in _stream_agent_events( + agent=agent, + config=config, + input_data=Command(resume={"decisions": decisions}), + streaming_service=streaming_service, + result=stream_result, + step_prefix="thinking-resume", + fallback_commit_search_space_id=search_space_id, + fallback_commit_created_by_id=user_id, + fallback_commit_filesystem_mode=( + filesystem_selection.mode + if filesystem_selection + else FilesystemMode.CLOUD + ), + fallback_commit_thread_id=chat_id, + ): + if not _first_event_logged: + _perf_log.info( + "[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)", + time.perf_counter() - _t_stream_start, + time.perf_counter() - _t_total, + chat_id, + ) + _first_event_logged = True + yield sse + break + except Exception as stream_exc: + can_runtime_recover = ( + not runtime_rate_limit_recovered + and requested_llm_config_id == 0 + and llm_config_id < 0 + and not _first_event_logged + and _is_provider_rate_limited(stream_exc) ) - _first_event_logged = True - yield sse + if not can_runtime_recover: + raise + + runtime_rate_limit_recovered = True + previous_config_id = llm_config_id + mark_runtime_cooldown( + previous_config_id, + reason="provider_rate_limited", + ) + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + ) + ).resolved_llm_config_id + + llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + if llm_load_error: + raise stream_exc + + _t0 = time.perf_counter() + agent = await create_surfsense_deep_agent( + llm=llm, + search_space_id=search_space_id, + db_session=session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=chat_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=visibility, + filesystem_selection=filesystem_selection, + ) + _perf_log.info( + "[stream_resume] Runtime rate-limit recovery repinned " + "config_id=%s -> %s and rebuilt agent in %.3fs", + previous_config_id, + llm_config_id, + time.perf_counter() - _t0, + ) + _log_chat_stream_error( + flow="resume", + error_kind="rate_limited", + error_code="RATE_LIMITED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Auto-pinned model hit runtime rate limit; switched to " + "another eligible model and retried." + ), + extra={ + "auto_runtime_recover": True, + "previous_config_id": previous_config_id, + "fallback_config_id": llm_config_id, + }, + ) + continue _perf_log.info( "[stream_resume] Agent stream completed in %.3fs (chat_id=%s)", time.perf_counter() - _t_stream_start, diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index be9d7f721..8261fdfe0 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -6,12 +6,21 @@ from types import SimpleNamespace import pytest from app.services.auto_model_pin_service import ( + clear_runtime_cooldown, + mark_runtime_cooldown, resolve_or_get_pinned_llm_config_id, ) pytestmark = pytest.mark.unit +@pytest.fixture(autouse=True) +def _clear_runtime_cooldown_map(): + clear_runtime_cooldown() + yield + clear_runtime_cooldown() + + @dataclass class _FakeQuotaResult: allowed: bool @@ -701,3 +710,106 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch): assert result.resolved_llm_config_id == -1 assert result.from_existing_pin is True assert session.commit_count == 0 + + +@pytest.mark.asyncio +async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch): + """A runtime-cooled config should be excluded from candidate reuse. + + This enables one-shot recovery from transient provider 429 bursts: we can + mark the pinned cfg as cooled down and force a repair to another eligible + cfg on the next resolution. + """ + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "google/gemma-4-26b-a4b-it:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 90, + "health_gated": False, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 80, + "health_gated": False, + }, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + mark_runtime_cooldown(-1, reason="provider_rate_limited", cooldown_seconds=600) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -2 + assert result.from_existing_pin is False + + +@pytest.mark.asyncio +async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch): + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "google/gemma-4-26b-a4b-it:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 90, + "health_gated": False, + }, + ], + ) + + async def _must_not_call(*_args, **_kwargs): + raise AssertionError("premium_get_usage should not run on healthy pin reuse") + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _must_not_call, + ) + + mark_runtime_cooldown(-1, reason="provider_rate_limited", cooldown_seconds=600) + clear_runtime_cooldown(-1) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + assert result.from_existing_pin is True diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 5e6ad6abd..ed69ca348 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -159,6 +159,22 @@ def test_stream_exception_classifies_rate_limited(): assert extra is None +def test_stream_exception_classifies_openrouter_429_payload(): + exc = Exception( + 'OpenrouterException - {"error":{"message":"Provider returned error","code":429,' + '"metadata":{"raw":"foo is temporarily rate-limited upstream"}}}' + ) + kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( + exc, flow_label="chat" + ) + assert kind == "rate_limited" + assert code == "RATE_LIMITED" + assert severity == "warn" + assert is_expected is True + assert "temporarily rate-limited" in user_message + assert extra is None + + def test_stream_exception_classifies_thread_busy(): exc = BusyError(request_id="thread-123") kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( From 25ccc959cf59018c3937be22b23ffc7a35fb7391 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 01:35:30 +0530 Subject: [PATCH 56/68] feat(busy_mutex): enhance thread lock management to prevent stale middleware interference --- .../agents/new_chat/middleware/busy_mutex.py | 37 ++++++++++--- .../app/services/auto_model_pin_service.py | 10 +++- .../app/tasks/chat/stream_new_chat.py | 9 ++++ .../unit/agents/new_chat/test_busy_mutex.py | 34 ++++++++++++ .../services/test_auto_model_pin_service.py | 53 +++++++++++++++++++ 5 files changed, 134 insertions(+), 9 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py index d61a56533..06a27bc96 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py +++ b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py @@ -61,6 +61,9 @@ class _ThreadLockManager: self._cancel_events: dict[str, asyncio.Event] = {} self._cancel_requested_at_ms: dict[str, int] = {} self._cancel_attempt_count: dict[str, int] = {} + # Monotonic per-thread epoch used to prevent stale middleware + # teardown from releasing a newer turn's lock. + self._turn_epoch: dict[str, int] = {} def lock_for(self, thread_id: str) -> asyncio.Lock: lock = self._locks.get(thread_id) @@ -107,6 +110,14 @@ class _ThreadLockManager: self._cancel_requested_at_ms.pop(thread_id, None) self._cancel_attempt_count.pop(thread_id, None) + def bump_turn_epoch(self, thread_id: str) -> int: + epoch = self._turn_epoch.get(thread_id, 0) + 1 + self._turn_epoch[thread_id] = epoch + return epoch + + def current_turn_epoch(self, thread_id: str) -> int: + return self._turn_epoch.get(thread_id, 0) + def end_turn(self, thread_id: str) -> None: """Best-effort terminal cleanup for a thread turn. @@ -114,6 +125,10 @@ class _ThreadLockManager: finally-blocks where middleware teardown might be skipped due to abort or disconnect edge-cases. """ + # Invalidate any in-flight middleware holder first. This guarantees a + # stale ``aafter_agent`` from an older attempt cannot unlock a newer + # retry that already acquired the lock for the same thread. + self.bump_turn_epoch(thread_id) lock = self._locks.get(thread_id) if lock is not None and lock.locked(): lock.release() @@ -178,10 +193,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo super().__init__() self._require_thread_id = require_thread_id self.tools = [] - # Per-call locks owned by this middleware. We track them as - # an instance attribute so ``aafter_agent`` knows which lock - # to release. - self._held_locks: dict[str, asyncio.Lock] = {} + # Per-call lock ownership tracked as (lock, epoch). ``aafter_agent`` + # only releases when its epoch still matches the manager's current + # epoch for the thread, preventing stale unlock races. + self._held_locks: dict[str, tuple[asyncio.Lock, int]] = {} @staticmethod def _thread_id(runtime: Runtime[ContextT]) -> str | None: @@ -232,7 +247,8 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo if lock.locked(): raise BusyError(request_id=thread_id) await lock.acquire() - self._held_locks[thread_id] = lock + epoch = manager.bump_turn_epoch(thread_id) + self._held_locks[thread_id] = (lock, epoch) # Reset the cancel event so this turn starts fresh reset_cancel(thread_id) return None @@ -246,8 +262,15 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo thread_id = self._thread_id(runtime) if thread_id is None: return None - lock = self._held_locks.pop(thread_id, None) - if lock is not None and lock.locked(): + held = self._held_locks.pop(thread_id, None) + if held is None: + return None + lock, held_epoch = held + if held_epoch != manager.current_turn_epoch(thread_id): + # Stale teardown from an older attempt (e.g. runtime-recovery path + # already advanced epoch). Do not touch current lock/cancel state. + return None + if lock.locked(): lock.release() # Always clear cancel event between turns so a stale signal # doesn't leak into the next request. diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index 05a54b257..f6a223866 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -179,6 +179,7 @@ async def resolve_or_get_pinned_llm_config_id( user_id: str | UUID | None, selected_llm_config_id: int, force_repin_free: bool = False, + exclude_config_ids: set[int] | None = None, ) -> AutoPinResolution: """Resolve Auto (Fastest) to one concrete config id and persist the pin. @@ -214,9 +215,14 @@ async def resolve_or_get_pinned_llm_config_id( from_existing_pin=False, ) - candidates = _global_candidates() + excluded_ids = {int(cid) for cid in (exclude_config_ids or set())} + candidates = [ + c for c in _global_candidates() if int(c.get("id", 0)) not in excluded_ids + ] if not candidates: - raise ValueError("No usable global LLM configs are available for Auto mode") + raise ValueError( + "No usable global LLM configs are available for Auto mode" + ) candidate_by_id = {int(c["id"]): c for c in candidates} # Reuse an existing valid pin without re-checking current quota (no silent diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 8f596927d..dbfd5e2ea 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -2784,6 +2784,10 @@ async def stream_new_chat( runtime_rate_limit_recovered = True previous_config_id = llm_config_id + # The failed attempt may still hold the per-thread busy mutex + # (middleware teardown can lag behind raised provider errors). + # Force release before we retry within the same request. + end_turn(str(chat_id)) mark_runtime_cooldown( previous_config_id, reason="provider_rate_limited", @@ -2796,6 +2800,7 @@ async def stream_new_chat( search_space_id=search_space_id, user_id=user_id, selected_llm_config_id=0, + exclude_config_ids={previous_config_id}, ) ).resolved_llm_config_id @@ -3442,6 +3447,9 @@ async def stream_resume_chat( runtime_rate_limit_recovered = True previous_config_id = llm_config_id + # Ensure the same-request recovery retry does not trip the + # BusyMutex lock retained by the failed attempt. + end_turn(str(chat_id)) mark_runtime_cooldown( previous_config_id, reason="provider_rate_limited", @@ -3453,6 +3461,7 @@ async def stream_resume_chat( search_space_id=search_space_id, user_id=user_id, selected_llm_config_id=0, + exclude_config_ids={previous_config_id}, ) ).resolved_llm_config_id diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py index c923dc499..f0161f605 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py @@ -118,3 +118,37 @@ async def test_end_turn_force_clears_lock_and_cancel_state() -> None: assert not manager.lock_for(thread_id).locked() assert not get_cancel_event(thread_id).is_set() assert is_cancel_requested(thread_id) is False + + +@pytest.mark.asyncio +async def test_busy_mutex_stale_aafter_does_not_release_new_attempt_lock() -> None: + """A stale aafter call from attempt A must not unlock attempt B. + + Repro flow: + 1) attempt A acquires thread lock + 2) forced end_turn clears A so retry can proceed + 3) attempt B acquires same thread lock + 4) stale attempt-A aafter runs late + + Expected: B lock remains held. + """ + thread_id = "stale-aafter-lock" + runtime = _Runtime(thread_id) + attempt_a = BusyMutexMiddleware() + attempt_b = BusyMutexMiddleware() + + await attempt_a.abefore_agent({}, runtime) + lock = manager.lock_for(thread_id) + assert lock.locked() + + end_turn(thread_id) + assert not lock.locked() + + await attempt_b.abefore_agent({}, runtime) + assert lock.locked() + + # Stale cleanup from attempt A must not release attempt B's lock. + await attempt_a.aafter_agent({}, runtime) + assert lock.locked() + + await attempt_b.aafter_agent({}, runtime) diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index 8261fdfe0..8696a8829 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -813,3 +813,56 @@ async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch): ) assert result.resolved_llm_config_id == -1 assert result.from_existing_pin is True + + +@pytest.mark.asyncio +async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypatch): + """Runtime retry should never repin the just-failed config.""" + from app.config import config + + session = _FakeSession(_thread(pinned_llm_config_id=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [ + { + "id": -1, + "provider": "OPENROUTER", + "model_name": "google/gemma-4-26b-a4b-it:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 90, + "health_gated": False, + }, + { + "id": -2, + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash:free", + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "C", + "quality_score": 80, + "health_gated": False, + }, + ], + ) + + async def _blocked(*_args, **_kwargs): + return _FakeQuotaResult(allowed=False) + + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _blocked, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id="00000000-0000-0000-0000-000000000001", + selected_llm_config_id=0, + exclude_config_ids={-1}, + ) + assert result.resolved_llm_config_id == -2 + assert result.from_existing_pin is False From 14686cdf829e62b4a5b62f088faf462948aaa416 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 02:07:16 +0530 Subject: [PATCH 57/68] feat(auto_pin): add short-TTL healthy-status cache for preflight reuse --- .../app/services/auto_model_pin_service.py | 57 +++++++++++++++++++ .../services/test_auto_model_pin_service.py | 53 +++++++++++++++++ 2 files changed, 110 insertions(+) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index f6a223866..b2acd6f56 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -34,6 +34,7 @@ logger = logging.getLogger(__name__) AUTO_FASTEST_ID = 0 AUTO_FASTEST_MODE = "auto_fastest" _RUNTIME_COOLDOWN_SECONDS = 600 +_HEALTHY_TTL_SECONDS = 45 # In-memory runtime cooldown map for configs that recently hard-failed at # provider runtime (e.g. OpenRouter 429 on a pinned free model). This keeps @@ -41,6 +42,13 @@ _RUNTIME_COOLDOWN_SECONDS = 600 _runtime_cooldown_until: dict[int, float] = {} _runtime_cooldown_lock = threading.Lock() +# Short-TTL "recently healthy" cache for configs that just passed a runtime +# preflight ping. Lets back-to-back turns on the same model skip the probe +# without eroding correctness — entries auto-expire and are wiped any time +# the same config is cooled down or the OR catalogue is refreshed. +_healthy_until: dict[int, float] = {} +_healthy_lock = threading.Lock() + @dataclass class AutoPinResolution: @@ -89,6 +97,9 @@ def mark_runtime_cooldown( with _runtime_cooldown_lock: _runtime_cooldown_until[int(config_id)] = until _prune_runtime_cooldowns() + # A cooled cfg can never be "recently healthy"; drop any stale credit so + # the next turn that resolves to it (after cooldown) re-runs preflight. + clear_healthy(int(config_id)) logger.info( "auto_pin_runtime_cooled_down config_id=%s reason=%s cooldown_seconds=%s", config_id, @@ -106,6 +117,52 @@ def clear_runtime_cooldown(config_id: int | None = None) -> None: _runtime_cooldown_until.pop(int(config_id), None) +def _prune_healthy(now_ts: float | None = None) -> None: + now = time.time() if now_ts is None else now_ts + stale = [cid for cid, until in _healthy_until.items() if until <= now] + for cid in stale: + _healthy_until.pop(cid, None) + + +def is_recently_healthy(config_id: int) -> bool: + """Return True if ``config_id`` passed preflight within the TTL window.""" + with _healthy_lock: + _prune_healthy() + return int(config_id) in _healthy_until + + +def mark_healthy( + config_id: int, + *, + ttl_seconds: int = _HEALTHY_TTL_SECONDS, +) -> None: + """Record that ``config_id`` just passed a preflight probe. + + Subsequent calls within ``ttl_seconds`` can skip the preflight ping. The + healthy state is intentionally process-local — it's a latency hint, not a + correctness primitive — so multi-worker drift is acceptable. + """ + if ttl_seconds <= 0: + ttl_seconds = _HEALTHY_TTL_SECONDS + until = time.time() + int(ttl_seconds) + with _healthy_lock: + _healthy_until[int(config_id)] = until + _prune_healthy() + + +def clear_healthy(config_id: int | None = None) -> None: + """Drop one (or all) healthy-cache entries. + + Called from runtime cooldown and OR catalogue refresh so a freshly cooled + or replaced config never carries stale "healthy" credit. + """ + with _healthy_lock: + if config_id is None: + _healthy_until.clear() + return + _healthy_until.pop(int(config_id), None) + + def _global_candidates() -> list[dict]: """Return Auto-eligible global cfgs. diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index 8696a8829..d333f0b7a 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -6,7 +6,10 @@ from types import SimpleNamespace import pytest from app.services.auto_model_pin_service import ( + clear_healthy, clear_runtime_cooldown, + is_recently_healthy, + mark_healthy, mark_runtime_cooldown, resolve_or_get_pinned_llm_config_id, ) @@ -17,8 +20,10 @@ pytestmark = pytest.mark.unit @pytest.fixture(autouse=True) def _clear_runtime_cooldown_map(): clear_runtime_cooldown() + clear_healthy() yield clear_runtime_cooldown() + clear_healthy() @dataclass @@ -866,3 +871,51 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa ) assert result.resolved_llm_config_id == -2 assert result.from_existing_pin is False + + +# --------------------------------------------------------------------------- +# Healthy-status cache (preflight TTL companion) +# --------------------------------------------------------------------------- + + +def test_mark_healthy_then_is_recently_healthy_true_within_ttl(): + mark_healthy(-42, ttl_seconds=60) + assert is_recently_healthy(-42) is True + + +def test_healthy_expires_after_ttl(monkeypatch): + import app.services.auto_model_pin_service as svc + + real_time = svc.time.time + base = real_time() + + monkeypatch.setattr(svc.time, "time", lambda: base) + mark_healthy(-7, ttl_seconds=10) + assert is_recently_healthy(-7) is True + + monkeypatch.setattr(svc.time, "time", lambda: base + 11) + assert is_recently_healthy(-7) is False + + +def test_mark_runtime_cooldown_invalidates_healthy_cache(): + mark_healthy(-9, ttl_seconds=60) + assert is_recently_healthy(-9) is True + + mark_runtime_cooldown(-9, reason="test", cooldown_seconds=60) + assert is_recently_healthy(-9) is False + + +def test_clear_healthy_removes_single_entry(): + mark_healthy(-11, ttl_seconds=60) + mark_healthy(-12, ttl_seconds=60) + clear_healthy(-11) + assert is_recently_healthy(-11) is False + assert is_recently_healthy(-12) is True + + +def test_clear_healthy_no_args_drops_all_entries(): + mark_healthy(-21, ttl_seconds=60) + mark_healthy(-22, ttl_seconds=60) + clear_healthy() + assert is_recently_healthy(-21) is False + assert is_recently_healthy(-22) is False From 2764fa5e30185c3e22f59a10df19c7db7d0a25bd Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 02:07:30 +0530 Subject: [PATCH 58/68] feat(openrouter): clear healthy-status cache on catalogue refresh --- .../app/services/openrouter_integration_service.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index 9c3eaa5ea..67dbb6690 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -382,6 +382,18 @@ class OpenRouterIntegrationService: self._configs = new_configs self._configs_by_id = new_by_id + # Catalogue churn invalidates per-config "recently healthy" credit + # earned by the previous turn's preflight. Drop the whole table so + # the next turn re-probes against the freshly loaded configs. + try: + from app.services.auto_model_pin_service import clear_healthy + + clear_healthy() + except Exception: + logger.debug( + "OpenRouter refresh: clear_healthy import skipped", exc_info=True + ) + tier_counts = self._tier_counts(new_configs) logger.info( "OpenRouter refresh: updated to %d models (free=%d, premium=%d)", From 7c1c394fe4768c05babc0330e2f8955e82167046 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 02:07:44 +0530 Subject: [PATCH 59/68] feat(stream_new_chat): add lightweight LLM preflight probe for auto-pin --- .../unit/test_stream_new_chat_contract.py | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index ed69ca348..6a1b4c13b 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -175,6 +175,68 @@ def test_stream_exception_classifies_openrouter_429_payload(): assert extra is None +@pytest.mark.asyncio +async def test_preflight_swallows_non_rate_limit_errors_and_re_raises_429(monkeypatch): + """``_preflight_llm`` is best-effort. + + - On rate-limit shaped exceptions (provider 429) it MUST re-raise so the + caller can drive the cooldown/repin branch. + - On any other transient failure it MUST swallow the error so the normal + stream path continues without surfacing preflight noise to the user. + """ + from types import SimpleNamespace + + from app.tasks.chat.stream_new_chat import _preflight_llm + + class _RateLimitedExc(Exception): + """Class-name carries 'RateLimit' so _is_provider_rate_limited triggers.""" + + rate_calls: list[dict] = [] + other_calls: list[dict] = [] + + async def _fake_acompletion_429(**kwargs): + rate_calls.append(kwargs) + raise _RateLimitedExc("simulated 429") + + async def _fake_acompletion_other(**kwargs): + other_calls.append(kwargs) + raise RuntimeError("some unrelated transient failure") + + fake_llm = SimpleNamespace( + model="openrouter/google/gemma-4-31b-it:free", + api_key="test", + api_base=None, + ) + + import litellm # type: ignore[import-not-found] + + monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_429) + with pytest.raises(_RateLimitedExc): + await _preflight_llm(fake_llm) + assert len(rate_calls) == 1 + assert rate_calls[0]["max_tokens"] == 1 + assert rate_calls[0]["stream"] is False + + monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_other) + # MUST NOT raise: non-rate-limit failures are swallowed. + await _preflight_llm(fake_llm) + assert len(other_calls) == 1 + + +@pytest.mark.asyncio +async def test_preflight_skipped_for_auto_router_model(): + """Router-mode ``model='auto'`` has no single deployment to ping; the + LiteLLM router itself owns per-deployment rate-limit accounting, so the + preflight helper must short-circuit instead of issuing a probe.""" + from types import SimpleNamespace + + from app.tasks.chat.stream_new_chat import _preflight_llm + + fake_llm = SimpleNamespace(model="auto", api_key="x", api_base=None) + # Should return without raising or making any LiteLLM call. + await _preflight_llm(fake_llm) + + def test_stream_exception_classifies_thread_busy(): exc = BusyError(request_id="thread-123") kind, code, severity, is_expected, user_message, extra = _classify_stream_exception( From 789d8ce62ed173a8a2e98b1fe3d9a14f620beb69 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 02:08:34 +0530 Subject: [PATCH 60/68] feat(stream_new_chat): wire preflight + early repin into auto-mode flow --- .../app/tasks/chat/stream_new_chat.py | 215 ++++++++++++++++++ 1 file changed, 215 insertions(+) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index dbfd5e2ea..07d14afeb 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -65,6 +65,8 @@ from app.db import ( ) from app.prompts import TITLE_GENERATION_PROMPT from app.services.auto_model_pin_service import ( + is_recently_healthy, + mark_healthy, mark_runtime_cooldown, resolve_or_get_pinned_llm_config_id, ) @@ -471,6 +473,54 @@ def _is_provider_rate_limited(exc: BaseException) -> bool: ) +_PREFLIGHT_TIMEOUT_SEC: float = 2.5 +_PREFLIGHT_MAX_TOKENS: int = 1 + + +async def _preflight_llm(llm: Any) -> None: + """Issue a minimal completion to confirm the pinned model isn't 429'ing. + + Used before agent build / planner / classifier / title-gen so a known-bad + free OpenRouter deployment is detected and repinned before it cascades + into multiple wasted internal calls. The probe is intentionally cheap: + one token, low timeout, tagged ``surfsense:internal`` so token tracking + and SSE pipelines treat it as overhead rather than user output. + + Raises the original exception when the provider responds with a + rate-limit-shaped error so the caller can drive the cooldown/repin + branch via :func:`_is_provider_rate_limited`. Other transient failures + are swallowed — the caller continues to the normal stream path and the + in-stream recovery loop remains the safety net. + """ + from litellm import acompletion + + model = getattr(llm, "model", None) + if not model or model == "auto": + # Auto-mode router doesn't have a single deployment to ping; the + # router itself handles per-deployment rate-limit accounting. + return + + try: + await acompletion( + model=model, + messages=[{"role": "user", "content": "ping"}], + api_key=getattr(llm, "api_key", None), + api_base=getattr(llm, "api_base", None), + max_tokens=_PREFLIGHT_MAX_TOKENS, + timeout=_PREFLIGHT_TIMEOUT_SEC, + stream=False, + metadata={"tags": ["surfsense:internal", "auto-pin-preflight"]}, + ) + except Exception as exc: + if _is_provider_rate_limited(exc): + raise + logging.getLogger(__name__).debug( + "auto_pin_preflight non_rate_limit_error model=%s err=%s", + model, + exc, + ) + + def _classify_stream_exception( exc: Exception, *, @@ -2371,6 +2421,92 @@ async def stream_new_chat( yield streaming_service.format_done() return + # Auto-mode preflight ping. Runs ONLY for thread-pinned auto cfgs + # (negative ids selected via ``resolve_or_get_pinned_llm_config_id``) + # whose health hasn't already been confirmed within the TTL window. + # Detecting a 429 here lets us repin BEFORE the planner/classifier/ + # title-generation LLM calls fan out and each independently hit the + # same upstream rate limit. + if ( + requested_llm_config_id == 0 + and llm_config_id < 0 + and not is_recently_healthy(llm_config_id) + ): + _t_preflight = time.perf_counter() + try: + await _preflight_llm(llm) + mark_healthy(llm_config_id) + _perf_log.info( + "[stream_new_chat] auto_pin_preflight ok config_id=%s " + "took=%.3fs", + llm_config_id, + time.perf_counter() - _t_preflight, + ) + except Exception as preflight_exc: + if not _is_provider_rate_limited(preflight_exc): + raise + previous_config_id = llm_config_id + mark_runtime_cooldown( + previous_config_id, reason="preflight_rate_limited" + ) + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + exclude_config_ids={previous_config_id}, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield _emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + + llm, agent_config, llm_load_error = await _load_llm_bundle( + llm_config_id + ) + if llm_load_error or not llm: + yield _emit_stream_error( + message=llm_load_error or "Failed to create LLM instance", + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + # Trust the freshly-resolved cfg for the remainder of this + # turn rather than recursing into another preflight; the + # in-stream 429 recovery loop is still in place as the + # safety net if even this fallback hits an upstream cap. + mark_healthy(llm_config_id) + _log_chat_stream_error( + flow=flow, + error_kind="rate_limited", + error_code="RATE_LIMITED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Auto-pinned model failed preflight; switched to another " + "eligible model and continuing." + ), + extra={ + "auto_runtime_recover": True, + "preflight": True, + "previous_config_id": previous_config_id, + "fallback_config_id": llm_config_id, + }, + ) + # Create connector service _t0 = time.perf_counter() connector_service = ConnectorService(session, search_space_id=search_space_id) @@ -3327,6 +3463,85 @@ async def stream_resume_chat( yield streaming_service.format_done() return + # Auto-mode preflight ping (resume path). Mirrors ``stream_new_chat``: + # one cheap probe before the agent is rebuilt so a 429'd pin gets + # repinned without burning planner/classifier/title calls first. + if ( + requested_llm_config_id == 0 + and llm_config_id < 0 + and not is_recently_healthy(llm_config_id) + ): + _t_preflight = time.perf_counter() + try: + await _preflight_llm(llm) + mark_healthy(llm_config_id) + _perf_log.info( + "[stream_resume] auto_pin_preflight ok config_id=%s " + "took=%.3fs", + llm_config_id, + time.perf_counter() - _t_preflight, + ) + except Exception as preflight_exc: + if not _is_provider_rate_limited(preflight_exc): + raise + previous_config_id = llm_config_id + mark_runtime_cooldown( + previous_config_id, reason="preflight_rate_limited" + ) + try: + llm_config_id = ( + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + selected_llm_config_id=0, + exclude_config_ids={previous_config_id}, + ) + ).resolved_llm_config_id + except ValueError as pin_error: + yield _emit_stream_error( + message=str(pin_error), + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + + llm, agent_config, llm_load_error = await _load_llm_bundle( + llm_config_id + ) + if llm_load_error or not llm: + yield _emit_stream_error( + message=llm_load_error or "Failed to create LLM instance", + error_kind="server_error", + error_code="SERVER_ERROR", + ) + yield streaming_service.format_done() + return + mark_healthy(llm_config_id) + _log_chat_stream_error( + flow="resume", + error_kind="rate_limited", + error_code="RATE_LIMITED", + severity="info", + is_expected=True, + request_id=request_id, + thread_id=chat_id, + search_space_id=search_space_id, + user_id=user_id, + message=( + "Auto-pinned model failed preflight; switched to another " + "eligible model and continuing." + ), + extra={ + "auto_runtime_recover": True, + "preflight": True, + "previous_config_id": previous_config_id, + "fallback_config_id": llm_config_id, + }, + ) + _t0 = time.perf_counter() connector_service = ConnectorService(session, search_space_id=search_space_id) From d14fed43c6f92e03907974c3ebb6318d77d3a0f9 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 02:45:27 +0530 Subject: [PATCH 61/68] feat(documents): add endpoint to retrieve document by virtual path --- .../app/routes/documents_routes.py | 45 ++++++ .../app/tasks/chat/stream_new_chat.py | 24 +-- .../unit/test_stream_new_chat_contract.py | 34 ++++ .../components/assistant-ui/markdown-text.tsx | 150 ++++++++++++------ .../lib/apis/documents-api.service.ts | 12 ++ 5 files changed, 206 insertions(+), 59 deletions(-) diff --git a/surfsense_backend/app/routes/documents_routes.py b/surfsense_backend/app/routes/documents_routes.py index f558481cf..f1ca3b6bf 100644 --- a/surfsense_backend/app/routes/documents_routes.py +++ b/surfsense_backend/app/routes/documents_routes.py @@ -745,6 +745,51 @@ async def search_document_titles( ) from e +@router.get("/documents/by-virtual-path", response_model=DocumentTitleRead) +async def get_document_by_virtual_path( + search_space_id: int, + virtual_path: str, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """Resolve a knowledge-base document id by exact virtual path.""" + try: + await check_permission( + session, + user, + search_space_id, + Permission.DOCUMENTS_READ.value, + "You don't have permission to read documents in this search space", + ) + + result = await session.execute( + select( + Document.id, + Document.title, + Document.document_type, + ).filter( + Document.search_space_id == search_space_id, + Document.document_metadata["virtual_path"].as_string() == virtual_path, + ) + ) + row = result.first() + if row is None: + raise HTTPException(status_code=404, detail="Document not found") + + return DocumentTitleRead( + id=row.id, + title=row.title, + document_type=row.document_type, + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Failed to resolve document by virtual path: {e!s}", + ) from e + + @router.get("/documents/status", response_model=DocumentStatusBatchResponse) async def get_documents_status( search_space_id: int, diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 07d14afeb..53f237f06 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -304,20 +304,17 @@ def _tool_output_has_error(tool_output: Any) -> bool: return False -def _extract_resolved_file_path(*, tool_name: str, tool_output: Any) -> str | None: +def _extract_resolved_file_path( + *, tool_name: str, tool_output: Any, tool_input: Any | None = None +) -> str | None: if isinstance(tool_output, dict): path_value = tool_output.get("path") if isinstance(path_value, str) and path_value.strip(): return path_value.strip() - text = _tool_output_to_text(tool_output) - if tool_name == "write_file": - match = re.search(r"Updated file\s+(.+)$", text.strip()) - if match: - return match.group(1).strip() - if tool_name == "edit_file": - match = re.search(r"in '([^']+)'", text) - if match: - return match.group(1).strip() + if tool_name in ("write_file", "edit_file") and isinstance(tool_input, dict): + file_path = tool_input.get("file_path") + if isinstance(file_path, str) and file_path.strip(): + return file_path.strip() return None @@ -714,6 +711,7 @@ async def _stream_agent_events( # fallback path only and never re-pops a chunk we already streamed. pending_tool_call_chunks: list[dict[str, Any]] = [] lc_tool_call_id_by_run: dict[str, str] = {} + file_path_by_run: dict[str, str] = {} # parity_v2 only: live tool-call argument streaming. ``index_to_meta`` # is keyed by the chunk's ``index`` field — LangChain @@ -892,6 +890,10 @@ async def _stream_agent_events( tool_input = event.get("data", {}).get("input", {}) if tool_name in ("write_file", "edit_file"): result.write_attempted = True + if isinstance(tool_input, dict): + file_path = tool_input.get("file_path") + if isinstance(file_path, str) and file_path.strip() and run_id: + file_path_by_run[run_id] = file_path.strip() if current_text_id is not None: yield streaming_service.format_text_end(current_text_id) @@ -1298,6 +1300,7 @@ async def _stream_agent_events( run_id = event.get("run_id", "") tool_name = event.get("name", "unknown_tool") raw_output = event.get("data", {}).get("output", "") + staged_file_path = file_path_by_run.pop(run_id, None) if run_id else None if tool_name == "update_memory": called_update_memory = True @@ -1811,6 +1814,7 @@ async def _stream_agent_events( resolved_path = _extract_resolved_file_path( tool_name=tool_name, tool_output=tool_output, + tool_input={"file_path": staged_file_path} if staged_file_path else None, ) result_text = _tool_output_to_text(tool_output) if _tool_output_has_error(tool_output): diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 6a1b4c13b..3676601f4 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -13,6 +13,7 @@ from app.tasks.chat.stream_new_chat import ( StreamResult, _classify_stream_exception, _contract_enforcement_active, + _extract_resolved_file_path, _evaluate_file_contract_outcome, _log_chat_stream_error, _tool_output_has_error, @@ -28,6 +29,39 @@ def test_tool_output_error_detection(): assert not _tool_output_has_error({"result": "Updated file /notes.md"}) +def test_extract_resolved_file_path_prefers_structured_path(): + assert ( + _extract_resolved_file_path( + tool_name="write_file", + tool_output={"status": "completed", "path": "/docs/note.md"}, + tool_input=None, + ) + == "/docs/note.md" + ) + + +def test_extract_resolved_file_path_falls_back_to_tool_input(): + assert ( + _extract_resolved_file_path( + tool_name="edit_file", + tool_output={"status": "completed", "result": "updated"}, + tool_input={"file_path": "/docs/edited.md"}, + ) + == "/docs/edited.md" + ) + + +def test_extract_resolved_file_path_does_not_parse_result_text(): + assert ( + _extract_resolved_file_path( + tool_name="write_file", + tool_output={"result": "Updated file /docs/from-text.md"}, + tool_input=None, + ) + is None + ) + + def test_file_write_contract_outcome_reasons(): result = StreamResult(intent_detected="file_write") passed, reason = _evaluate_file_contract_outcome(result) diff --git a/surfsense_web/components/assistant-ui/markdown-text.tsx b/surfsense_web/components/assistant-ui/markdown-text.tsx index 4842e5979..bfbc3a423 100644 --- a/surfsense_web/components/assistant-ui/markdown-text.tsx +++ b/surfsense_web/components/assistant-ui/markdown-text.tsx @@ -30,8 +30,10 @@ import { TableRow, } from "@/components/ui/table"; import { useElectronAPI } from "@/hooks/use-platform"; +import { documentsApiService } from "@/lib/apis/documents-api.service"; import { type CitationUrlMap, preprocessCitationMarkdown } from "@/lib/citations/citation-parser"; import { cn } from "@/lib/utils"; +import { toast } from "sonner"; function MarkdownCodeBlockSkeleton() { return ( @@ -194,6 +196,89 @@ function isVirtualFilePathToken(value: string): boolean { return segments.length >= 2; } +function isStandaloneDocumentsPathText(node: ReactNode): string | null { + if (typeof node !== "string") return null; + const value = node.trim(); + if (!value.startsWith("/documents/")) return null; + if (value.includes(" ")) return null; + const normalized = value.replace(/\/+$/, ""); + const leaf = normalized.split("/").filter(Boolean).at(-1) ?? ""; + if (!leaf || !leaf.includes(".")) return null; + return value; +} + +function FilePathLink({ + path, + className, +}: { + path: string; + className?: string; +}) { + const openEditorPanel = useSetAtom(openEditorPanelAtom); + const params = useParams(); + const electronAPI = useElectronAPI(); + const searchSpaceIdParam = params?.search_space_id; + const parsedSearchSpaceId = Array.isArray(searchSpaceIdParam) + ? Number(searchSpaceIdParam[0]) + : Number(searchSpaceIdParam); + const resolvedSearchSpaceId = Number.isFinite(parsedSearchSpaceId) ? parsedSearchSpaceId : undefined; + + return ( + <button + type="button" + className={cn( + "cursor-pointer font-mono text-[0.9em] font-medium text-primary underline underline-offset-4 transition-colors hover:text-primary/80", + className + )} + onClick={(event) => { + event.preventDefault(); + event.stopPropagation(); + void (async () => { + if (electronAPI) { + let resolvedLocalPath = path; + if (electronAPI.getAgentFilesystemMounts) { + try { + const mounts = (await electronAPI.getAgentFilesystemMounts( + resolvedSearchSpaceId + )) as AgentFilesystemMount[]; + resolvedLocalPath = normalizeLocalVirtualPathForEditor(path, mounts); + } catch { + // Fall back to the raw path if mount lookup fails. + } + } + openEditorPanel({ + kind: "local_file", + localFilePath: resolvedLocalPath, + title: resolvedLocalPath.split("/").pop() || resolvedLocalPath, + searchSpaceId: resolvedSearchSpaceId, + }); + return; + } + + if (!resolvedSearchSpaceId || !path.startsWith("/documents/")) return; + try { + const doc = await documentsApiService.getDocumentByVirtualPath({ + search_space_id: resolvedSearchSpaceId, + virtual_path: path, + }); + openEditorPanel({ + kind: "document", + documentId: doc.id, + searchSpaceId: resolvedSearchSpaceId, + title: doc.title, + }); + } catch { + toast.error("Document not found in knowledge base."); + } + })(); + }} + title="Open in editor panel" + > + {path} + </button> + ); +} + function MarkdownImage({ src, alt }: { src?: string; alt?: string }) { if (!src) return null; @@ -311,9 +396,14 @@ const defaultComponents = memoizeMarkdownComponents({ }, p: function P({ className, children, ...props }) { const urlMap = useCitationUrlMap(); + const standalonePath = isStandaloneDocumentsPathText(children); return ( <p className={cn("aui-md-p mt-5 mb-5 leading-7 first:mt-0 last:mb-0", className)} {...props}> - {processChildrenWithCitations(children, urlMap)} + {standalonePath ? ( + <FilePathLink path={standalonePath} /> + ) : ( + processChildrenWithCitations(children, urlMap) + )} </p> ); }, @@ -400,8 +490,6 @@ const defaultComponents = memoizeMarkdownComponents({ code: function Code({ className, children, ...props }) { const isCodeBlock = useIsMarkdownCodeBlock(); const { resolvedTheme } = useTheme(); - const openEditorPanel = useSetAtom(openEditorPanelAtom); - const params = useParams(); const electronAPI = useElectronAPI(); const language = /language-(\w+)/.exec(className || "")?.[1] ?? "text"; const codeString = String(children).replace(/\n$/, ""); @@ -418,53 +506,17 @@ const defaultComponents = memoizeMarkdownComponents({ const isLikelyFolder = inlineValue.endsWith("/") || !leafSegment || !leafSegment.includes("."); const isLocalPath = - !!electronAPI && - isVirtualFilePathToken(inlineValue) && - !inlineValue.startsWith("//") && - !isLikelyFolder; - const displayLocalPath = inlineValue.replace(/^\/+/, ""); - const searchSpaceIdParam = params?.search_space_id; - const parsedSearchSpaceId = Array.isArray(searchSpaceIdParam) - ? Number(searchSpaceIdParam[0]) - : Number(searchSpaceIdParam); + (isVirtualFilePathToken(inlineValue) && + !inlineValue.startsWith("//") && + !isLikelyFolder && + !!electronAPI) || + (isVirtualFilePathToken(inlineValue) && + !inlineValue.startsWith("//") && + !isLikelyFolder && + !electronAPI && + inlineValue.startsWith("/documents/")); if (isLocalPath) { - return ( - <button - type="button" - className={cn( - "cursor-pointer font-mono text-[0.9em] font-medium text-primary underline underline-offset-4 transition-colors hover:text-primary/80" - )} - onClick={(event) => { - event.preventDefault(); - event.stopPropagation(); - void (async () => { - let resolvedLocalPath = inlineValue; - const resolvedSearchSpaceId = Number.isFinite(parsedSearchSpaceId) - ? parsedSearchSpaceId - : undefined; - if (electronAPI?.getAgentFilesystemMounts) { - try { - const mounts = (await electronAPI.getAgentFilesystemMounts( - resolvedSearchSpaceId - )) as AgentFilesystemMount[]; - resolvedLocalPath = normalizeLocalVirtualPathForEditor(inlineValue, mounts); - } catch { - // Fall back to the raw inline path if mount lookup fails. - } - } - openEditorPanel({ - kind: "local_file", - localFilePath: resolvedLocalPath, - title: resolvedLocalPath.split("/").pop() || resolvedLocalPath, - searchSpaceId: resolvedSearchSpaceId, - }); - })(); - }} - title="Open in editor panel" - > - {displayLocalPath} - </button> - ); + return <FilePathLink path={inlineValue} className="text-[0.9em]" />; } return ( <code diff --git a/surfsense_web/lib/apis/documents-api.service.ts b/surfsense_web/lib/apis/documents-api.service.ts index 0cd81c0b7..949e3b29f 100644 --- a/surfsense_web/lib/apis/documents-api.service.ts +++ b/surfsense_web/lib/apis/documents-api.service.ts @@ -28,6 +28,7 @@ import { getSurfsenseDocsRequest, getSurfsenseDocsResponse, type SearchDocumentsRequest, + documentTitleRead, type SearchDocumentTitlesRequest, searchDocumentsRequest, searchDocumentsResponse, @@ -269,6 +270,17 @@ class DocumentsApiService { ); }; + getDocumentByVirtualPath = async (request: { + search_space_id: number; + virtual_path: string; + }) => { + const params = new URLSearchParams({ + search_space_id: String(request.search_space_id), + virtual_path: request.virtual_path, + }); + return baseApiService.get(`/api/v1/documents/by-virtual-path?${params.toString()}`, documentTitleRead); + }; + /** * Get document type counts */ From e9d964514bdd1585f051616c90db924978341f26 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 03:31:03 +0530 Subject: [PATCH 62/68] feat(alembic): add user table to zero_publication for selective replication of usage metrics --- .../139_add_user_to_zero_publication.py | 158 ++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py diff --git a/surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py b/surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py new file mode 100644 index 000000000..5b8bc29b0 --- /dev/null +++ b/surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py @@ -0,0 +1,158 @@ +"""add user table to zero_publication with column list + +Adds the "user" table to zero_publication with a column-list publication +so that only the 5 fields driving the live usage meters are replicated +through WAL -> zero-cache -> browser IndexedDB: + + id, pages_limit, pages_used, + premium_tokens_limit, premium_tokens_used + +Sensitive columns (hashed_password, email, oauth_account, display_name, +avatar_url, memory_md, refresh_tokens, last_login, etc.) are NOT +included in the publication, so they never enter WAL replication. + +Also re-asserts REPLICA IDENTITY DEFAULT on "user" for idempotency +(it is already DEFAULT today since "user" was never in the +TABLES_WITH_FULL_IDENTITY list of migration 117). + +IMPORTANT - before AND after running this migration: + 1. Stop zero-cache (it holds replication locks that will deadlock DDL) + 2. Run: alembic upgrade head + 3. Delete / reset the zero-cache data volume + 4. Restart zero-cache (it will do a fresh initial sync) + +Revision ID: 139 +Revises: 138 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "139" +down_revision: str | None = "138" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +PUBLICATION_NAME = "zero_publication" + +# Document column list as left by migration 117. Must match exactly. +DOCUMENT_COLS = [ + "id", + "title", + "document_type", + "search_space_id", + "folder_id", + "created_by_id", + "status", + "created_at", + "updated_at", +] + +# Five fields needed by the live usage meters (sidebar Tokens/Pages, +# Buy Tokens content). Keep this list narrow on purpose: anything added +# here flows into WAL and IndexedDB for every connected browser. +USER_COLS = [ + "id", + "pages_limit", + "pages_used", + "premium_tokens_limit", + "premium_tokens_used", +] + + +def _terminate_blocked_pids(conn, table: str) -> None: + """Kill backends whose locks on *table* would block our AccessExclusiveLock.""" + conn.execute( + sa.text( + "SELECT pg_terminate_backend(l.pid) " + "FROM pg_locks l " + "JOIN pg_class c ON c.oid = l.relation " + "WHERE c.relname = :tbl " + " AND l.pid != pg_backend_pid()" + ), + {"tbl": table}, + ) + + +def _has_zero_version(conn, table: str) -> bool: + return ( + conn.execute( + sa.text( + "SELECT 1 FROM information_schema.columns " + "WHERE table_name = :tbl AND column_name = '_0_version'" + ), + {"tbl": table}, + ).fetchone() + is not None + ) + + +def _build_publication_ddl(documents_has_zero_ver: bool, user_has_zero_ver: bool) -> str: + doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else []) + user_cols = USER_COLS + (['"_0_version"'] if user_has_zero_ver else []) + doc_col_list = ", ".join(doc_cols) + user_col_list = ", ".join(user_cols) + return ( + f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE " + f"notifications, " + f"documents ({doc_col_list}), " + f"folders, " + f"search_source_connectors, " + f"new_chat_messages, " + f"chat_comments, " + f"chat_session_state, " + f'"user" ({user_col_list})' + ) + + +def _build_publication_ddl_without_user(documents_has_zero_ver: bool) -> str: + doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else []) + doc_col_list = ", ".join(doc_cols) + return ( + f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE " + f"notifications, " + f"documents ({doc_col_list}), " + f"folders, " + f"search_source_connectors, " + f"new_chat_messages, " + f"chat_comments, " + f"chat_session_state" + ) + + +def upgrade() -> None: + conn = op.get_bind() + # asyncpg requires LOCK TABLE inside a transaction block. Alembic already + # opened one via context.begin_transaction(), but the driver still errors + # unless we use an explicit SAVEPOINT (nested transaction) for this block. + tx = conn.begin_nested() if conn.in_transaction() else conn.begin() + with tx: + conn.execute(sa.text("SET lock_timeout = '10s'")) + + _terminate_blocked_pids(conn, "user") + conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE')) + + # Idempotent: "user" was never in TABLES_WITH_FULL_IDENTITY of + # migration 117, so this is already DEFAULT. Re-assert anyway so + # the column-list publication stays valid (DEFAULT identity only + # requires the PK to be in the column list). + conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT')) + + conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}")) + + documents_has_zero_ver = _has_zero_version(conn, "documents") + user_has_zero_ver = _has_zero_version(conn, "user") + + conn.execute( + sa.text(_build_publication_ddl(documents_has_zero_ver, user_has_zero_ver)) + ) + + +def downgrade() -> None: + conn = op.get_bind() + conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}")) + documents_has_zero_ver = _has_zero_version(conn, "documents") + conn.execute(sa.text(_build_publication_ddl_without_user(documents_has_zero_ver))) From 05eef5a7db42f215fdbcc6115fbe609641b72c7f Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 03:31:50 +0530 Subject: [PATCH 63/68] feat(zero): add userTable + queries.user.me() synced query --- surfsense_web/zero/queries/index.ts | 2 ++ surfsense_web/zero/queries/user.ts | 11 +++++++++++ surfsense_web/zero/schema/index.ts | 2 ++ surfsense_web/zero/schema/user.ts | 11 +++++++++++ 4 files changed, 26 insertions(+) create mode 100644 surfsense_web/zero/queries/user.ts create mode 100644 surfsense_web/zero/schema/user.ts diff --git a/surfsense_web/zero/queries/index.ts b/surfsense_web/zero/queries/index.ts index bc332114e..fbf1bd76e 100644 --- a/surfsense_web/zero/queries/index.ts +++ b/surfsense_web/zero/queries/index.ts @@ -3,6 +3,7 @@ import { chatSessionQueries, commentQueries, messageQueries } from "./chat"; import { connectorQueries, documentQueries } from "./documents"; import { folderQueries } from "./folders"; import { notificationQueries } from "./inbox"; +import { userQueries } from "./user"; export const queries = defineQueries({ notifications: notificationQueries, @@ -12,4 +13,5 @@ export const queries = defineQueries({ messages: messageQueries, comments: commentQueries, chatSession: chatSessionQueries, + user: userQueries, }); diff --git a/surfsense_web/zero/queries/user.ts b/surfsense_web/zero/queries/user.ts new file mode 100644 index 000000000..30e71a482 --- /dev/null +++ b/surfsense_web/zero/queries/user.ts @@ -0,0 +1,11 @@ +import { defineQuery } from "@rocicorp/zero"; +import { z } from "zod"; +import { zql } from "../schema/index"; + +export const userQueries = { + me: defineQuery(z.object({}), ({ ctx }) => { + const userId = ctx?.userId; + if (!userId) return zql.user.where("id", "__none__").one(); + return zql.user.where("id", userId).one(); + }), +}; diff --git a/surfsense_web/zero/schema/index.ts b/surfsense_web/zero/schema/index.ts index bba561580..3cca0f24a 100644 --- a/surfsense_web/zero/schema/index.ts +++ b/surfsense_web/zero/schema/index.ts @@ -3,6 +3,7 @@ import { chatCommentTable, chatSessionStateTable, newChatMessageTable } from "./ import { documentTable, searchSourceConnectorTable } from "./documents"; import { folderTable } from "./folders"; import { notificationTable } from "./inbox"; +import { userTable } from "./user"; const chatCommentRelationships = relationships(chatCommentTable, ({ one }) => ({ message: one({ @@ -34,6 +35,7 @@ export const schema = createSchema({ newChatMessageTable, chatCommentTable, chatSessionStateTable, + userTable, ], relationships: [chatCommentRelationships, newChatMessageRelationships], }); diff --git a/surfsense_web/zero/schema/user.ts b/surfsense_web/zero/schema/user.ts new file mode 100644 index 000000000..0e6234db5 --- /dev/null +++ b/surfsense_web/zero/schema/user.ts @@ -0,0 +1,11 @@ +import { number, string, table } from "@rocicorp/zero"; + +export const userTable = table("user") + .columns({ + id: string(), + pagesLimit: number().from("pages_limit"), + pagesUsed: number().from("pages_used"), + premiumTokensLimit: number().from("premium_tokens_limit"), + premiumTokensUsed: number().from("premium_tokens_used"), + }) + .primaryKey("id"); From 2a14c0528251e03a8e2ecff92c558e1628af5f27 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 03:32:05 +0530 Subject: [PATCH 64/68] feat(sidebar): live premium tokens meter via Zero --- .../ui/sidebar/PremiumTokenUsageDisplay.tsx | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/surfsense_web/components/layout/ui/sidebar/PremiumTokenUsageDisplay.tsx b/surfsense_web/components/layout/ui/sidebar/PremiumTokenUsageDisplay.tsx index a4d760dba..a3f028858 100644 --- a/surfsense_web/components/layout/ui/sidebar/PremiumTokenUsageDisplay.tsx +++ b/surfsense_web/components/layout/ui/sidebar/PremiumTokenUsageDisplay.tsx @@ -1,23 +1,18 @@ "use client"; -import { useQuery } from "@tanstack/react-query"; +import { useQuery } from "@rocicorp/zero/react"; import { Progress } from "@/components/ui/progress"; import { useIsAnonymous } from "@/contexts/anonymous-mode"; -import { stripeApiService } from "@/lib/apis/stripe-api.service"; +import { queries } from "@/zero/queries"; export function PremiumTokenUsageDisplay() { const isAnonymous = useIsAnonymous(); - const { data: tokenStatus } = useQuery({ - queryKey: ["token-status"], - queryFn: () => stripeApiService.getTokenStatus(), - staleTime: 60_000, - enabled: !isAnonymous, - }); + const [me] = useQuery(queries.user.me({})); - if (!tokenStatus) return null; + if (isAnonymous || !me) return null; const usagePercentage = Math.min( - (tokenStatus.premium_tokens_used / Math.max(tokenStatus.premium_tokens_limit, 1)) * 100, + (me.premiumTokensUsed / Math.max(me.premiumTokensLimit, 1)) * 100, 100 ); @@ -31,8 +26,7 @@ export function PremiumTokenUsageDisplay() { <div className="space-y-1.5"> <div className="flex justify-between items-center text-xs"> <span className="text-muted-foreground"> - {formatTokens(tokenStatus.premium_tokens_used)} /{" "} - {formatTokens(tokenStatus.premium_tokens_limit)} tokens + {formatTokens(me.premiumTokensUsed)} / {formatTokens(me.premiumTokensLimit)} tokens </span> <span className="font-medium">{usagePercentage.toFixed(0)}%</span> </div> From 6b06416d4761007cd6a4551313d7038cfef52cc7 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 03:32:19 +0530 Subject: [PATCH 65/68] feat(sidebar): live pages meter via Zero for authenticated users --- .../layout/providers/LayoutDataProvider.tsx | 9 --------- .../ui/sidebar/AuthenticatedPageUsageDisplay.tsx | 15 +++++++++++++++ .../components/layout/ui/sidebar/Sidebar.tsx | 6 ++---- 3 files changed, 17 insertions(+), 13 deletions(-) create mode 100644 surfsense_web/components/layout/ui/sidebar/AuthenticatedPageUsageDisplay.tsx diff --git a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx index afd888f48..d70a7ade4 100644 --- a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx +++ b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx @@ -681,14 +681,6 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid } }, [chatToRename, newChatTitle, queryClient, searchSpaceId, tSidebar]); - // Page usage - const pageUsage = user - ? { - pagesUsed: user.pages_used, - pagesLimit: user.pages_limit, - } - : undefined; - // Detect if we're on the chat page (needs overflow-hidden for chat's own scroll) const isChatPage = pathname?.includes("/new-chat") ?? false; @@ -723,7 +715,6 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid onManageMembers={handleManageMembers} onUserSettings={handleUserSettings} onLogout={handleLogout} - pageUsage={pageUsage} theme={theme} setTheme={setTheme} isChatPage={isChatPage} diff --git a/surfsense_web/components/layout/ui/sidebar/AuthenticatedPageUsageDisplay.tsx b/surfsense_web/components/layout/ui/sidebar/AuthenticatedPageUsageDisplay.tsx new file mode 100644 index 000000000..ad31d50bb --- /dev/null +++ b/surfsense_web/components/layout/ui/sidebar/AuthenticatedPageUsageDisplay.tsx @@ -0,0 +1,15 @@ +"use client"; + +import { useQuery } from "@rocicorp/zero/react"; +import { useIsAnonymous } from "@/contexts/anonymous-mode"; +import { queries } from "@/zero/queries"; +import { PageUsageDisplay } from "./PageUsageDisplay"; + +export function AuthenticatedPageUsageDisplay() { + const isAnonymous = useIsAnonymous(); + const [me] = useQuery(queries.user.me({})); + + if (isAnonymous || !me) return null; + + return <PageUsageDisplay pagesUsed={me.pagesUsed} pagesLimit={me.pagesLimit} />; +} diff --git a/surfsense_web/components/layout/ui/sidebar/Sidebar.tsx b/surfsense_web/components/layout/ui/sidebar/Sidebar.tsx index adad52792..d5038ea05 100644 --- a/surfsense_web/components/layout/ui/sidebar/Sidebar.tsx +++ b/surfsense_web/components/layout/ui/sidebar/Sidebar.tsx @@ -12,9 +12,9 @@ import { useIsAnonymous } from "@/contexts/anonymous-mode"; import { cn } from "@/lib/utils"; import { SIDEBAR_MIN_WIDTH } from "../../hooks/useSidebarResize"; import type { ChatItem, NavItem, PageUsage, SearchSpace, User } from "../../types/layout.types"; +import { AuthenticatedPageUsageDisplay } from "./AuthenticatedPageUsageDisplay"; import { ChatListItem } from "./ChatListItem"; import { NavSection } from "./NavSection"; -import { PageUsageDisplay } from "./PageUsageDisplay"; import { PremiumTokenUsageDisplay } from "./PremiumTokenUsageDisplay"; import { SidebarButton } from "./SidebarButton"; import { SidebarCollapseButton } from "./SidebarCollapseButton"; @@ -338,9 +338,7 @@ function SidebarUsageFooter({ return ( <div className="px-3 py-3 border-t space-y-3"> <PremiumTokenUsageDisplay /> - {pageUsage && ( - <PageUsageDisplay pagesUsed={pageUsage.pagesUsed} pagesLimit={pageUsage.pagesLimit} /> - )} + <AuthenticatedPageUsageDisplay /> <div className="space-y-0.5"> <Link href={`/dashboard/${searchSpaceId}/more-pages`} From 38a4742ec688d741da301f31f6657e731f3d3033 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 03:32:37 +0530 Subject: [PATCH 66/68] feat(settings): live buy-tokens meter via Zero --- .../settings/buy-tokens-content.tsx | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/surfsense_web/components/settings/buy-tokens-content.tsx b/surfsense_web/components/settings/buy-tokens-content.tsx index 649a50639..e7fac4255 100644 --- a/surfsense_web/components/settings/buy-tokens-content.tsx +++ b/surfsense_web/components/settings/buy-tokens-content.tsx @@ -1,5 +1,6 @@ "use client"; +import { useQuery as useZeroQuery } from "@rocicorp/zero/react"; import { useMutation, useQuery } from "@tanstack/react-query"; import { Minus, Plus } from "lucide-react"; import { useParams } from "next/navigation"; @@ -11,6 +12,7 @@ import { Spinner } from "@/components/ui/spinner"; import { stripeApiService } from "@/lib/apis/stripe-api.service"; import { AppError } from "@/lib/error"; import { cn } from "@/lib/utils"; +import { queries } from "@/zero/queries"; const TOKEN_PACK_SIZE = 1_000_000; const PRICE_PER_PACK_USD = 1; @@ -21,11 +23,15 @@ export function BuyTokensContent() { const searchSpaceId = Number(params?.search_space_id); const [quantity, setQuantity] = useState(1); + // Server config flag: stays on REST, not per-user. const { data: tokenStatus } = useQuery({ queryKey: ["token-status"], queryFn: () => stripeApiService.getTokenStatus(), }); + // Live per-user usage via Zero. + const [me] = useZeroQuery(queries.user.me({})); + const purchaseMutation = useMutation({ mutationFn: stripeApiService.createTokenCheckoutSession, onSuccess: (response) => { @@ -54,12 +60,11 @@ export function BuyTokensContent() { ); } - const usagePercentage = tokenStatus - ? Math.min( - (tokenStatus.premium_tokens_used / Math.max(tokenStatus.premium_tokens_limit, 1)) * 100, - 100 - ) - : 0; + const used = me?.premiumTokensUsed ?? 0; + const limit = me?.premiumTokensLimit ?? 0; + // Mirrors the backend formula in stripe_routes.py:608 (max(0, limit - used)). + const remaining = Math.max(0, limit - used); + const usagePercentage = me ? Math.min((used / Math.max(limit, 1)) * 100, 100) : 0; return ( <div className="w-full space-y-5"> @@ -68,18 +73,17 @@ export function BuyTokensContent() { <p className="mt-1 text-sm text-muted-foreground">$1 per 1M tokens, pay as you go</p> </div> - {tokenStatus && ( + {me && ( <div className="rounded-lg border bg-muted/20 p-3 space-y-1.5"> <div className="flex justify-between items-center text-xs"> <span className="text-muted-foreground"> - {tokenStatus.premium_tokens_used.toLocaleString()} /{" "} - {tokenStatus.premium_tokens_limit.toLocaleString()} premium tokens + {used.toLocaleString()} / {limit.toLocaleString()} premium tokens </span> <span className="font-medium">{usagePercentage.toFixed(0)}%</span> </div> <Progress value={usagePercentage} className="h-1.5" /> <p className="text-[11px] text-muted-foreground"> - {tokenStatus.premium_tokens_remaining.toLocaleString()} tokens remaining + {remaining.toLocaleString()} tokens remaining </p> </div> )} From b9b4d0b3777667bb6aa59dacc6120d8ae8eb2783 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 03:32:58 +0530 Subject: [PATCH 67/68] chore(usage): stop polling /users/me and token-status for live fields --- .../[search_space_id]/purchase-success/page.tsx | 9 --------- surfsense_web/atoms/user/user-query.atoms.ts | 5 ++++- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx index 67d9edab0..85bc4aaa6 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/purchase-success/page.tsx @@ -1,11 +1,8 @@ "use client"; -import { useQueryClient } from "@tanstack/react-query"; import { CheckCircle2 } from "lucide-react"; import Link from "next/link"; import { useParams } from "next/navigation"; -import { useEffect } from "react"; -import { USER_QUERY_KEY } from "@/atoms/user/user-query.atoms"; import { Button } from "@/components/ui/button"; import { Card, @@ -18,14 +15,8 @@ import { export default function PurchaseSuccessPage() { const params = useParams(); - const queryClient = useQueryClient(); const searchSpaceId = String(params.search_space_id ?? ""); - useEffect(() => { - void queryClient.invalidateQueries({ queryKey: USER_QUERY_KEY }); - void queryClient.invalidateQueries({ queryKey: ["token-status"] }); - }, [queryClient]); - return ( <div className="flex min-h-[calc(100vh-64px)] items-center justify-center px-4 py-8"> <Card className="w-full max-w-lg"> diff --git a/surfsense_web/atoms/user/user-query.atoms.ts b/surfsense_web/atoms/user/user-query.atoms.ts index 8e196c9c7..a59811324 100644 --- a/surfsense_web/atoms/user/user-query.atoms.ts +++ b/surfsense_web/atoms/user/user-query.atoms.ts @@ -8,7 +8,10 @@ const userQueryFn = () => userApiService.getMe(); export const currentUserAtom = atomWithQuery(() => { return { queryKey: USER_QUERY_KEY, - staleTime: 5 * 60 * 1000, + // Live-changing numeric fields (pages_*, premium_tokens_*) are now + // pushed via Zero (queries.user.me()), so /users/me only needs to + // fire once per session for the static profile fields. + staleTime: Infinity, enabled: !!getBearerToken(), queryFn: userQueryFn, }; From cd25175b8459994b7dc982be1de5eb22b5bb7d32 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 2 May 2026 03:36:13 +0530 Subject: [PATCH 68/68] chore: ran linting --- .../139_add_user_to_zero_publication.py | 4 +- surfsense_backend/app/config/__init__.py | 4 +- .../app/services/auto_model_pin_service.py | 4 +- .../openrouter_integration_service.py | 26 +++---------- .../app/services/quality_score.py | 10 ++--- .../app/tasks/chat/stream_new_chat.py | 24 ++++++++---- .../services/test_auto_model_pin_service.py | 2 +- .../services/test_llm_router_pool_filter.py | 37 ++++++++++++------- .../test_openrouter_integration_service.py | 2 - .../services/test_openrouter_legacy_config.py | 4 +- .../tests/unit/services/test_quality_score.py | 9 +++-- .../unit/test_stream_new_chat_contract.py | 8 ++-- .../components/assistant-ui/markdown-text.tsx | 14 +++---- .../lib/apis/documents-api.service.ts | 12 +++--- 14 files changed, 78 insertions(+), 82 deletions(-) diff --git a/surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py b/surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py index 5b8bc29b0..83c96a429 100644 --- a/surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py +++ b/surfsense_backend/alembic/versions/139_add_user_to_zero_publication.py @@ -90,7 +90,9 @@ def _has_zero_version(conn, table: str) -> bool: ) -def _build_publication_ddl(documents_has_zero_ver: bool, user_has_zero_ver: bool) -> str: +def _build_publication_ddl( + documents_has_zero_ver: bool, user_has_zero_ver: bool +) -> str: doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else []) user_cols = USER_COLS + (['"_0_version"'] if user_has_zero_ver else []) doc_col_list = ", ".join(doc_cols) diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index b3eff571e..675b05d2c 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -286,9 +286,7 @@ def initialize_openrouter_integration(): if new_configs: config.GLOBAL_LLM_CONFIGS.extend(new_configs) - free_count = sum( - 1 for c in new_configs if c.get("billing_tier") == "free" - ) + free_count = sum(1 for c in new_configs if c.get("billing_tier") == "free") premium_count = sum( 1 for c in new_configs if c.get("billing_tier") == "premium" ) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index b2acd6f56..3a2c681b7 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -277,9 +277,7 @@ async def resolve_or_get_pinned_llm_config_id( c for c in _global_candidates() if int(c.get("id", 0)) not in excluded_ids ] if not candidates: - raise ValueError( - "No usable global LLM configs are available for Auto mode" - ) + raise ValueError("No usable global LLM configs are available for Auto mode") candidate_by_id = {int(c["id"]): c for c in candidates} # Reuse an existing valid pin without re-checking current quota (no silent diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index 67dbb6690..7e856d015 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -405,9 +405,7 @@ class OpenRouterIntegrationService: # Re-blend health scores against the freshly fetched catalogue. Also # re-stamps health for any YAML-curated cfg with provider==OPENROUTER # so a hand-picked dead OR model is gated like a dynamic one. - await self._enrich_health_safely( - static_configs + new_configs, log_summary=True - ) + await self._enrich_health_safely(static_configs + new_configs, log_summary=True) # Rebuild the LiteLLM router so freshly fetched configs flow through # (dynamic OR premium entries now opt into the pool, free ones stay @@ -415,8 +413,8 @@ class OpenRouterIntegrationService: # reset cached context-window profiles). try: from app.config import config as _app_config - from app.services.llm_router_service import LLMRouterService from app.services.llm_router_service import ( + LLMRouterService, _router_instance_cache as _chat_router_cache, ) @@ -426,9 +424,7 @@ class OpenRouterIntegrationService: ) _chat_router_cache.clear() except Exception as exc: - logger.warning( - "OpenRouter refresh: router rebuild skipped (%s)", exc - ) + logger.warning("OpenRouter refresh: router rebuild skipped (%s)", exc) @staticmethod def _tier_counts(configs: list[dict]) -> dict[str, int]: @@ -475,19 +471,11 @@ class OpenRouterIntegrationService: return premium_pool = sorted( - [ - c - for c in or_cfgs - if str(c.get("billing_tier", "")).lower() == "premium" - ], + [c for c in or_cfgs if str(c.get("billing_tier", "")).lower() == "premium"], key=lambda c: -int(c.get("quality_score_static") or 0), )[:_HEALTH_ENRICH_TOP_N_PREMIUM] free_pool = sorted( - [ - c - for c in or_cfgs - if str(c.get("billing_tier", "")).lower() == "free" - ], + [c for c in or_cfgs if str(c.get("billing_tier", "")).lower() == "free"], key=lambda c: -int(c.get("quality_score_static") or 0), )[:_HEALTH_ENRICH_TOP_N_FREE] # De-duplicate while preserving order: a cfg shouldn't fall in both @@ -507,9 +495,7 @@ class OpenRouterIntegrationService: api_key = str(self._settings.get("api_key") or "") semaphore = asyncio.Semaphore(_HEALTH_ENRICH_CONCURRENCY) - async with httpx.AsyncClient( - timeout=_HEALTH_FETCH_TIMEOUT_SEC - ) as client: + async with httpx.AsyncClient(timeout=_HEALTH_FETCH_TIMEOUT_SEC) as client: results = await asyncio.gather( *( self._fetch_endpoints(client, semaphore, api_key, cfg) diff --git a/surfsense_backend/app/services/quality_score.py b/surfsense_backend/app/services/quality_score.py index 8f6c75d56..2fb37de21 100644 --- a/surfsense_backend/app/services/quality_score.py +++ b/surfsense_backend/app/services/quality_score.py @@ -7,12 +7,12 @@ sort and a SHA256 pick. Score components (0-100 scale, higher is better): -* ``static_score_or`` – derived from the bulk ``/api/v1/models`` payload +* ``static_score_or`` - derived from the bulk ``/api/v1/models`` payload (provider prestige + ``created`` recency + pricing band + context window + capabilities + narrow tiny/legacy slug penalty). -* ``static_score_yaml`` – same shape for hand-curated YAML configs, plus +* ``static_score_yaml`` - same shape for hand-curated YAML configs, plus an operator-trust bonus (the operator deliberately picked this model). -* ``aggregate_health`` – run on per-model ``/api/v1/models/{id}/endpoints`` +* ``aggregate_health`` - run on per-model ``/api/v1/models/{id}/endpoints`` responses; returns ``(gated, score_or_none)``. The blended ``quality_score`` (0.5 * static + 0.5 * health) is computed in @@ -281,9 +281,7 @@ def static_score_yaml(cfg: dict) -> int: model_name = cfg.get("model_name") or "" litellm_params = cfg.get("litellm_params") or {} lookup_name = ( - litellm_params.get("base_model") - or litellm_params.get("model") - or model_name + litellm_params.get("base_model") or litellm_params.get("model") or model_name ) ctx = 0 diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 53f237f06..dbfe9a67b 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -1814,7 +1814,9 @@ async def _stream_agent_events( resolved_path = _extract_resolved_file_path( tool_name=tool_name, tool_output=tool_output, - tool_input={"file_path": staged_file_path} if staged_file_path else None, + tool_input={"file_path": staged_file_path} + if staged_file_path + else None, ) result_text = _tool_output_to_text(tool_output) if _tool_output_has_error(tool_output): @@ -2441,8 +2443,7 @@ async def stream_new_chat( await _preflight_llm(llm) mark_healthy(llm_config_id) _perf_log.info( - "[stream_new_chat] auto_pin_preflight ok config_id=%s " - "took=%.3fs", + "[stream_new_chat] auto_pin_preflight ok config_id=%s took=%.3fs", llm_config_id, time.perf_counter() - _t_preflight, ) @@ -2891,7 +2892,11 @@ async def stream_new_chat( # Inject title update mid-stream as soon as the background # task finishes. - if title_task is not None and title_task.done() and not title_emitted: + if ( + title_task is not None + and title_task.done() + and not title_emitted + ): generated_title, title_usage = title_task.result() if title_usage: accumulator.add(**title_usage) @@ -2944,7 +2949,9 @@ async def stream_new_chat( ) ).resolved_llm_config_id - llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + llm, agent_config, llm_load_error = await _load_llm_bundle( + llm_config_id + ) if llm_load_error: raise stream_exc @@ -3480,8 +3487,7 @@ async def stream_resume_chat( await _preflight_llm(llm) mark_healthy(llm_config_id) _perf_log.info( - "[stream_resume] auto_pin_preflight ok config_id=%s " - "took=%.3fs", + "[stream_resume] auto_pin_preflight ok config_id=%s took=%.3fs", llm_config_id, time.perf_counter() - _t_preflight, ) @@ -3684,7 +3690,9 @@ async def stream_resume_chat( ) ).resolved_llm_config_id - llm, agent_config, llm_load_error = await _load_llm_bundle(llm_config_id) + llm, agent_config, llm_load_error = await _load_llm_bundle( + llm_config_id + ) if llm_load_error: raise stream_exc diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index d333f0b7a..49b3621c7 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -574,7 +574,7 @@ async def test_top_k_picks_only_high_score_models(monkeypatch): monkeypatch.setattr( config, "GLOBAL_LLM_CONFIGS", - high_score_cfgs + [low_score_trap], + [*high_score_cfgs, low_score_trap], ) async def _allowed(*_args, **_kwargs): diff --git a/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py b/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py index 0191025ec..c309ff881 100644 --- a/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py +++ b/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py @@ -96,9 +96,12 @@ def test_router_pool_includes_or_premium_excludes_or_free(): ), ] - with patch("app.services.llm_router_service.Router") as mock_router, patch( - "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups" - ) as mock_ctx_fb: + with ( + patch("app.services.llm_router_service.Router") as mock_router, + patch( + "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups" + ) as mock_ctx_fb, + ): mock_ctx_fb.side_effect = lambda ml: (ml, None) mock_router.return_value = object() LLMRouterService.initialize(configs) @@ -124,9 +127,10 @@ def test_router_pool_includes_or_premium_excludes_or_free(): assert "openrouter/openai/gpt-4o" in prem assert LLMRouterService.is_premium_model("openrouter/openai/gpt-4o") is True # Dynamic OR free never enters the pool, so it's never counted as premium. - assert LLMRouterService.is_premium_model( - "openrouter/meta-llama/llama-3.3-70b:free" - ) is False + assert ( + LLMRouterService.is_premium_model("openrouter/meta-llama/llama-3.3-70b:free") + is False + ) def test_router_pool_filter_mechanics_respect_override(): @@ -147,9 +151,12 @@ def test_router_pool_filter_mechanics_respect_override(): ), ] - with patch("app.services.llm_router_service.Router") as mock_router, patch( - "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups" - ) as mock_ctx_fb: + with ( + patch("app.services.llm_router_service.Router") as mock_router, + patch( + "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups" + ) as mock_ctx_fb, + ): mock_ctx_fb.side_effect = lambda ml: (ml, None) mock_router.return_value = object() LLMRouterService.initialize(configs) @@ -167,13 +174,17 @@ def test_rebuild_refreshes_pool_after_configs_change(): configs_v1 = [ _fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"), ] - configs_v2 = configs_v1 + [ + configs_v2 = [ + *configs_v1, _fake_yaml_config(id=-2, model_name="gpt-4o-mini", billing_tier="free"), ] - with patch("app.services.llm_router_service.Router") as mock_router, patch( - "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups" - ) as mock_ctx_fb: + with ( + patch("app.services.llm_router_service.Router") as mock_router, + patch( + "app.services.llm_router_service.LLMRouterService._build_context_fallback_groups" + ) as mock_ctx_fb, + ): mock_ctx_fb.side_effect = lambda ml: (ml, None) mock_router.return_value = object() diff --git a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py index d3921729d..085740032 100644 --- a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py +++ b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py @@ -214,5 +214,3 @@ def test_generate_configs_drops_non_text_and_non_tool_models(): assert "openai/gpt-4o" in model_names assert "openai/dall-e" not in model_names assert "openai/completion-only" not in model_names - - diff --git a/surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py b/surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py index b3dd2bf18..4eb1f2295 100644 --- a/surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py +++ b/surfsense_backend/tests/unit/services/test_openrouter_legacy_config.py @@ -68,9 +68,7 @@ openrouter_integration: assert "deprecated" in captured -def test_new_keys_take_priority_over_legacy_back_compat( - monkeypatch, tmp_path, capsys -): +def test_new_keys_take_priority_over_legacy_back_compat(monkeypatch, tmp_path, capsys): """If both legacy and new keys are present, new keys win (setdefault).""" _write_yaml( tmp_path, diff --git a/surfsense_backend/tests/unit/services/test_quality_score.py b/surfsense_backend/tests/unit/services/test_quality_score.py index fbc91521d..6fbc8fd62 100644 --- a/surfsense_backend/tests/unit/services/test_quality_score.py +++ b/surfsense_backend/tests/unit/services/test_quality_score.py @@ -106,9 +106,12 @@ def test_context_signal_bands(ctx, expected): def test_capabilities_signal_caps_at_five(): - assert capabilities_signal( - ["tools", "structured_outputs", "reasoning", "include_reasoning"] - ) <= 5 + assert ( + capabilities_signal( + ["tools", "structured_outputs", "reasoning", "include_reasoning"] + ) + <= 5 + ) def test_capabilities_signal_tools_only(): diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index 3676601f4..910009667 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -13,8 +13,8 @@ from app.tasks.chat.stream_new_chat import ( StreamResult, _classify_stream_exception, _contract_enforcement_active, - _extract_resolved_file_path, _evaluate_file_contract_outcome, + _extract_resolved_file_path, _log_chat_stream_error, _tool_output_has_error, ) @@ -222,7 +222,7 @@ async def test_preflight_swallows_non_rate_limit_errors_and_re_raises_429(monkey from app.tasks.chat.stream_new_chat import _preflight_llm - class _RateLimitedExc(Exception): + class _RateLimitedError(Exception): """Class-name carries 'RateLimit' so _is_provider_rate_limited triggers.""" rate_calls: list[dict] = [] @@ -230,7 +230,7 @@ async def test_preflight_swallows_non_rate_limit_errors_and_re_raises_429(monkey async def _fake_acompletion_429(**kwargs): rate_calls.append(kwargs) - raise _RateLimitedExc("simulated 429") + raise _RateLimitedError("simulated 429") async def _fake_acompletion_other(**kwargs): other_calls.append(kwargs) @@ -245,7 +245,7 @@ async def test_preflight_swallows_non_rate_limit_errors_and_re_raises_429(monkey import litellm # type: ignore[import-not-found] monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_429) - with pytest.raises(_RateLimitedExc): + with pytest.raises(_RateLimitedError): await _preflight_llm(fake_llm) assert len(rate_calls) == 1 assert rate_calls[0]["max_tokens"] == 1 diff --git a/surfsense_web/components/assistant-ui/markdown-text.tsx b/surfsense_web/components/assistant-ui/markdown-text.tsx index bfbc3a423..9fddec360 100644 --- a/surfsense_web/components/assistant-ui/markdown-text.tsx +++ b/surfsense_web/components/assistant-ui/markdown-text.tsx @@ -19,6 +19,7 @@ import remarkMath from "remark-math"; import { openEditorPanelAtom } from "@/atoms/editor/editor-panel.atom"; import { ImagePreview, ImageRoot, ImageZoom } from "@/components/assistant-ui/image"; import "katex/dist/katex.min.css"; +import { toast } from "sonner"; import { processChildrenWithCitations } from "@/components/citations/citation-renderer"; import { Skeleton } from "@/components/ui/skeleton"; import { @@ -33,7 +34,6 @@ import { useElectronAPI } from "@/hooks/use-platform"; import { documentsApiService } from "@/lib/apis/documents-api.service"; import { type CitationUrlMap, preprocessCitationMarkdown } from "@/lib/citations/citation-parser"; import { cn } from "@/lib/utils"; -import { toast } from "sonner"; function MarkdownCodeBlockSkeleton() { return ( @@ -207,13 +207,7 @@ function isStandaloneDocumentsPathText(node: ReactNode): string | null { return value; } -function FilePathLink({ - path, - className, -}: { - path: string; - className?: string; -}) { +function FilePathLink({ path, className }: { path: string; className?: string }) { const openEditorPanel = useSetAtom(openEditorPanelAtom); const params = useParams(); const electronAPI = useElectronAPI(); @@ -221,7 +215,9 @@ function FilePathLink({ const parsedSearchSpaceId = Array.isArray(searchSpaceIdParam) ? Number(searchSpaceIdParam[0]) : Number(searchSpaceIdParam); - const resolvedSearchSpaceId = Number.isFinite(parsedSearchSpaceId) ? parsedSearchSpaceId : undefined; + const resolvedSearchSpaceId = Number.isFinite(parsedSearchSpaceId) + ? parsedSearchSpaceId + : undefined; return ( <button diff --git a/surfsense_web/lib/apis/documents-api.service.ts b/surfsense_web/lib/apis/documents-api.service.ts index 949e3b29f..630c88d16 100644 --- a/surfsense_web/lib/apis/documents-api.service.ts +++ b/surfsense_web/lib/apis/documents-api.service.ts @@ -5,6 +5,7 @@ import { type DeleteDocumentRequest, deleteDocumentRequest, deleteDocumentResponse, + documentTitleRead, type GetDocumentByChunkRequest, type GetDocumentChunksRequest, type GetDocumentRequest, @@ -28,7 +29,6 @@ import { getSurfsenseDocsRequest, getSurfsenseDocsResponse, type SearchDocumentsRequest, - documentTitleRead, type SearchDocumentTitlesRequest, searchDocumentsRequest, searchDocumentsResponse, @@ -270,15 +270,15 @@ class DocumentsApiService { ); }; - getDocumentByVirtualPath = async (request: { - search_space_id: number; - virtual_path: string; - }) => { + getDocumentByVirtualPath = async (request: { search_space_id: number; virtual_path: string }) => { const params = new URLSearchParams({ search_space_id: String(request.search_space_id), virtual_path: request.virtual_path, }); - return baseApiService.get(`/api/v1/documents/by-virtual-path?${params.toString()}`, documentTitleRead); + return baseApiService.get( + `/api/v1/documents/by-virtual-path?${params.toString()}`, + documentTitleRead + ); }; /**