diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index c1bfcc538..a793f33d1 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -282,6 +282,14 @@ LANGSMITH_PROJECT=surfsense # SURFSENSE_ENABLE_ACTION_LOG=false # SURFSENSE_ENABLE_REVERT_ROUTE=false # Backend-only; flip when UI ships +# Streaming parity v2 — opt in to LangChain's structured AIMessageChunk +# content (typed reasoning blocks, tool-input deltas) and propagate the +# real tool_call_id to the SSE layer. When OFF, the stream falls back to +# the str-only text path and synthetic "call_" tool-call ids. +# Schema migrations 135/136 ship unconditionally because they are +# forward-compatible. +# SURFSENSE_ENABLE_STREAM_PARITY_V2=false + # Plugins # SURFSENSE_ENABLE_PLUGIN_LOADER=false # Comma-separated allowlist of plugin entry-point names diff --git a/surfsense_backend/alembic/versions/134_relax_revision_fks.py b/surfsense_backend/alembic/versions/134_relax_revision_fks.py new file mode 100644 index 000000000..99b665426 --- /dev/null +++ b/surfsense_backend/alembic/versions/134_relax_revision_fks.py @@ -0,0 +1,139 @@ +"""134_relax_revision_fks + +Revision ID: 134 +Revises: 133 +Create Date: 2026-04-29 + +Relax the parent FKs on ``document_revisions`` and ``folder_revisions`` so +revisions survive the deletes they describe. + +Why: the snapshot/revert pipeline writes a ``DocumentRevision`` BEFORE +hard-deleting a document via the ``rm`` tool (and likewise a +``FolderRevision`` before ``rmdir``). If the FK is ``ON DELETE CASCADE`` +the snapshot row is wiped at the exact moment we need it most — revert +then has nothing to read and the operation becomes irreversible. + +Migration: + +* ``document_revisions.document_id``: ``NOT NULL`` -> nullable; FK + ``ON DELETE CASCADE`` -> ``ON DELETE SET NULL``. +* ``folder_revisions.folder_id``: same treatment. + +The ``search_space_id`` FK on both tables is left unchanged (still +``ON DELETE CASCADE``). When a search space is deleted, all documents, +folders, AND their revisions go together — that's the correct teardown +story. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy import inspect + +from alembic import op + +revision: str = "134" +down_revision: str | None = "133" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def _fk_name(bind, table: str, column: str) -> str | None: + """Return the (single) FK constraint name on ``table.column``, if any.""" + inspector = inspect(bind) + for fk in inspector.get_foreign_keys(table): + cols = fk.get("constrained_columns") or [] + if cols == [column]: + return fk.get("name") + return None + + +def upgrade() -> None: + bind = op.get_bind() + + # --- document_revisions.document_id -> nullable + SET NULL --------------- + fk_name = _fk_name(bind, "document_revisions", "document_id") + if fk_name: + op.drop_constraint(fk_name, "document_revisions", type_="foreignkey") + op.alter_column( + "document_revisions", + "document_id", + existing_type=sa.Integer(), + nullable=True, + ) + op.create_foreign_key( + "document_revisions_document_id_fkey", + "document_revisions", + "documents", + ["document_id"], + ["id"], + ondelete="SET NULL", + ) + + # --- folder_revisions.folder_id -> nullable + SET NULL ------------------- + fk_name = _fk_name(bind, "folder_revisions", "folder_id") + if fk_name: + op.drop_constraint(fk_name, "folder_revisions", type_="foreignkey") + op.alter_column( + "folder_revisions", + "folder_id", + existing_type=sa.Integer(), + nullable=True, + ) + op.create_foreign_key( + "folder_revisions_folder_id_fkey", + "folder_revisions", + "folders", + ["folder_id"], + ["id"], + ondelete="SET NULL", + ) + + +def downgrade() -> None: + bind = op.get_bind() + + # Reinstating NOT NULL + CASCADE requires draining orphan rows first + # (any revision whose parent doc/folder has already been deleted). + op.execute("DELETE FROM document_revisions WHERE document_id IS NULL") + op.execute("DELETE FROM folder_revisions WHERE folder_id IS NULL") + + # --- document_revisions.document_id -> NOT NULL + CASCADE --------------- + fk_name = _fk_name(bind, "document_revisions", "document_id") + if fk_name: + op.drop_constraint(fk_name, "document_revisions", type_="foreignkey") + op.alter_column( + "document_revisions", + "document_id", + existing_type=sa.Integer(), + nullable=False, + ) + op.create_foreign_key( + "document_revisions_document_id_fkey", + "document_revisions", + "documents", + ["document_id"], + ["id"], + ondelete="CASCADE", + ) + + # --- folder_revisions.folder_id -> NOT NULL + CASCADE ------------------- + fk_name = _fk_name(bind, "folder_revisions", "folder_id") + if fk_name: + op.drop_constraint(fk_name, "folder_revisions", type_="foreignkey") + op.alter_column( + "folder_revisions", + "folder_id", + existing_type=sa.Integer(), + nullable=False, + ) + op.create_foreign_key( + "folder_revisions_folder_id_fkey", + "folder_revisions", + "folders", + ["folder_id"], + ["id"], + ondelete="CASCADE", + ) diff --git a/surfsense_backend/alembic/versions/135_action_log_correlation_ids.py b/surfsense_backend/alembic/versions/135_action_log_correlation_ids.py new file mode 100644 index 000000000..9ae368b81 --- /dev/null +++ b/surfsense_backend/alembic/versions/135_action_log_correlation_ids.py @@ -0,0 +1,82 @@ +"""135_action_log_correlation_ids + +Revision ID: 135 +Revises: 134 +Create Date: 2026-04-29 + +Action-log correlation-id cleanup. + +Background +---------- +``agent_action_log.turn_id`` is misnamed. ``ActionLogMiddleware`` writes +the LangChain ``tool_call.id`` into that column today (see +``action_log.py:_resolve_turn_id``), and ``kb_persistence._find_action_ids_batch`` +joins on it as such. The real chat-turn id (``f"{chat_id}:{ms}"`` from +``stream_new_chat.py``) lives in ``config.configurable.turn_id`` and was +never persisted. + +This migration introduces two new, correctly-named columns: + +* ``tool_call_id`` (LangChain tool-call id, what ``turn_id`` actually held) +* ``chat_turn_id`` (the per-turn correlation id from + ``configurable.turn_id`` — used by the per-turn ``revert-turn`` route). + +Backfill copies the current ``turn_id`` values into ``tool_call_id`` so +existing joins keep working. The old ``turn_id`` column is left in place +for one release as a deprecated alias to give safe rollback. ``ActionLogMiddleware`` +keeps writing it (= ``tool_call_id``) for the same reason. + +Indexes +------- + +* ``ix_agent_action_log_tool_call_id`` — required by + ``_find_action_ids_batch`` (was on ``turn_id``). +* ``ix_agent_action_log_chat_turn_id`` — required by the + ``revert-turn/{chat_turn_id}`` query. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "135" +down_revision: str | None = "134" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.add_column( + "agent_action_log", + sa.Column("tool_call_id", sa.String(length=64), nullable=True), + ) + op.add_column( + "agent_action_log", + sa.Column("chat_turn_id", sa.String(length=64), nullable=True), + ) + + op.create_index( + "ix_agent_action_log_tool_call_id", + "agent_action_log", + ["tool_call_id"], + ) + op.create_index( + "ix_agent_action_log_chat_turn_id", + "agent_action_log", + ["chat_turn_id"], + ) + + op.execute( + "UPDATE agent_action_log SET tool_call_id = turn_id WHERE tool_call_id IS NULL" + ) + + +def downgrade() -> None: + op.drop_index("ix_agent_action_log_chat_turn_id", table_name="agent_action_log") + op.drop_index("ix_agent_action_log_tool_call_id", table_name="agent_action_log") + op.drop_column("agent_action_log", "chat_turn_id") + op.drop_column("agent_action_log", "tool_call_id") diff --git a/surfsense_backend/alembic/versions/136_new_chat_message_turn_id.py b/surfsense_backend/alembic/versions/136_new_chat_message_turn_id.py new file mode 100644 index 000000000..8d4350424 --- /dev/null +++ b/surfsense_backend/alembic/versions/136_new_chat_message_turn_id.py @@ -0,0 +1,52 @@ +"""136_new_chat_message_turn_id + +Revision ID: 136 +Revises: 135 +Create Date: 2026-04-29 + +Persist the per-turn correlation id on each chat message. + +Background +---------- +LangGraph's checkpointer stores user-provided ``configurable.turn_id`` +in checkpoint metadata (see +``langgraph/checkpoint/base/__init__.py:get_checkpoint_metadata``). To +support edit-from-arbitrary-position, the regenerate route needs to map +a ``message_id`` -> ``turn_id`` -> checkpoint at request time. Without +this column the mapping doesn't exist anywhere, so regenerate would +have to hardcode the "last 2 messages" rewind heuristic. + +This migration adds a nullable ``turn_id`` column to ``new_chat_messages`` +plus an index. Legacy rows have NULL — the regenerate route degrades +gracefully to the reload-last-two heuristic for those. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "136" +down_revision: str | None = "135" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.add_column( + "new_chat_messages", + sa.Column("turn_id", sa.String(length=64), nullable=True), + ) + op.create_index( + "ix_new_chat_messages_turn_id", + "new_chat_messages", + ["turn_id"], + ) + + +def downgrade() -> None: + op.drop_index("ix_new_chat_messages_turn_id", table_name="new_chat_messages") + op.drop_column("new_chat_messages", "turn_id") diff --git a/surfsense_backend/alembic/versions/137_unique_reverse_of_in_action_log.py b/surfsense_backend/alembic/versions/137_unique_reverse_of_in_action_log.py new file mode 100644 index 000000000..d606a00f9 --- /dev/null +++ b/surfsense_backend/alembic/versions/137_unique_reverse_of_in_action_log.py @@ -0,0 +1,74 @@ +"""137_unique_reverse_of_in_action_log + +Revision ID: 137 +Revises: 136 +Create Date: 2026-04-29 + +Protect ``agent_action_log.reverse_of`` against double inserts. Two +concurrent revert calls (single-action route + the per-turn batch +route, or two batch routes racing) both pass the +``_was_already_reverted`` SELECT and both insert their own +``_revert:*`` rows. The application-level idempotency check is racy +because there's no DB constraint backing it. + +This migration adds a partial unique index on ``reverse_of`` (PostgreSQL +``WHERE reverse_of IS NOT NULL``) so the second concurrent insert raises +``IntegrityError`` and the route can translate it to ``"already_reverted"`` +deterministically. + +The plain ``UniqueConstraint`` flavour can't be used because most +existing rows have ``reverse_of = NULL`` (only revert rows fill it), +and Postgres does treat NULL as distinct in unique indexes — but a +partial index is the cleanest expression of intent and works even on +older Postgres releases that distinguish NULL handling. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op + +revision: str = "137" +down_revision: str | None = "136" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +_INDEX_NAME = "ux_agent_action_log_reverse_of" + + +def upgrade() -> None: + # Defensively de-dup any pre-existing double-revert rows before + # adding the unique index. Keeps the OLDEST row (smallest id) and + # NULLs out the duplicates' ``reverse_of`` so they survive as audit + # trail but no longer claim to be the canonical revert. We do NOT + # delete them — operators can still inspect them via /actions. + op.execute( + """ + WITH dups AS ( + SELECT id, + reverse_of, + ROW_NUMBER() OVER ( + PARTITION BY reverse_of ORDER BY id ASC + ) AS rn + FROM agent_action_log + WHERE reverse_of IS NOT NULL + ) + UPDATE agent_action_log + SET reverse_of = NULL + WHERE id IN (SELECT id FROM dups WHERE rn > 1) + """ + ) + + op.create_index( + _INDEX_NAME, + "agent_action_log", + ["reverse_of"], + unique=True, + postgresql_where="reverse_of IS NOT NULL", + ) + + +def downgrade() -> None: + op.drop_index(_INDEX_NAME, table_name="agent_action_log") diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index bfb94ba2d..fdd72ea92 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -724,7 +724,8 @@ def _build_compiled_agent_blocking( repair_mw = None if flags.enable_tool_call_repair and not flags.disable_new_agent_stack: registered_names: set[str] = {t.name for t in tools} - # Tools owned by the standard deepagents middleware stack. + # Tools owned by the standard deepagents middleware stack and the + # SurfSense filesystem extension. registered_names |= { "write_todos", "ls", @@ -735,6 +736,14 @@ def _build_compiled_agent_blocking( "grep", "execute", "task", + "mkdir", + "cd", + "pwd", + "move_file", + "rm", + "rmdir", + "list_tree", + "execute_code", } repair_mw = ToolCallNameRepairMiddleware( registered_tool_names=registered_names, @@ -763,25 +772,51 @@ def _build_compiled_agent_blocking( # on every safe read-only call (``ls``, ``read_file``, ``grep``, # ``glob``, ``web_search`` …) and, on resume, replay the previous # reject decision into innocent calls. - # 2. ``connector_synthesized`` — deny rules for tools whose required - # connector is not connected to this space. Overrides #1. - # 3. (future) user-defined rules from ``agent_permission_rules`` table - # via the Agent Permissions UI. Loaded last so they override both. + # 2. ``desktop_safety`` — ``ask`` for destructive filesystem ops when + # the agent is operating against the user's real disk. Cloud mode + # has full revision-based revert via ``revert_service``, but + # desktop mode hits disk immediately with no undo, so an + # accidental ``rm`` / ``rmdir`` / ``move_file`` / ``edit_file`` / + # ``write_file`` is unrecoverable. This layer is forced on in + # desktop mode regardless of ``enable_permission`` because the + # safety net is non-negotiable. + # 3. ``connector_synthesized`` — deny rules for tools whose required + # connector is not connected to this space. Overrides #1/#2. + # 4. (future) user-defined rules from ``agent_permission_rules`` table + # via the Agent Permissions UI. Loaded last so they override all. permission_mw: PermissionMiddleware | None = None - if flags.enable_permission and not flags.disable_new_agent_stack: - synthesized = _synthesize_connector_deny_rules( - available_connectors=available_connectors, - enabled_tool_names={t.name for t in tools}, - ) - permission_mw = PermissionMiddleware( - rulesets=[ + is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER + permission_enabled = flags.enable_permission and not flags.disable_new_agent_stack + # Build the middleware whenever it has work to do: either the user + # opted into the rule engine, OR we're in desktop mode and need the + # safety rules unconditionally. + if permission_enabled or is_desktop_fs: + rulesets: list[Ruleset] = [ + Ruleset( + rules=[Rule(permission="*", pattern="*", action="allow")], + origin="surfsense_defaults", + ), + ] + if is_desktop_fs: + rulesets.append( Ruleset( - rules=[Rule(permission="*", pattern="*", action="allow")], - origin="surfsense_defaults", - ), - Ruleset(rules=synthesized, origin="connector_synthesized"), - ], - ) + rules=[ + Rule(permission="rm", pattern="*", action="ask"), + Rule(permission="rmdir", pattern="*", action="ask"), + Rule(permission="move_file", pattern="*", action="ask"), + Rule(permission="edit_file", pattern="*", action="ask"), + Rule(permission="write_file", pattern="*", action="ask"), + ], + origin="desktop_safety", + ) + ) + if permission_enabled: + synthesized = _synthesize_connector_deny_rules( + available_connectors=available_connectors, + enabled_tool_names={t.name for t in tools}, + ) + rulesets.append(Ruleset(rules=synthesized, origin="connector_synthesized")) + permission_mw = PermissionMiddleware(rulesets=rulesets) # ActionLogMiddleware. Off by default until the ``agent_action_log`` # table is migrated. When enabled, persists one row per tool call @@ -938,6 +973,7 @@ def _build_compiled_agent_blocking( search_space_id=search_space_id, created_by_id=user_id, filesystem_mode=filesystem_mode, + thread_id=thread_id, ) if filesystem_mode == FilesystemMode.CLOUD else None, diff --git a/surfsense_backend/app/agents/new_chat/feature_flags.py b/surfsense_backend/app/agents/new_chat/feature_flags.py index 55525abc5..f58bf0dd7 100644 --- a/surfsense_backend/app/agents/new_chat/feature_flags.py +++ b/surfsense_backend/app/agents/new_chat/feature_flags.py @@ -23,6 +23,7 @@ Local development (recommended for trying everything except doom-loop / selector SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false + SURFSENSE_ENABLE_STREAM_PARITY_V2=false # structured streaming events Master kill-switch (overrides everything else): @@ -86,6 +87,15 @@ class AgentFeatureFlags: False # Backend ships before UI; route returns 503 until this flips ) + # Streaming parity v2 — opt in to LangChain's structured + # ``AIMessageChunk`` content (typed reasoning blocks, tool-input + # deltas) and propagate the real ``tool_call_id`` to the SSE layer. + # When OFF the ``stream_new_chat`` task falls back to the str-only + # text path and the synthetic ``call_`` tool-call id (no + # ``langchainToolCallId`` propagation). Schema migrations 135/136 + # ship unconditionally because they're forward-compatible. + enable_stream_parity_v2: bool = False + # Plugins enable_plugin_loader: bool = False @@ -139,6 +149,10 @@ class AgentFeatureFlags: # Snapshot / revert enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False), enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False), + # Streaming parity v2 + enable_stream_parity_v2=_env_bool( + "SURFSENSE_ENABLE_STREAM_PARITY_V2", False + ), # Plugins enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False), # Observability diff --git a/surfsense_backend/app/agents/new_chat/filesystem_state.py b/surfsense_backend/app/agents/new_chat/filesystem_state.py index 18952ed6f..f54ada76e 100644 --- a/surfsense_backend/app/agents/new_chat/filesystem_state.py +++ b/surfsense_backend/app/agents/new_chat/filesystem_state.py @@ -5,9 +5,14 @@ extra fields needed to implement Postgres-backed virtual filesystem semantics: * ``cwd`` — current working directory (per-thread checkpointed). * ``staged_dirs`` — pending mkdir requests (cloud only). +* ``staged_dir_tool_calls`` — sidecar map ``path -> tool_call_id`` for staged dirs. * ``pending_moves`` — pending move_file requests (cloud only). +* ``pending_deletes`` — pending ``rm`` requests (cloud only). +* ``pending_dir_deletes`` — pending ``rmdir`` requests (cloud only). * ``doc_id_by_path`` — virtual_path -> Document.id, populated by lazy reads. * ``dirty_paths`` — paths whose state file content differs from DB. +* ``dirty_path_tool_calls`` — sidecar map ``path -> latest tool_call_id`` for + dirty paths; used to bind the per-path snapshot to an action_id. * ``kb_priority`` — top-K priority hints rendered into a system message. * ``kb_matched_chunk_ids`` — internal hand-off for matched-chunk highlighting. * ``kb_anon_doc`` — Redis-loaded anonymous document (if any). @@ -32,12 +37,31 @@ from app.agents.new_chat.state_reducers import ( ) -class PendingMove(TypedDict): - """A staged move_file operation pending end-of-turn commit.""" +class PendingMove(TypedDict, total=False): + """A staged move_file operation pending end-of-turn commit. + + ``tool_call_id`` is optional for backward compatibility with checkpoints + written before the snapshot/revert pipeline was wired up; new entries + always include it so the persistence body can resolve an action_id. + """ source: str dest: str overwrite: bool + tool_call_id: str + + +class PendingDelete(TypedDict, total=False): + """A staged ``rm`` or ``rmdir`` operation pending end-of-turn commit. + + ``tool_call_id`` is required for new entries (it's the binding key used + by :class:`KnowledgeBasePersistenceMiddleware` to find the matching + :class:`AgentActionLog` row and bind the snapshot to it). Marked + ``total=False`` only to tolerate older checkpoint payloads. + """ + + path: str + tool_call_id: str class KbPriorityEntry(TypedDict, total=False): @@ -76,9 +100,38 @@ class SurfSenseFilesystemState(FilesystemState): staged_dirs: NotRequired[Annotated[list[str], _add_unique_reducer]] """mkdir paths staged for end-of-turn folder creation (cloud only).""" + staged_dir_tool_calls: NotRequired[ + Annotated[dict[str, str], _dict_merge_with_tombstones_reducer] + ] + """``path -> tool_call_id`` sidecar for ``staged_dirs``. + + Used by :class:`KnowledgeBasePersistenceMiddleware` to bind the + :class:`FolderRevision` snapshot to the originating ``mkdir`` action. + Kept separate from ``staged_dirs`` (which stays a unique-string list) + to avoid breaking ``_add_unique_reducer`` semantics. + """ + pending_moves: NotRequired[Annotated[list[PendingMove], _list_append_reducer]] """move_file ops staged for end-of-turn commit (cloud only).""" + pending_deletes: NotRequired[Annotated[list[PendingDelete], _list_append_reducer]] + """``rm`` ops staged for end-of-turn ``DELETE FROM documents`` (cloud only). + + Each entry is a dict ``{"path": ..., "tool_call_id": ...}``. Per-path + uniqueness is enforced inside the commit body, not the reducer (we keep + ``tool_call_id`` per occurrence so snapshot binding works). + """ + + pending_dir_deletes: NotRequired[ + Annotated[list[PendingDelete], _list_append_reducer] + ] + """``rmdir`` ops staged for end-of-turn ``DELETE FROM folders`` (cloud only). + + Same shape as :data:`pending_deletes`. Commit body re-verifies the + folder is empty (in-DB AND with this turn's pending changes accounted + for) before issuing the DELETE. + """ + doc_id_by_path: NotRequired[ Annotated[dict[str, int], _dict_merge_with_tombstones_reducer] ] @@ -92,6 +145,17 @@ class SurfSenseFilesystemState(FilesystemState): dirty_paths: NotRequired[Annotated[list[str], _add_unique_reducer]] """Paths whose ``state["files"]`` content has been modified this turn.""" + dirty_path_tool_calls: NotRequired[ + Annotated[dict[str, str], _dict_merge_with_tombstones_reducer] + ] + """``path -> latest tool_call_id`` sidecar for ``dirty_paths``. + + The persistence body coalesces multiple writes/edits to the same path + into one snapshot per turn. This map captures the most-recent + ``tool_call_id`` so the resulting :class:`DocumentRevision` is bound + to the latest action_id (the one the user is most likely to revert). + """ + kb_priority: NotRequired[Annotated[list[KbPriorityEntry], _replace_reducer]] """Top-K priority hints rendered as a system message before the user turn.""" @@ -108,6 +172,7 @@ class SurfSenseFilesystemState(FilesystemState): __all__ = [ "KbAnonDoc", "KbPriorityEntry", + "PendingDelete", "PendingMove", "SurfSenseFilesystemState", ] diff --git a/surfsense_backend/app/agents/new_chat/middleware/action_log.py b/surfsense_backend/app/agents/new_chat/middleware/action_log.py index 3675064e8..716a1616c 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/action_log.py +++ b/surfsense_backend/app/agents/new_chat/middleware/action_log.py @@ -30,6 +30,7 @@ from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any from langchain.agents.middleware import AgentMiddleware +from langchain_core.callbacks import adispatch_custom_event from langchain_core.messages import ToolMessage from app.agents.new_chat.feature_flags import get_flags @@ -144,11 +145,19 @@ class ActionLogMiddleware(AgentMiddleware): result=result, ) + tool_call_id = _resolve_tool_call_id(request) + chat_turn_id = _resolve_chat_turn_id(request) + row = AgentActionLog( thread_id=self._thread_id, user_id=self._user_id, search_space_id=self._search_space_id, - turn_id=_resolve_turn_id(request), + # ``turn_id`` is the deprecated alias of ``tool_call_id`` + # kept for one release for safe rollback. New consumers + # should read ``tool_call_id`` directly. + turn_id=tool_call_id, + tool_call_id=tool_call_id, + chat_turn_id=chat_turn_id, message_id=_resolve_message_id(request), tool_name=tool_name, args=args_payload, @@ -160,11 +169,41 @@ class ActionLogMiddleware(AgentMiddleware): async with shielded_async_session() as session: session.add(row) await session.commit() + row_id = int(row.id) if row.id is not None else None + row_created_at = row.created_at except Exception: logger.warning( "ActionLogMiddleware failed to persist action log row", exc_info=True, ) + return + + # Surface a side-channel SSE event so the chat tool card can + # render a Revert button immediately after the row is durable. + # ``stream_new_chat`` translates this into a + # ``data-action-log`` SSE event. We DO NOT include the + # ``reverse_descriptor`` payload here; only a presence flag. + try: + await adispatch_custom_event( + "action_log", + { + "id": row_id, + "lc_tool_call_id": tool_call_id, + "chat_turn_id": chat_turn_id, + "tool_name": tool_name, + "reversible": bool(reversible), + "reverse_descriptor_present": reverse_descriptor is not None, + "created_at": row_created_at.isoformat() + if row_created_at + else None, + "error": error_payload is not None, + }, + ) + except Exception: + logger.debug( + "ActionLogMiddleware failed to dispatch action_log event", + exc_info=True, + ) def _render_reverse( self, @@ -254,7 +293,8 @@ def _resolve_args_payload(request: Any) -> dict[str, Any] | None: } -def _resolve_turn_id(request: Any) -> str | None: +def _resolve_tool_call_id(request: Any) -> str | None: + """Return the LangChain ``tool_call.id`` for this request, if any.""" try: call = getattr(request, "tool_call", None) or {} if isinstance(call, dict): @@ -266,9 +306,40 @@ def _resolve_turn_id(request: Any) -> str | None: return None +# Deprecated alias kept for one release. Old callers and tests treated +# ``turn_id`` as if it carried the LangChain tool_call id; the new column +# lives under ``tool_call_id``. Both resolve to the same value today. +_resolve_turn_id = _resolve_tool_call_id + + +def _resolve_chat_turn_id(request: Any) -> str | None: + """Return ``configurable.turn_id`` for this request, if accessible. + + ``ToolRuntime.config`` is exposed by LangGraph (see + ``langgraph/prebuilt/tool_node.py``); the chat-turn correlation id + lives at ``runtime.config["configurable"]["turn_id"]``. + """ + try: + runtime = getattr(request, "runtime", None) + if runtime is None: + return None + config = getattr(runtime, "config", None) + if not isinstance(config, dict): + return None + configurable = config.get("configurable") + if not isinstance(configurable, dict): + return None + value = configurable.get("turn_id") + if isinstance(value, str) and value: + return value + except Exception: # pragma: no cover - defensive + pass + return None + + def _resolve_message_id(request: Any) -> str | None: """Tool-call IDs serve as best-available message correlator at this layer.""" - return _resolve_turn_id(request) + return _resolve_tool_call_id(request) def _resolve_result_id(result: Any) -> str | None: diff --git a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py index 62316d69e..c46eb98a5 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py +++ b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py @@ -102,6 +102,8 @@ current working directory (`cwd`, default `/documents`). - cd(path): change the current working directory. - pwd(): print the current working directory. - move_file(source, dest): move/rename a file under `/documents/`. +- rm(path): delete a single file under `/documents/` (no `-r`). +- rmdir(path): delete an empty directory under `/documents/`. - list_tree(path, max_depth, page_size): recursively list files/folders. ## Persistence Rules @@ -112,8 +114,9 @@ current working directory (`cwd`, default `/documents`). `/documents/temp_scratch.md`) are **discarded** at end of turn — use this prefix for any scratch/working content you do NOT want saved. - All other paths (outside `/documents/` and not `temp_*`) are rejected. -- mkdir/move_file are staged this turn and committed at end of turn alongside - any new/edited documents. +- mkdir/move_file/rm/rmdir are staged this turn and committed at end of + turn alongside any new/edited documents. Snapshot/revert is enabled + for every destructive operation when action logging is on. ## Reading Documents Efficiently @@ -176,6 +179,8 @@ directory (`cwd`). - cd(path): change the current working directory. - pwd(): print the current working directory. - move_file(source, dest): move/rename a file. +- rm(path): delete a single file from disk (no `-r`). NOT reversible. +- rmdir(path): delete an empty directory from disk. NOT reversible. - list_tree(path, max_depth, page_size): recursively list files/folders. ## Workflow Tips @@ -184,6 +189,8 @@ directory (`cwd`). - For large trees, prefer `list_tree` then `grep` then `read_file` over brute-force directory traversal. - Cross-mount moves are not supported. +- Desktop deletes hit disk immediately and cannot be undone via the + agent's revert flow — confirm before calling `rm`/`rmdir`. """ ) @@ -355,6 +362,42 @@ Notes: - Parent folders are created as needed. """ +_CLOUD_RM_TOOL_DESCRIPTION = """Deletes a single file under `/documents/`. + +Mirrors POSIX `rm path` (no `-r`, no glob expansion). Stages the deletion +for end-of-turn commit; the row is removed only after the agent's turn +finishes successfully. + +Args: +- path: absolute or relative file path. Cannot point at a directory — use + `rmdir` for empty folders. Cannot target the root or `/documents`. + +Notes: +- The action is reversible via the per-action revert flow when action + logging is enabled. +- The anonymous uploaded document is read-only and cannot be deleted. +""" + +_CLOUD_RMDIR_TOOL_DESCRIPTION = """Deletes an empty directory under `/documents/`. + +Mirrors POSIX `rmdir path`: refuses non-empty directories. Recursive +deletion (`rm -r`) is intentionally NOT supported — clear contents with +`rm` first. + +Args: +- path: absolute or relative directory path. Cannot target the root, + `/documents`, the current cwd, or any ancestor of cwd (use `cd` to + move out first). + +Notes: +- Emptiness is evaluated against the post-staged view, so a same-turn + `rm /a/x.md` followed by `rmdir /a` is fine. +- If the directory was added in this same turn via `mkdir` and never + committed, the staged mkdir is dropped instead of issuing a delete. +- The action is reversible via the per-action revert flow when action + logging is enabled. +""" + # --- desktop-only ---------------------------------------------------------- _DESKTOP_LIST_FILES_TOOL_DESCRIPTION = """Lists files and directories at the given path. @@ -421,6 +464,28 @@ Notes: - Parent folders are created as needed. """ +_DESKTOP_RM_TOOL_DESCRIPTION = """Deletes a single file from disk. + +Mirrors POSIX `rm path` (no `-r`, no glob expansion). The deletion hits +disk immediately. Desktop deletes are NOT reversible via the agent's +revert flow. + +Args: +- path: absolute mount-prefixed file path. Cannot point at a directory — + use `rmdir` for empty folders. +""" + +_DESKTOP_RMDIR_TOOL_DESCRIPTION = """Deletes an empty directory from disk. + +Mirrors POSIX `rmdir path`: refuses non-empty directories. Recursive +deletion is NOT supported. The deletion hits disk immediately and is +NOT reversible via the agent's revert flow. + +Args: +- path: absolute mount-prefixed directory path. Cannot target the mount + root or any directory containing files/subfolders. +""" + def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]: """Pick the active-mode description for every filesystem tool.""" @@ -437,6 +502,8 @@ def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]: "mkdir": _CLOUD_MKDIR_TOOL_DESCRIPTION, "cd": SURFSENSE_CD_TOOL_DESCRIPTION, "pwd": SURFSENSE_PWD_TOOL_DESCRIPTION, + "rm": _CLOUD_RM_TOOL_DESCRIPTION, + "rmdir": _CLOUD_RMDIR_TOOL_DESCRIPTION, } return { "ls": _DESKTOP_LIST_FILES_TOOL_DESCRIPTION, @@ -450,6 +517,8 @@ def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]: "mkdir": _DESKTOP_MKDIR_TOOL_DESCRIPTION, "cd": SURFSENSE_CD_TOOL_DESCRIPTION, "pwd": SURFSENSE_PWD_TOOL_DESCRIPTION, + "rm": _DESKTOP_RM_TOOL_DESCRIPTION, + "rmdir": _DESKTOP_RMDIR_TOOL_DESCRIPTION, } @@ -476,6 +545,21 @@ def _basename(path: str) -> str: return path.rsplit("/", 1)[-1] +def _is_ancestor_of(candidate: str, target: str) -> bool: + """True iff ``candidate`` is a strict ancestor directory of ``target``. + + ``target`` itself is NOT considered an ancestor (use equality for that). + Both paths are assumed to be canonicalised, absolute, and free of + trailing slashes (except the root ``/``). + """ + if not candidate.startswith("/") or not target.startswith("/"): + return False + if candidate == target: + return False + prefix = candidate.rstrip("/") + "/" + return target.startswith(prefix) + + class SurfSenseFilesystemMiddleware(FilesystemMiddleware): """SurfSense-specific filesystem middleware (cloud + desktop).""" @@ -519,6 +603,8 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): self.tools.append(self._create_cd_tool()) self.tools.append(self._create_pwd_tool()) self.tools.append(self._create_move_file_tool()) + self.tools.append(self._create_rm_tool()) + self.tools.append(self._create_rmdir_tool()) self.tools.append(self._create_list_tree_tool()) if self._sandbox_available: self.tools.append(self._create_execute_code_tool()) @@ -941,6 +1027,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): } if self._is_cloud(): update["dirty_paths"] = [path] + update["dirty_path_tool_calls"] = {path: runtime.tool_call_id} return Command(update=update) def sync_write_file( @@ -1036,6 +1123,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): } if self._is_cloud(): update["dirty_paths"] = [path] + update["dirty_path_tool_calls"] = {path: runtime.tool_call_id} if doc_id_to_attach is not None: update["doc_id_by_path"] = {path: doc_id_to_attach} return Command(update=update) @@ -1103,6 +1191,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): return Command( update={ "staged_dirs": [validated], + "staged_dir_tool_calls": { + validated: runtime.tool_call_id, + }, "messages": [ ToolMessage( content=( @@ -1372,7 +1463,14 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): files_update: dict[str, Any] = {source: None, dest: source_file_data} update: dict[str, Any] = { "files": files_update, - "pending_moves": [{"source": source, "dest": dest, "overwrite": False}], + "pending_moves": [ + { + "source": source, + "dest": dest, + "overwrite": False, + "tool_call_id": runtime.tool_call_id, + } + ], "messages": [ ToolMessage( content=( @@ -1396,6 +1494,323 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): update["dirty_paths"] = new_dirty return Command(update=update) + # ------------------------------------------------------------------ tool: rm + + def _create_rm_tool(self) -> BaseTool: + tool_description = ( + self._custom_tool_descriptions.get("rm") or _CLOUD_RM_TOOL_DESCRIPTION + ) + + async def async_rm( + path: Annotated[ + str, + "Absolute or relative path to the file to delete.", + ], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> Command | str: + if not path or not path.strip(): + return "Error: path is required." + + target = self._resolve_relative(path, runtime) + try: + validated = validate_path(target) + except ValueError as exc: + return f"Error: {exc}" + + if self._is_cloud(): + if validated in ("/", DOCUMENTS_ROOT): + return f"Error: refusing to rm '{validated}'." + if not validated.startswith(DOCUMENTS_ROOT + "/"): + return ( + "Error: cloud rm must target a path under /documents/ " + f"(got '{validated}')." + ) + + anon = runtime.state.get("kb_anon_doc") or {} + if isinstance(anon, dict) and str(anon.get("path") or "") == validated: + return "Error: the anonymous uploaded document is read-only." + + # Refuse if the path looks like a directory. + staged_dirs = list(runtime.state.get("staged_dirs") or []) + if validated in staged_dirs: + return ( + f"Error: '{validated}' is a directory. Use rmdir for " + "empty directories." + ) + pending_dir_deletes = list( + runtime.state.get("pending_dir_deletes") or [] + ) + if any( + isinstance(d, dict) and d.get("path") == validated + for d in pending_dir_deletes + ): + return f"Error: '{validated}' is already queued for rmdir." + + backend = self._get_backend(runtime) + if isinstance(backend, KBPostgresBackend): + # Detect "is a directory" via `ls`: if the path lists + # children we know it's a folder. Otherwise we still + # need to confirm it's a real file before staging. + children = await backend.als_info(validated) + if children: + return ( + f"Error: '{validated}' is a directory. Use rmdir for " + "empty directories." + ) + + # Already queued for delete this turn? + pending_deletes = list(runtime.state.get("pending_deletes") or []) + if any( + isinstance(d, dict) and d.get("path") == validated + for d in pending_deletes + ): + return f"'{validated}' is already queued for deletion." + + # Resolve doc_id (best-effort): file in state or DB. + files_state = runtime.state.get("files") or {} + doc_id_by_path = runtime.state.get("doc_id_by_path") or {} + resolved_doc_id: int | None = doc_id_by_path.get(validated) + if ( + validated not in files_state + and resolved_doc_id is None + and isinstance(backend, KBPostgresBackend) + ): + loaded = await backend._load_file_data(validated) + if loaded is None: + return f"Error: file '{validated}' not found." + _, resolved_doc_id = loaded + + files_update: dict[str, Any] = {validated: None} + update: dict[str, Any] = { + "pending_deletes": [ + { + "path": validated, + "tool_call_id": runtime.tool_call_id, + } + ], + "files": files_update, + "doc_id_by_path": {validated: None}, + "messages": [ + ToolMessage( + content=( + f"Staged delete of '{validated}' (will commit at " + "end of turn)." + ), + tool_call_id=runtime.tool_call_id, + ) + ], + } + + # Drop the path from dirty_paths so a same-turn write+rm + # doesn't recreate the doc at commit time. + dirty_paths = list(runtime.state.get("dirty_paths") or []) + if validated in dirty_paths: + new_dirty: list[Any] = [_CLEAR] + for entry in dirty_paths: + if entry != validated: + new_dirty.append(entry) + update["dirty_paths"] = new_dirty + update["dirty_path_tool_calls"] = {validated: None} + + return Command(update=update) + + # Desktop mode — hit disk immediately. + backend = self._get_backend(runtime) + adelete = getattr(backend, "adelete_file", None) + if not callable(adelete): + return "Error: rm is not supported by the active backend." + res: WriteResult = await adelete(validated) + if res.error: + return res.error + update_desktop: dict[str, Any] = { + "files": {validated: None}, + "messages": [ + ToolMessage( + content=f"Deleted file '{res.path or validated}'", + tool_call_id=runtime.tool_call_id, + ) + ], + } + return Command(update=update_desktop) + + def sync_rm( + path: Annotated[ + str, + "Absolute or relative path to the file to delete.", + ], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> Command | str: + return self._run_async_blocking(async_rm(path, runtime)) + + return StructuredTool.from_function( + name="rm", + description=tool_description, + func=sync_rm, + coroutine=async_rm, + ) + + # ------------------------------------------------------------------ tool: rmdir + + def _create_rmdir_tool(self) -> BaseTool: + tool_description = ( + self._custom_tool_descriptions.get("rmdir") or _CLOUD_RMDIR_TOOL_DESCRIPTION + ) + + async def async_rmdir( + path: Annotated[ + str, + "Absolute or relative path of the empty directory to delete.", + ], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> Command | str: + if not path or not path.strip(): + return "Error: path is required." + + target = self._resolve_relative(path, runtime) + try: + validated = validate_path(target) + except ValueError as exc: + return f"Error: {exc}" + + if self._is_cloud(): + if validated in ("/", DOCUMENTS_ROOT): + return f"Error: refusing to rmdir '{validated}'." + if not validated.startswith(DOCUMENTS_ROOT + "/"): + return ( + "Error: cloud rmdir must target a path under /documents/ " + f"(got '{validated}')." + ) + + cwd = self._current_cwd(runtime) + if validated == cwd or _is_ancestor_of(validated, cwd): + return ( + f"Error: cannot rmdir '{validated}' because the current " + "cwd is at or under it. cd out first." + ) + + staged_dirs = list(runtime.state.get("staged_dirs") or []) + pending_dir_deletes = list( + runtime.state.get("pending_dir_deletes") or [] + ) + if any( + isinstance(d, dict) and d.get("path") == validated + for d in pending_dir_deletes + ): + return f"'{validated}' is already queued for deletion." + + backend = self._get_backend(runtime) + + # The path must currently exist either in DB folder paths or + # in staged_dirs. We rely on KBPostgresBackend.als_info (which + # already accounts for pending deletes/moves) to evaluate + # both existence and emptiness against the post-staged view. + exists_in_staged = validated in staged_dirs + children: list[Any] = [] + if isinstance(backend, KBPostgresBackend): + children = list(await backend.als_info(validated)) + + # Detect "is a file" — if als_info returns no children but + # the path is actually a file, we should reject. We use + # _load_file_data to disambiguate file vs missing folder. + if ( + isinstance(backend, KBPostgresBackend) + and not children + and not exists_in_staged + ): + loaded = await backend._load_file_data(validated) + if loaded is not None: + return ( + f"Error: '{validated}' is a file. Use rm to delete files." + ) + # Confirm folder exists in DB by checking the parent listing. + parent = posixpath.dirname(validated) or "/" + parent_listing = await backend.als_info(parent) + parent_has_dir = any( + info.get("path") == validated and info.get("is_dir") + for info in parent_listing + ) + if not parent_has_dir: + return f"Error: directory '{validated}' not found." + + if children: + return ( + f"Error: directory '{validated}' is not empty. " + "Remove contents first." + ) + + # Same-turn mkdir un-stage: drop the staged_dirs entry + # entirely and skip queuing a DB delete (nothing was ever + # committed). + if exists_in_staged: + rest = [d for d in staged_dirs if d != validated] + return Command( + update={ + "staged_dirs": [_CLEAR, *rest], + "staged_dir_tool_calls": {validated: None}, + "messages": [ + ToolMessage( + content=(f"Un-staged directory '{validated}'."), + tool_call_id=runtime.tool_call_id, + ) + ], + } + ) + + return Command( + update={ + "pending_dir_deletes": [ + { + "path": validated, + "tool_call_id": runtime.tool_call_id, + } + ], + "messages": [ + ToolMessage( + content=( + f"Staged rmdir of '{validated}' (will commit " + "at end of turn)." + ), + tool_call_id=runtime.tool_call_id, + ) + ], + } + ) + + # Desktop mode — hit disk immediately. + backend = self._get_backend(runtime) + armdir = getattr(backend, "armdir", None) + if not callable(armdir): + return "Error: rmdir is not supported by the active backend." + res: WriteResult = await armdir(validated) + if res.error: + return res.error + return Command( + update={ + "messages": [ + ToolMessage( + content=f"Deleted directory '{res.path or validated}'", + tool_call_id=runtime.tool_call_id, + ) + ], + } + ) + + def sync_rmdir( + path: Annotated[ + str, + "Absolute or relative path of the empty directory to delete.", + ], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> Command | str: + return self._run_async_blocking(async_rmdir(path, runtime)) + + return StructuredTool.from_function( + name="rmdir", + description=tool_description, + func=sync_rmdir, + coroutine=async_rmdir, + ) + # ------------------------------------------------------------------ tool: list_tree def _create_list_tree_tool(self) -> BaseTool: diff --git a/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py b/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py index 378b83950..d577441dd 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py +++ b/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py @@ -1,16 +1,29 @@ """End-of-turn persistence for the cloud-mode SurfSense filesystem. This middleware runs ``aafter_agent`` once per turn (cloud only). It commits -all staged folder creations, file moves, and content writes/edits to -Postgres in a single ordered pass: +all staged folder creations, file moves, content writes/edits, file deletes +(``rm``), and directory deletes (``rmdir``) to Postgres in a single ordered +pass: 1. Materialize ``staged_dirs`` into ``Folder`` rows. 2. Apply ``pending_moves`` in order (chained moves resolved via ``doc_id_by_path``). 3. Normalize ``dirty_paths`` through ``pending_moves`` so write-then-move - sequences commit at the final path. + sequences commit at the final path. Paths queued for ``rm`` this turn + are dropped here so a write+rm sequence doesn't recreate the doc. 4. Commit content writes / edits for ``/documents/*`` paths, skipping ``temp_*`` basenames. +5. Apply ``pending_deletes`` (``rm``) — file deletes run BEFORE directory + deletes so a same-turn ``rm /a/x.md`` + ``rmdir /a`` sequence works. +6. Apply ``pending_dir_deletes`` (``rmdir``); re-verifies emptiness against + the post-step-5 DB state. + +When ``flags.enable_action_log`` is on every destructive op also writes a +``DocumentRevision`` / ``FolderRevision`` snapshot bound to the +originating ``AgentActionLog`` row via ``tool_call_id``. ``rm``/``rmdir`` +share a single ``SAVEPOINT`` with their snapshot — if the snapshot fails +the DELETE rolls back and we surface the error rather than silently +making the data irreversible. The commit body is exposed as a free function ``commit_staged_filesystem_state`` so the optional stream-task fallback (``stream_new_chat.py``) can call the @@ -25,12 +38,13 @@ from typing import Any from fractional_indexing import generate_key_between from langchain.agents.middleware import AgentMiddleware, AgentState -from langchain_core.callbacks import dispatch_custom_event +from langchain_core.callbacks import adispatch_custom_event, dispatch_custom_event from langgraph.runtime import Runtime -from sqlalchemy import delete, select +from sqlalchemy import delete, select, update from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.feature_flags import get_flags from app.agents.new_chat.filesystem_selection import FilesystemMode from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState from app.agents.new_chat.path_resolver import ( @@ -41,10 +55,13 @@ from app.agents.new_chat.path_resolver import ( ) from app.agents.new_chat.state_reducers import _CLEAR from app.db import ( + AgentActionLog, Chunk, Document, + DocumentRevision, DocumentType, Folder, + FolderRevision, shielded_async_session, ) from app.indexing_pipeline.document_chunker import chunk_text @@ -123,6 +140,47 @@ async def _ensure_folder_hierarchy( return parent_id +async def _resolve_folder_id( + session: AsyncSession, + *, + search_space_id: int, + folder_parts: list[str], +) -> int | None: + """Look up an existing folder chain without creating anything. + + Returns ``None`` if any segment is missing. Used by ``rmdir`` snapshot + capture and by parent-folder lookup at ``rmdir`` commit time. + """ + if not folder_parts: + return None + parent_id: int | None = None + for raw in folder_parts: + name = safe_folder_segment(str(raw)) + query = select(Folder).where( + Folder.search_space_id == search_space_id, + Folder.name == name, + ) + query = ( + query.where(Folder.parent_id.is_(None)) + if parent_id is None + else query.where(Folder.parent_id == parent_id) + ) + result = await session.execute(query) + folder = result.scalar_one_or_none() + if folder is None: + return None + parent_id = folder.id + return parent_id + + +def _split_folder_path(folder_path: str) -> list[str]: + """Return the folder segments under ``/documents/`` for a path.""" + if not folder_path.startswith(DOCUMENTS_ROOT): + return [] + rel = folder_path[len(DOCUMENTS_ROOT) :].strip("/") + return [p for p in rel.split("/") if p] + + # --------------------------------------------------------------------------- # Document helpers # --------------------------------------------------------------------------- @@ -331,6 +389,298 @@ async def _apply_move( return {"id": document.id, "source": source, "dest": dest, "title": new_title} +# --------------------------------------------------------------------------- +# Action log binding helpers +# --------------------------------------------------------------------------- + + +async def _find_action_ids_batch( + session: AsyncSession, + *, + thread_id: int | None, + tool_call_ids: set[str], +) -> dict[str, int]: + """Resolve ``tool_call_id -> AgentActionLog.id`` in a single query. + + Returns an empty dict when ``thread_id`` or ``tool_call_ids`` are + missing — callers treat that as "no binding available" and write the + revision with ``agent_action_id = NULL``. + """ + if thread_id is None or not tool_call_ids: + return {} + rows = await session.execute( + select(AgentActionLog.id, AgentActionLog.tool_call_id).where( + AgentActionLog.thread_id == thread_id, + AgentActionLog.tool_call_id.in_(list(tool_call_ids)), + ) + ) + mapping: dict[str, int] = {} + for row in rows.all(): + if row.tool_call_id and row.id: + mapping[str(row.tool_call_id)] = int(row.id) + return mapping + + +async def _mark_action_reversible( + session: AsyncSession, + *, + action_id: int | None, +) -> None: + """Flip ``agent_action_log.reversible = TRUE`` for ``action_id``. + + Best-effort: caller may invoke from inside a SAVEPOINT and treat + failure as a soft demotion (snapshot persists, just no Revert button). + + Callers should also call ``_dispatch_reversibility_update`` (defined + below) AFTER the enclosing SAVEPOINT block exits successfully so the + chat tool card can light up its Revert button without + re-fetching ``GET /threads/.../actions``. Dispatching from inside the + SAVEPOINT would risk emitting "reversible=true" for rows whose + update gets rolled back if the surrounding destructive op fails. + """ + if action_id is None: + return + await session.execute( + update(AgentActionLog) + .where(AgentActionLog.id == action_id) + .values(reversible=True) + ) + + +async def _dispatch_reversibility_update(action_id: int | None) -> None: + """Best-effort dispatch of an ``action_log_updated`` custom event. + + Surfaces the post-SAVEPOINT reversibility flip to the SSE layer so + the chat tool card can flip its Revert button live. Defensive: + failures are logged at debug level and swallowed; the + REST endpoint ``GET /threads/.../actions`` is still authoritative. + + .. warning:: + Inside :func:`commit_staged_filesystem_state` we DEFER all + dispatches until the outer ``session.commit()`` succeeds — see + the ``deferred_dispatches`` queue in that function. Dispatching + from inside a SAVEPOINT block while the outer transaction is + still pending would emit ``reversible=true`` for rows whose + snapshots get rolled back if the outer commit fails. Direct + callers (e.g. the optional stream-task fallback) that own the + full session lifetime can still call this helper inline. + """ + if action_id is None: + return + try: + await adispatch_custom_event( + "action_log_updated", + {"id": int(action_id), "reversible": True}, + ) + except Exception: + logger.debug( + "kb_persistence.aafter_agent failed to dispatch action_log_updated", + exc_info=True, + ) + + +# --------------------------------------------------------------------------- +# Snapshot helpers +# --------------------------------------------------------------------------- +# +# Best-effort helpers swallow + log so a snapshot failure can never break +# the destructive op for non-destructive tools (write/edit/move/mkdir). +# Strict helpers run inside the SAME ``begin_nested()`` SAVEPOINT as the +# destructive DELETE — failure aborts the savepoint and leaves the doc / +# folder intact, so revertable ops never become irreversible silently. + + +def _doc_revision_payload( + doc: Document, + *, + chunks_before: list[dict[str, str]] | None = None, +) -> dict[str, Any]: + """Pre-mutation field map for ``DocumentRevision``.""" + metadata = dict(doc.document_metadata or {}) + return { + "content_before": doc.content, + "title_before": doc.title, + "folder_id_before": doc.folder_id, + "chunks_before": chunks_before, + "metadata_before": metadata or None, + } + + +async def _load_chunks_for_snapshot( + session: AsyncSession, *, doc_id: int +) -> list[dict[str, str]]: + rows = await session.execute( + select(Chunk.content).where(Chunk.document_id == doc_id).order_by(Chunk.id) + ) + return [{"content": row.content} for row in rows.all() if row.content is not None] + + +async def _snapshot_document_pre_write( + session: AsyncSession, + *, + doc: Document, + action_id: int | None, + search_space_id: int, + turn_id: str | None = None, + deferred_dispatches: list[int] | None = None, +) -> int | None: + """Best-effort snapshot ahead of an in-place ``write_file``/``edit_file``. + + When ``deferred_dispatches`` is provided, on success the action id + is APPENDED to it and the SSE dispatch is left to the caller (so it + can be flushed only after the outer ``session.commit()`` succeeds). + """ + try: + async with session.begin_nested(): + chunks = await _load_chunks_for_snapshot(session, doc_id=doc.id) + payload = _doc_revision_payload(doc, chunks_before=chunks) + rev = DocumentRevision( + document_id=doc.id, + search_space_id=search_space_id, + created_by_turn_id=turn_id, + agent_action_id=action_id, + **payload, + ) + session.add(rev) + await session.flush() + await _mark_action_reversible(session, action_id=action_id) + rev_id = rev.id + if deferred_dispatches is None: + await _dispatch_reversibility_update(action_id) + elif action_id is not None: + deferred_dispatches.append(int(action_id)) + return rev_id + except Exception as exc: # pragma: no cover - defensive + logger.warning( + "kb_persistence: pre-write snapshot for doc=%s failed: %s", + doc.id, + exc, + ) + return None + + +async def _snapshot_document_pre_create( + session: AsyncSession, + *, + action_id: int | None, + search_space_id: int, + turn_id: str | None = None, + deferred_dispatches: list[int] | None = None, +) -> int | None: + """Best-effort placeholder revision for a fresh ``write_file`` create. + + ``document_id`` is patched in by the caller after the new doc is + flushed and gets an ID; the placeholder lets us bind the action_id + even though no parent row exists yet. + """ + try: + async with session.begin_nested(): + rev = DocumentRevision( + document_id=None, + search_space_id=search_space_id, + content_before=None, + title_before=None, + folder_id_before=None, + chunks_before=None, + metadata_before=None, + created_by_turn_id=turn_id, + agent_action_id=action_id, + ) + session.add(rev) + await session.flush() + await _mark_action_reversible(session, action_id=action_id) + rev_id = rev.id + if deferred_dispatches is None: + await _dispatch_reversibility_update(action_id) + elif action_id is not None: + deferred_dispatches.append(int(action_id)) + return rev_id + except Exception as exc: # pragma: no cover - defensive + logger.warning("kb_persistence: pre-create snapshot failed: %s", exc) + return None + + +async def _snapshot_document_pre_move( + session: AsyncSession, + *, + doc: Document, + action_id: int | None, + search_space_id: int, + turn_id: str | None = None, + deferred_dispatches: list[int] | None = None, +) -> int | None: + """Best-effort snapshot ahead of a ``move_file``.""" + try: + async with session.begin_nested(): + payload = _doc_revision_payload(doc, chunks_before=None) + rev = DocumentRevision( + document_id=doc.id, + search_space_id=search_space_id, + created_by_turn_id=turn_id, + agent_action_id=action_id, + **payload, + ) + session.add(rev) + await session.flush() + await _mark_action_reversible(session, action_id=action_id) + rev_id = rev.id + if deferred_dispatches is None: + await _dispatch_reversibility_update(action_id) + elif action_id is not None: + deferred_dispatches.append(int(action_id)) + return rev_id + except Exception as exc: # pragma: no cover - defensive + logger.warning( + "kb_persistence: pre-move snapshot for doc=%s failed: %s", + doc.id, + exc, + ) + return None + + +async def _snapshot_folder_pre_mkdir( + session: AsyncSession, + *, + folder: Folder, + action_id: int | None, + search_space_id: int, + turn_id: str | None = None, + deferred_dispatches: list[int] | None = None, +) -> int | None: + """Best-effort placeholder for an ``mkdir`` (revert deletes the folder). + + The "before" state is "did not exist", so all ``*_before`` fields are + NULL — revert routes by ``tool_name == "mkdir"`` and DELETEs. + """ + try: + async with session.begin_nested(): + rev = FolderRevision( + folder_id=folder.id, + search_space_id=search_space_id, + name_before=None, + parent_id_before=None, + position_before=None, + created_by_turn_id=turn_id, + agent_action_id=action_id, + ) + session.add(rev) + await session.flush() + await _mark_action_reversible(session, action_id=action_id) + rev_id = rev.id + if deferred_dispatches is None: + await _dispatch_reversibility_update(action_id) + elif action_id is not None: + deferred_dispatches.append(int(action_id)) + return rev_id + except Exception as exc: # pragma: no cover - defensive + logger.warning( + "kb_persistence: pre-mkdir snapshot for folder=%s failed: %s", + folder.id, + exc, + ) + return None + + # --------------------------------------------------------------------------- # Commit body # --------------------------------------------------------------------------- @@ -342,12 +692,20 @@ async def commit_staged_filesystem_state( search_space_id: int, created_by_id: str | None, filesystem_mode: FilesystemMode = FilesystemMode.CLOUD, + thread_id: int | None = None, dispatch_events: bool = True, ) -> dict[str, Any] | None: """Commit all staged filesystem changes; return the state delta for reducers. Shared between :class:`KnowledgeBasePersistenceMiddleware.aafter_agent` and the optional stream-task fallback. + + When ``flags.enable_action_log`` is on every destructive op also writes + a ``DocumentRevision`` / ``FolderRevision`` snapshot bound to the + originating ``AgentActionLog`` row via ``tool_call_id``. Snapshot + durability is best-effort for non-destructive ops and STRICT for + ``rm``/``rmdir`` (snapshot + DELETE share a SAVEPOINT — snapshot + failure aborts the delete). """ if filesystem_mode != FilesystemMode.CLOUD: return None @@ -360,8 +718,20 @@ async def commit_staged_filesystem_state( files: dict[str, Any] = state_dict.get("files") or {} staged_dirs: list[str] = list(state_dict.get("staged_dirs") or []) + staged_dir_tool_calls: dict[str, str] = dict( + state_dict.get("staged_dir_tool_calls") or {} + ) pending_moves: list[dict[str, Any]] = list(state_dict.get("pending_moves") or []) + pending_deletes: list[dict[str, Any]] = list( + state_dict.get("pending_deletes") or [] + ) + pending_dir_deletes: list[dict[str, Any]] = list( + state_dict.get("pending_dir_deletes") or [] + ) dirty_paths: list[str] = list(state_dict.get("dirty_paths") or []) + dirty_path_tool_calls: dict[str, str] = dict( + state_dict.get("dirty_path_tool_calls") or {} + ) doc_id_by_path: dict[str, int] = dict(state_dict.get("doc_id_by_path") or {}) kb_anon_doc = state_dict.get("kb_anon_doc") @@ -374,32 +744,112 @@ async def commit_staged_filesystem_state( return { "dirty_paths": [_CLEAR], "staged_dirs": [_CLEAR], + "staged_dir_tool_calls": {_CLEAR: True}, "pending_moves": [_CLEAR], + "pending_deletes": [_CLEAR], + "pending_dir_deletes": [_CLEAR], + "dirty_path_tool_calls": {_CLEAR: True}, "files": dict.fromkeys(temp_paths), } - if not (staged_dirs or pending_moves or dirty_paths): + if not ( + staged_dirs + or pending_moves + or dirty_paths + or pending_deletes + or pending_dir_deletes + ): return None + flags = get_flags() + snapshot_enabled = flags.enable_action_log + + # De-duplicate pending deletes per-path while preserving the latest + # tool_call_id (the one the user is most likely to revert via the UI). + file_delete_paths: dict[str, str] = {} + for entry in pending_deletes: + if not isinstance(entry, dict): + continue + path = str(entry.get("path") or "") + if path: + file_delete_paths[path] = str(entry.get("tool_call_id") or "") + dir_delete_paths: dict[str, str] = {} + for entry in pending_dir_deletes: + if not isinstance(entry, dict): + continue + path = str(entry.get("path") or "") + if path: + dir_delete_paths[path] = str(entry.get("tool_call_id") or "") + committed_creates: list[dict[str, Any]] = [] committed_updates: list[dict[str, Any]] = [] + committed_deletes: list[dict[str, Any]] = [] + committed_folder_deletes: list[dict[str, Any]] = [] discarded: list[str] = [] applied_moves: list[dict[str, Any]] = [] doc_id_path_tombstones: dict[str, int | None] = {} tree_changed = False + # Reversibility-flip dispatches are deferred until AFTER the outer + # ``session.commit()`` succeeds. Dispatching from inside the + # SAVEPOINT chain while the outer transaction is still pending + # would emit ``reversible=true`` for rows whose snapshots get rolled + # back if the final commit raises. Snapshot helpers append on + # success; we drain this list after commit and silently abandon it + # on rollback so the UI stays consistent with durable state. + deferred_dispatches: list[int] = [] try: async with shielded_async_session() as session: + # ------------------------------------------------------------------ + # Resolve action-id bindings up front. One SELECT per turn for all + # tool_call_ids, NOT one per op — important because a turn that + # touches 50 paths would otherwise issue 50 lookups. + # ------------------------------------------------------------------ + action_id_by_call: dict[str, int] = {} + if snapshot_enabled and thread_id is not None: + tool_call_ids: set[str] = set() + tool_call_ids.update( + tcid for tcid in staged_dir_tool_calls.values() if tcid + ) + for move in pending_moves: + tcid = str(move.get("tool_call_id") or "") + if tcid: + tool_call_ids.add(tcid) + tool_call_ids.update( + tcid for tcid in dirty_path_tool_calls.values() if tcid + ) + tool_call_ids.update( + tcid for tcid in file_delete_paths.values() if tcid + ) + tool_call_ids.update(tcid for tcid in dir_delete_paths.values() if tcid) + action_id_by_call = await _find_action_ids_batch( + session, + thread_id=thread_id, + tool_call_ids=tool_call_ids, + ) + + def _action_id_for(tool_call_id: str | None) -> int | None: + if not snapshot_enabled or not tool_call_id: + return None + return action_id_by_call.get(str(tool_call_id)) + + turn_id_for_revision = ( + next(iter(action_id_by_call), None) if action_id_by_call else None + ) + + # ------------------------------------------------------------------ + # 1. staged_dirs -> Folder rows. Snapshot post-flush so the new + # folder_id is available for the FK. + # ------------------------------------------------------------------ for folder_path in staged_dirs: if not isinstance(folder_path, str): continue if not folder_path.startswith(DOCUMENTS_ROOT): continue - rel = folder_path[len(DOCUMENTS_ROOT) :].strip("/") - folder_parts_full = [p for p in rel.split("/") if p] + folder_parts_full = _split_folder_path(folder_path) if not folder_parts_full: continue - await _ensure_folder_hierarchy( + folder_id = await _ensure_folder_hierarchy( session, search_space_id=search_space_id, created_by_id=created_by_id, @@ -407,7 +857,61 @@ async def commit_staged_filesystem_state( ) tree_changed = True + if snapshot_enabled and folder_id is not None: + tcid = staged_dir_tool_calls.get(folder_path) + action_id = _action_id_for(tcid) + if action_id is not None: + # Re-read the folder for the snapshot. + result = await session.execute( + select(Folder).where(Folder.id == folder_id) + ) + folder_row = result.scalar_one_or_none() + if folder_row is not None: + await _snapshot_folder_pre_mkdir( + session, + folder=folder_row, + action_id=action_id, + search_space_id=search_space_id, + turn_id=tcid, + deferred_dispatches=deferred_dispatches, + ) + + # ------------------------------------------------------------------ + # 2. pending_moves. Snapshot pre-move (in-place restore on revert). + # ------------------------------------------------------------------ for move in pending_moves: + source = str(move.get("source") or "") + if snapshot_enabled and source: + tcid = str(move.get("tool_call_id") or "") + action_id = _action_id_for(tcid) + if action_id is not None: + # Resolve the doc to snapshot BEFORE we mutate it. + doc_id_pre = doc_id_by_path.get(source) + document_pre: Document | None = None + if doc_id_pre is not None: + res_pre = await session.execute( + select(Document).where( + Document.id == doc_id_pre, + Document.search_space_id == search_space_id, + ) + ) + document_pre = res_pre.scalar_one_or_none() + if document_pre is None: + document_pre = await virtual_path_to_doc( + session, + search_space_id=search_space_id, + virtual_path=source, + ) + if document_pre is not None: + await _snapshot_document_pre_move( + session, + doc=document_pre, + action_id=action_id, + search_space_id=search_space_id, + turn_id=tcid, + deferred_dispatches=deferred_dispatches, + ) + applied = await _apply_move( session, search_space_id=search_space_id, @@ -431,8 +935,13 @@ async def commit_staged_filesystem_state( path = move_alias[path] return path + # ------------------------------------------------------------------ + # 3. dirty_paths -> writes/edits. Skip any path queued for ``rm`` + # this turn so a write+rm sequence doesn't recreate the doc. + # ------------------------------------------------------------------ kb_dirty_seen: set[str] = set() kb_dirty: list[str] = [] + kb_dirty_origin: dict[str, str] = {} for raw in dirty_paths: if not isinstance(raw, str): continue @@ -441,8 +950,12 @@ async def commit_staged_filesystem_state( continue if final in kb_dirty_seen: continue + if final in file_delete_paths: + discarded.append(final) + continue kb_dirty_seen.add(final) kb_dirty.append(final) + kb_dirty_origin[final] = raw for path in kb_dirty: basename = _basename(path) @@ -454,6 +967,15 @@ async def commit_staged_filesystem_state( continue content = "\n".join(file_data.get("content") or []) doc_id = doc_id_by_path.get(path) + # Path ↔ tool_call_id binding: the dirty_paths list dedupes via + # _add_unique_reducer, so we look up the latest tool_call_id by + # path (or by the un-renamed origin). + origin = kb_dirty_origin.get(path, path) + tcid = dirty_path_tool_calls.get(path) or dirty_path_tool_calls.get( + origin + ) + action_id = _action_id_for(tcid) + if doc_id is None: # The in-memory ``doc_id_by_path`` is per-thread and starts # empty in every new chat. If the agent writes to a path @@ -470,6 +992,23 @@ async def commit_staged_filesystem_state( doc_id = existing.id doc_id_by_path[path] = existing.id if doc_id is not None: + if snapshot_enabled and action_id is not None: + result_doc = await session.execute( + select(Document).where( + Document.id == doc_id, + Document.search_space_id == search_space_id, + ) + ) + existing_doc = result_doc.scalar_one_or_none() + if existing_doc is not None: + await _snapshot_document_pre_write( + session, + doc=existing_doc, + action_id=action_id, + search_space_id=search_space_id, + turn_id=tcid, + deferred_dispatches=deferred_dispatches, + ) updated = await _update_document( session, doc_id=doc_id, @@ -492,12 +1031,21 @@ async def commit_staged_filesystem_state( } ) else: - # Wrap each create in a SAVEPOINT so a residual - # ``IntegrityError`` (e.g. a deployment that hasn't run - # migration 133 yet, where ``documents.content_hash`` - # still carries its legacy global UNIQUE constraint) - # rolls back only this one create instead of poisoning - # the whole turn's transaction. + # Fresh create. Wrap each create in a SAVEPOINT so a + # residual ``IntegrityError`` (e.g. a deployment that + # hasn't run migration 133 yet, where + # ``documents.content_hash`` still carries its legacy + # global UNIQUE constraint) rolls back only this one + # create instead of poisoning the whole turn. + placeholder_revision_id: int | None = None + if snapshot_enabled and action_id is not None: + placeholder_revision_id = await _snapshot_document_pre_create( + session, + action_id=action_id, + search_space_id=search_space_id, + turn_id=tcid, + deferred_dispatches=deferred_dispatches, + ) try: async with session.begin_nested(): new_doc = await _create_document( @@ -511,14 +1059,16 @@ async def commit_staged_filesystem_state( logger.warning( "kb_persistence: skipping %s create: %s", path, exc ) + # Roll back the placeholder revision since the create + # never happened. + if placeholder_revision_id is not None: + await session.execute( + delete(DocumentRevision).where( + DocumentRevision.id == placeholder_revision_id + ) + ) continue except IntegrityError as exc: - # The path-uniqueness check above already protected - # against ``unique_identifier_hash`` collisions, so - # the most likely culprit is the legacy - # ``ix_documents_content_hash`` UNIQUE constraint - # that migration 133 drops. Log loudly so operators - # know to run the migration; do NOT silently swallow. msg = str(exc.orig) if exc.orig is not None else str(exc) logger.error( "kb_persistence: IntegrityError creating %s: %s. " @@ -528,8 +1078,20 @@ async def commit_staged_filesystem_state( path, msg, ) + if placeholder_revision_id is not None: + await session.execute( + delete(DocumentRevision).where( + DocumentRevision.id == placeholder_revision_id + ) + ) continue doc_id_by_path[path] = new_doc.id + if placeholder_revision_id is not None: + await session.execute( + update(DocumentRevision) + .where(DocumentRevision.id == placeholder_revision_id) + .values(document_id=new_doc.id) + ) committed_creates.append( { "id": new_doc.id, @@ -545,13 +1107,234 @@ async def commit_staged_filesystem_state( ) tree_changed = True + # ------------------------------------------------------------------ + # 4. pending_deletes -> ``rm``. STRICT durability: snapshot + DELETE + # share a SAVEPOINT. If the snapshot insert fails, the DELETE + # rolls back too and we surface the error rather than silently + # making the data irreversible. + # ------------------------------------------------------------------ + for raw_path, tcid in file_delete_paths.items(): + final = _final_path(raw_path) + if not final.startswith(DOCUMENTS_ROOT + "/"): + continue + action_id = _action_id_for(tcid) + + # Resolve the doc. + doc_id_for_delete = doc_id_by_path.get(final) + document_to_delete: Document | None = None + if doc_id_for_delete is not None: + result = await session.execute( + select(Document).where( + Document.id == doc_id_for_delete, + Document.search_space_id == search_space_id, + ) + ) + document_to_delete = result.scalar_one_or_none() + if document_to_delete is None: + document_to_delete = await virtual_path_to_doc( + session, + search_space_id=search_space_id, + virtual_path=final, + ) + if document_to_delete is None: + logger.info( + "kb_persistence: skipping rm %s (target not found)", final + ) + continue + + doc_pk = document_to_delete.id + doc_title = document_to_delete.title + doc_folder_id = document_to_delete.folder_id + + try: + async with session.begin_nested(): + # Strict: snapshot first; failure aborts the delete. + if snapshot_enabled and action_id is not None: + chunks = await _load_chunks_for_snapshot( + session, doc_id=doc_pk + ) + payload = _doc_revision_payload( + document_to_delete, chunks_before=chunks + ) + rev = DocumentRevision( + document_id=doc_pk, + search_space_id=search_space_id, + created_by_turn_id=tcid, + agent_action_id=action_id, + **payload, + ) + session.add(rev) + await session.flush() + await _mark_action_reversible(session, action_id=action_id) + await session.execute( + delete(Document).where(Document.id == doc_pk) + ) + except Exception as exc: + logger.exception( + "kb_persistence: strict rm SAVEPOINT for path=%s failed: %s", + final, + exc, + ) + continue + + # B1 — SAVEPOINT released. Defer the reversibility-flip + # dispatch until AFTER the outer commit succeeds so we + # never tell the UI a row is reversible if its snapshot + # gets rolled back. + if snapshot_enabled and action_id is not None: + deferred_dispatches.append(int(action_id)) + + doc_id_by_path.pop(final, None) + doc_id_path_tombstones[final] = None + committed_deletes.append( + { + "id": doc_pk, + "title": doc_title, + "documentType": DocumentType.NOTE.value, + "searchSpaceId": search_space_id, + "folderId": doc_folder_id, + "createdById": str(created_by_id) if created_by_id else None, + "virtualPath": final, + } + ) + tree_changed = True + + # ------------------------------------------------------------------ + # 5. pending_dir_deletes -> ``rmdir``. STRICT durability + final + # emptiness check (after step 4's deletes have run, an "empty + # mid-turn" directory really IS empty in DB now). + # ------------------------------------------------------------------ + for raw_path, tcid in dir_delete_paths.items(): + final = _final_path(raw_path) + if not final.startswith(DOCUMENTS_ROOT + "/"): + continue + action_id = _action_id_for(tcid) + + folder_parts = _split_folder_path(final) + if not folder_parts: + continue + folder_id = await _resolve_folder_id( + session, + search_space_id=search_space_id, + folder_parts=folder_parts, + ) + if folder_id is None: + logger.info( + "kb_persistence: skipping rmdir %s (folder not found)", final + ) + continue + + # Re-check emptiness against in-DB state. + docs_in_folder = await session.execute( + select(Document.id) + .where(Document.folder_id == folder_id) + .where(Document.search_space_id == search_space_id) + .limit(1) + ) + if docs_in_folder.scalar_one_or_none() is not None: + logger.warning( + "kb_persistence: refusing rmdir %s — non-empty at commit time", + final, + ) + continue + child_folders = await session.execute( + select(Folder.id) + .where(Folder.parent_id == folder_id) + .where(Folder.search_space_id == search_space_id) + .limit(1) + ) + if child_folders.scalar_one_or_none() is not None: + logger.warning( + "kb_persistence: refusing rmdir %s — has child folders " + "at commit time", + final, + ) + continue + + folder_to_delete_res = await session.execute( + select(Folder).where(Folder.id == folder_id) + ) + folder_to_delete = folder_to_delete_res.scalar_one_or_none() + if folder_to_delete is None: + continue + + folder_pk = folder_to_delete.id + folder_name = folder_to_delete.name + folder_parent_id = folder_to_delete.parent_id + folder_position = folder_to_delete.position + + try: + async with session.begin_nested(): + if snapshot_enabled and action_id is not None: + rev = FolderRevision( + folder_id=folder_pk, + search_space_id=search_space_id, + name_before=folder_name, + parent_id_before=folder_parent_id, + position_before=folder_position, + created_by_turn_id=tcid, + agent_action_id=action_id, + ) + session.add(rev) + await session.flush() + await _mark_action_reversible(session, action_id=action_id) + await session.execute( + delete(Folder).where(Folder.id == folder_pk) + ) + except Exception as exc: + logger.exception( + "kb_persistence: strict rmdir SAVEPOINT for path=%s failed: %s", + final, + exc, + ) + continue + + # B1 — SAVEPOINT released. Defer the reversibility-flip + # dispatch until AFTER the outer commit succeeds so we + # never tell the UI a row is reversible if its snapshot + # gets rolled back. + if snapshot_enabled and action_id is not None: + deferred_dispatches.append(int(action_id)) + + committed_folder_deletes.append( + { + "id": folder_pk, + "name": folder_name, + "searchSpaceId": search_space_id, + "parentId": folder_parent_id, + "virtualPath": final, + } + ) + tree_changed = True + await session.commit() except Exception: # pragma: no cover - rollback safety net logger.exception( "kb_persistence: commit failed (search_space=%s)", search_space_id ) + # Outer commit raised — every SAVEPOINT-released change above + # (snapshots + reversibility flips) is now rolled back. Drop + # the deferred SSE dispatches so the UI stays consistent with + # durable state. + deferred_dispatches.clear() return None + # Outer commit succeeded; flush deferred reversibility-flip + # dispatches now so the chat tool card can light up its Revert + # button without re-fetching ``GET /threads/.../actions``. De-dup + # to avoid emitting the same id twice (e.g. write-then-rm in the + # same turn dispatches once for each snapshot site). + if deferred_dispatches and dispatch_events: + for action_id in dict.fromkeys(deferred_dispatches): + try: + await _dispatch_reversibility_update(action_id) + except Exception: + logger.debug( + "kb_persistence: deferred reversibility dispatch failed for action_id=%s", + action_id, + exc_info=True, + ) + if dispatch_events: for payload in committed_creates: try: @@ -567,11 +1350,34 @@ async def commit_staged_filesystem_state( logger.exception( "kb_persistence: failed to dispatch document_updated event" ) + for payload in committed_deletes: + try: + dispatch_custom_event("document_deleted", payload) + except Exception: + logger.exception( + "kb_persistence: failed to dispatch document_deleted event" + ) + for payload in committed_folder_deletes: + try: + dispatch_custom_event("folder_deleted", payload) + except Exception: + logger.exception( + "kb_persistence: failed to dispatch folder_deleted event" + ) temp_paths = [ p for p in files if isinstance(p, str) and _basename(p).startswith(_TEMP_PREFIX) ] + # Tombstone every committed-delete path so a stale ``state["files"]`` entry + # (which als_info would otherwise interpret as content) cannot survive into + # the next turn and make a now-empty folder look non-empty. + deleted_file_paths = [ + str(payload.get("virtualPath") or "") + for payload in committed_deletes + if payload.get("virtualPath") + ] + doc_id_update: dict[str, int | None] = {**doc_id_path_tombstones} for payload in committed_creates: doc_id_update[str(payload.get("virtualPath") or "")] = int(payload["id"]) @@ -579,23 +1385,38 @@ async def commit_staged_filesystem_state( delta: dict[str, Any] = { "dirty_paths": [_CLEAR], "staged_dirs": [_CLEAR], + "staged_dir_tool_calls": {_CLEAR: True}, "pending_moves": [_CLEAR], + "pending_deletes": [_CLEAR], + "pending_dir_deletes": [_CLEAR], + "dirty_path_tool_calls": {_CLEAR: True}, } + files_delta: dict[str, Any] = {} if temp_paths: - delta["files"] = dict.fromkeys(temp_paths) + files_delta.update(dict.fromkeys(temp_paths)) + for path in deleted_file_paths: + files_delta[path] = None + if files_delta: + delta["files"] = files_delta if doc_id_update: delta["doc_id_by_path"] = doc_id_update if tree_changed: delta["tree_version"] = int(state_dict.get("tree_version") or 0) + 1 + # Avoid 'unused' lint when turn_id_for_revision was only useful for + # diagnostic purposes inside the SAVEPOINT chain above. + _ = turn_id_for_revision + logger.info( "kb_persistence: commit (search_space=%s) creates=%d updates=%d " - "moves=%d staged_dirs=%d discarded=%d", + "moves=%d staged_dirs=%d deletes=%d folder_deletes=%d discarded=%d", search_space_id, len(committed_creates), len(committed_updates), len(applied_moves), len(staged_dirs), + len(committed_deletes), + len(committed_folder_deletes), len(discarded), ) return delta @@ -618,10 +1439,12 @@ class KnowledgeBasePersistenceMiddleware(AgentMiddleware): # type: ignore[type- search_space_id: int, created_by_id: str | None, filesystem_mode: FilesystemMode, + thread_id: int | None = None, ) -> None: self.search_space_id = search_space_id self.created_by_id = created_by_id self.filesystem_mode = filesystem_mode + self.thread_id = thread_id async def aafter_agent( # type: ignore[override] self, @@ -636,6 +1459,7 @@ class KnowledgeBasePersistenceMiddleware(AgentMiddleware): # type: ignore[type- search_space_id=self.search_space_id, created_by_id=self.created_by_id, filesystem_mode=self.filesystem_mode, + thread_id=self.thread_id, ) diff --git a/surfsense_backend/app/agents/new_chat/middleware/kb_postgres_backend.py b/surfsense_backend/app/agents/new_chat/middleware/kb_postgres_backend.py index ddb2d4af1..7cf3bf8cd 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/kb_postgres_backend.py +++ b/surfsense_backend/app/agents/new_chat/middleware/kb_postgres_backend.py @@ -115,6 +115,12 @@ class KBPostgresBackend(BackendProtocol): def _pending_moves(self) -> list[dict[str, Any]]: return list(self.state.get("pending_moves") or []) + def _pending_deletes(self) -> list[dict[str, Any]]: + return list(self.state.get("pending_deletes") or []) + + def _pending_dir_deletes(self) -> list[dict[str, Any]]: + return list(self.state.get("pending_dir_deletes") or []) + def _kb_anon_doc(self) -> dict[str, Any] | None: anon = self.state.get("kb_anon_doc") return anon if isinstance(anon, dict) else None @@ -140,18 +146,28 @@ class KBPostgresBackend(BackendProtocol): return path return path.rstrip("/") if path != "/" else path - def _moved_view_paths( + def _pending_filesystem_view( self, existing: dict[str, dict[str, Any]], - ) -> tuple[set[str], dict[str, str]]: - """Apply ``pending_moves`` to a path set and return ``(removed, alias)``. + ) -> tuple[set[str], dict[str, str], set[str]]: + """Compute removed/aliased/dir-suppressed paths from staged ops. - Removed paths should disappear from listings; ``alias[source] = dest`` - means a virtual entry should appear at ``dest`` even if no DB row is - yet there. + Returns ``(removed, alias, deleted_dirs)`` where: + + * ``removed`` — paths to drop from listings (sources of pending moves + AND paths queued for ``rm``). + * ``alias`` — ``{source: dest}`` for pending moves; the dest should + appear as a virtual entry even when no DB row is at that path yet. + * ``deleted_dirs`` — folder paths queued for ``rmdir``; their entire + subtree (descendants) is suppressed from listings/glob/grep. + + Entries in ``existing`` (the ``files`` state cache) keyed by a + removed path are popped so a same-turn delete-after-write doesn't + leave a stale virtual file in listings. """ removed: set[str] = set() alias: dict[str, str] = {} + deleted_dirs: set[str] = set() for move in self._pending_moves(): src = move.get("source") dst = move.get("dest") @@ -160,7 +176,23 @@ class KBPostgresBackend(BackendProtocol): removed.add(src) alias[src] = dst existing.pop(src, None) - return removed, alias + for entry in self._pending_deletes(): + path = entry.get("path") if isinstance(entry, dict) else None + if not path: + continue + removed.add(path) + existing.pop(path, None) + for entry in self._pending_dir_deletes(): + path = entry.get("path") if isinstance(entry, dict) else None + if not path: + continue + deleted_dirs.add(path) + return removed, alias, deleted_dirs + + @staticmethod + def _is_dir_suppressed(path: str, deleted_dirs: set[str]) -> bool: + """Return True iff ``path`` is at-or-under any directory in ``deleted_dirs``.""" + return any(path == d or _is_under(path, d) for d in deleted_dirs) # ------------------------------------------------------------------ ls/read @@ -189,7 +221,7 @@ class KBPostgresBackend(BackendProtocol): seen.add(anon_path) files = self._state_files() - moved_removed, moved_alias = self._moved_view_paths(files) + moved_removed, moved_alias, deleted_dirs = self._pending_filesystem_view(files) if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/": try: @@ -203,7 +235,12 @@ class KBPostgresBackend(BackendProtocol): for info in db_infos: p = info.get("path", "") - if not p or p in seen or p in moved_removed: + if ( + not p + or p in seen + or p in moved_removed + or self._is_dir_suppressed(p, deleted_dirs) + ): continue infos.append(info) seen.add(p) @@ -212,6 +249,8 @@ class KBPostgresBackend(BackendProtocol): if src not in seen: if not _is_under(dst, normalized): continue + if self._is_dir_suppressed(dst, deleted_dirs): + continue rel = ( dst[len(normalized) :].lstrip("/") if normalized != "/" @@ -247,6 +286,8 @@ class KBPostgresBackend(BackendProtocol): continue if not _is_under(staged, normalized): continue + if self._is_dir_suppressed(staged, deleted_dirs): + continue rel = ( staged[len(normalized) :].lstrip("/") if normalized != "/" @@ -265,14 +306,26 @@ class KBPostgresBackend(BackendProtocol): for sub in sorted(subdir_paths): if sub in seen: continue + if self._is_dir_suppressed(sub, deleted_dirs): + continue infos.append(FileInfo(path=sub, is_dir=True, size=0, modified_at="")) seen.add(sub) for path_key, fd in files.items(): if not isinstance(path_key, str) or path_key in seen: continue + # Tombstones (None values) are deletion markers from `rm`. The + # deepagents reducer normally pops them, but a stale tombstone + # surviving a checkpoint must NOT be reported as a child here — + # otherwise rmdir mistakenly sees the deleted file as content. + if fd is None: + continue if not _is_under(path_key, normalized) or path_key == normalized: continue + if path_key in moved_removed or self._is_dir_suppressed( + path_key, deleted_dirs + ): + continue if normalized == "/": rel = path_key.lstrip("/") else: @@ -550,10 +603,12 @@ class KBPostgresBackend(BackendProtocol): seen: set[str] = set() files = self._state_files() - moved_removed, _ = self._moved_view_paths(files) + moved_removed, _, deleted_dirs = self._pending_filesystem_view(files) regex = re.compile(fnmatch.translate(pattern)) for path_key, fd in files.items(): - if path_key in moved_removed: + if path_key in moved_removed or self._is_dir_suppressed( + path_key, deleted_dirs + ): continue if not _is_under(path_key, normalized): continue @@ -595,7 +650,11 @@ class KBPostgresBackend(BackendProtocol): folder_id=row.folder_id, index=index, ) - if candidate in seen or candidate in moved_removed: + if ( + candidate in seen + or candidate in moved_removed + or self._is_dir_suppressed(candidate, deleted_dirs) + ): continue if not _is_under(candidate, normalized): continue @@ -634,10 +693,12 @@ class KBPostgresBackend(BackendProtocol): matches: list[GrepMatch] = [] files = self._state_files() - moved_removed, _ = self._moved_view_paths(files) + moved_removed, _, deleted_dirs = self._pending_filesystem_view(files) glob_re = re.compile(fnmatch.translate(glob)) if glob else None for path_key, fd in files.items(): - if path_key in moved_removed: + if path_key in moved_removed or self._is_dir_suppressed( + path_key, deleted_dirs + ): continue if not _is_under(path_key, normalized): continue @@ -695,7 +756,11 @@ class KBPostgresBackend(BackendProtocol): ) for doc_id, chunk_id, content in chunk_buffer: candidate = doc_id_to_path.get(doc_id) - if not candidate or candidate in moved_removed: + if ( + not candidate + or candidate in moved_removed + or self._is_dir_suppressed(candidate, deleted_dirs) + ): continue if not _is_under(candidate, normalized): continue @@ -769,7 +834,7 @@ class KBPostgresBackend(BackendProtocol): return {"entries": [], "truncated": False} files = self._state_files() - moved_removed, _ = self._moved_view_paths(files) + moved_removed, _, deleted_dirs = self._pending_filesystem_view(files) anon = self._kb_anon_doc() anon_path = str(anon.get("path") or "") if anon else "" @@ -795,6 +860,8 @@ class KBPostgresBackend(BackendProtocol): for _fid, fpath in sorted(index.folder_paths.items(), key=lambda kv: kv[1]): if not _is_under(fpath, normalized): continue + if self._is_dir_suppressed(fpath, deleted_dirs): + continue depth = _depth_of(fpath) if max_depth is not None and depth > max_depth: continue @@ -811,6 +878,8 @@ class KBPostgresBackend(BackendProtocol): for staged in self._staged_dirs(): if not _is_under(staged, normalized): continue + if self._is_dir_suppressed(staged, deleted_dirs): + continue depth = _depth_of(staged) if max_depth is not None and depth > max_depth: continue @@ -835,7 +904,9 @@ class KBPostgresBackend(BackendProtocol): folder_id=row.folder_id, index=index, ) - if candidate in moved_removed: + if candidate in moved_removed or self._is_dir_suppressed( + candidate, deleted_dirs + ): continue if not _is_under(candidate, normalized): continue @@ -875,6 +946,10 @@ class KBPostgresBackend(BackendProtocol): continue if not _is_under(path_key, normalized): continue + if path_key in moved_removed or self._is_dir_suppressed( + path_key, deleted_dirs + ): + continue if any(e["path"] == path_key for e in entries): continue if not ( diff --git a/surfsense_backend/app/agents/new_chat/middleware/knowledge_tree.py b/surfsense_backend/app/agents/new_chat/middleware/knowledge_tree.py index 467d19747..e67be8221 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/knowledge_tree.py +++ b/surfsense_backend/app/agents/new_chat/middleware/knowledge_tree.py @@ -201,6 +201,12 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg] ) all_paths = sorted(set(folder_paths + doc_paths + [DOCUMENTS_ROOT])) + # Pre-compute which folders have at least one descendant (folder or doc). + # A folder is "empty" iff no path in `all_paths` is strictly under it. + # Used to emit an explicit "(empty)" marker so the LLM doesn't have to + # infer emptiness from indentation alone. + non_empty_folders = self._compute_non_empty_folders(folder_paths, doc_paths) + lines: list[str] = [] for path in all_paths: depth = ( @@ -214,7 +220,10 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg] path.rsplit("/", 1)[-1] if path != DOCUMENTS_ROOT else "/documents" ) if is_dir: - lines.append(f"{indent}{display}/") + if path != DOCUMENTS_ROOT and path not in non_empty_folders: + lines.append(f"{indent}{display}/ (empty)") + else: + lines.append(f"{indent}{display}/") else: lines.append(f"{indent}{display}") if len(lines) >= self.max_entries: @@ -235,6 +244,35 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg] return self._format_root_summary(folder_paths, doc_paths) + @staticmethod + def _compute_non_empty_folders( + folder_paths: list[str], doc_paths: list[str] + ) -> set[str]: + """Return the set of folder paths that contain at least one descendant. + + A folder is "non-empty" if any document path or any other folder path + is strictly under it. Documents propagate emptiness up to every + ancestor folder, while a sub-folder only marks its direct ancestors + non-empty (so a chain of empty folders all read ``(empty)``). + """ + non_empty: set[str] = set() + folder_set = set(folder_paths) + + for doc_path in doc_paths: + parent = doc_path.rsplit("/", 1)[0] + while parent and parent != DOCUMENTS_ROOT: + if parent in folder_set: + non_empty.add(parent) + parent = parent.rsplit("/", 1)[0] + + for child in folder_paths: + parent = child.rsplit("/", 1)[0] + while parent and parent != DOCUMENTS_ROOT and parent in folder_set: + non_empty.add(parent) + parent = parent.rsplit("/", 1)[0] + + return non_empty + def _format_root_summary( self, folder_paths: list[str], doc_paths: list[str] ) -> str: diff --git a/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py b/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py index 565fcb48b..4db9943cb 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py +++ b/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py @@ -360,6 +360,74 @@ class LocalFolderBackend: self.move, source_path, destination_path, overwrite ) + def delete_file(self, file_path: str) -> WriteResult: + """Hard-delete a single file under root. + + Refuses directories, root, and missing paths. Roughly mirrors POSIX + ``rm path``; ``-r`` recursion and glob expansion are explicitly + out of scope. + """ + try: + path = self._resolve_virtual(file_path) + except ValueError: + return WriteResult(error=f"Error: Invalid path '{file_path}'") + with self._lock_for(file_path): + if not path.exists(): + return WriteResult(error=f"Error: File '{file_path}' not found") + if path.is_dir(): + return WriteResult( + error=( + f"Error: '{file_path}' is a directory. " + "Use rmdir for empty directories." + ) + ) + try: + os.unlink(path) + except OSError as exc: + return WriteResult( + error=f"Error: failed to delete '{file_path}': {exc}" + ) + return WriteResult(path=file_path, files_update=None) + + async def adelete_file(self, file_path: str) -> WriteResult: + return await asyncio.to_thread(self.delete_file, file_path) + + def rmdir(self, dir_path: str) -> WriteResult: + """Hard-delete an empty directory under root. + + Refuses files, root, missing paths, and non-empty directories. + ``os.rmdir`` is naturally empty-only; we pre-check so the error is + clearer for the agent. + """ + try: + path = self._resolve_virtual(dir_path) + except ValueError: + return WriteResult(error=f"Error: Invalid path '{dir_path}'") + with self._lock_for(dir_path): + if not path.exists(): + return WriteResult(error=f"Error: Directory '{dir_path}' not found") + if not path.is_dir(): + return WriteResult(error=f"Error: '{dir_path}' is not a directory") + try: + next(path.iterdir()) + except StopIteration: + pass + else: + return WriteResult( + error=( + f"Error: directory '{dir_path}' is not empty. " + "Remove its contents first." + ) + ) + try: + os.rmdir(path) + except OSError as exc: + return WriteResult(error=f"Error: failed to rmdir '{dir_path}': {exc}") + return WriteResult(path=dir_path, files_update=None) + + async def armdir(self, dir_path: str) -> WriteResult: + return await asyncio.to_thread(self.rmdir, dir_path) + def edit( self, file_path: str, diff --git a/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py b/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py index 93eabe6ff..a5add6248 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py +++ b/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py @@ -285,6 +285,34 @@ class MultiRootLocalFolderBackend: overwrite, ) + def delete_file(self, file_path: str) -> WriteResult: + try: + mount, local_path = self._split_mount_path(file_path) + except ValueError as exc: + return WriteResult(error=f"Error: {exc}") + result = self._mount_to_backend[mount].delete_file(local_path) + if result.path: + result.path = self._prefix_mount_path(mount, result.path) + return result + + async def adelete_file(self, file_path: str) -> WriteResult: + return await asyncio.to_thread(self.delete_file, file_path) + + def rmdir(self, dir_path: str) -> WriteResult: + try: + mount, local_path = self._split_mount_path(dir_path) + except ValueError as exc: + return WriteResult(error=f"Error: {exc}") + if local_path == "/": + return WriteResult(error=f"Error: cannot rmdir mount root '{dir_path}'") + result = self._mount_to_backend[mount].rmdir(local_path) + if result.path: + result.path = self._prefix_mount_path(mount, result.path) + return result + + async def armdir(self, dir_path: str) -> WriteResult: + return await asyncio.to_thread(self.rmdir, dir_path) + def edit( self, file_path: str, diff --git a/surfsense_backend/app/agents/new_chat/state_reducers.py b/surfsense_backend/app/agents/new_chat/state_reducers.py index ce32406e6..89fc86367 100644 --- a/surfsense_backend/app/agents/new_chat/state_reducers.py +++ b/surfsense_backend/app/agents/new_chat/state_reducers.py @@ -181,9 +181,13 @@ def _initial_filesystem_state() -> dict[str, Any]: return { "cwd": "/documents", "staged_dirs": [], + "staged_dir_tool_calls": {}, "pending_moves": [], + "pending_deletes": [], + "pending_dir_deletes": [], "doc_id_by_path": {}, "dirty_paths": [], + "dirty_path_tool_calls": {}, "kb_priority": [], "kb_matched_chunk_ids": {}, "kb_anon_doc": None, diff --git a/surfsense_backend/app/agents/new_chat/subagents/config.py b/surfsense_backend/app/agents/new_chat/subagents/config.py index b36d35fa0..84ca516e0 100644 --- a/surfsense_backend/app/agents/new_chat/subagents/config.py +++ b/surfsense_backend/app/agents/new_chat/subagents/config.py @@ -84,6 +84,8 @@ WRITE_TOOL_DENY_PATTERNS: tuple[str, ...] = ( "write_file", "move_file", "mkdir", + "rm", + "rmdir", "update_memory", "update_memory_team", "update_memory_private", diff --git a/surfsense_backend/app/agents/new_chat/tools/hitl.py b/surfsense_backend/app/agents/new_chat/tools/hitl.py index 8480e57b1..92248c2c9 100644 --- a/surfsense_backend/app/agents/new_chat/tools/hitl.py +++ b/surfsense_backend/app/agents/new_chat/tools/hitl.py @@ -30,6 +30,35 @@ from langgraph.types import interrupt logger = logging.getLogger(__name__) +# Tools that mirror the safety profile of ``write_file`` against the +# SurfSense KB: each call creates ONE artifact in the user's own workspace +# with no external visibility (drafts aren't sent; new files aren't shared +# unless the user shares them later). These are auto-approved by default +# so the agent can compose drafts and seed scratch files without a popup +# on every call. +# +# Members of this set still call ``request_approval`` exactly as before; +# the function returns immediately with ``decision_type="auto_approved"`` +# and the original params untouched. This preserves the call-site shape +# (logging, metadata fetching, account fallbacks) so the only behavior +# change is "no interrupt fires". +# +# To re-enable prompting, the future per-search-space rules table +# (``agent_permission_rules``) takes precedence — see the ``# (future)`` +# layer-3 comment in :mod:`app.agents.new_chat.chat_deepagent`. +DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset( + { + "create_gmail_draft", + "update_gmail_draft", + "create_notion_page", + "create_confluence_page", + "create_google_drive_file", + "create_dropbox_file", + "create_onedrive_file", + } +) + + @dataclass(frozen=True, slots=True) class HITLResult: """Outcome of a human-in-the-loop approval request.""" @@ -119,6 +148,19 @@ def request_approval( logger.info("Tool '%s' is user-trusted — skipping HITL", tool_name) return HITLResult(rejected=False, decision_type="trusted", params=dict(params)) + if tool_name in DEFAULT_AUTO_APPROVED_TOOLS: + # Default policy: low-stakes creation tools (drafts + new-file + # creates) skip HITL because they're as recoverable as a local + # ``write_file`` against the SurfSense KB. The user can still + # delete the artifact in <30s if it's wrong. + logger.info( + "Tool '%s' is in DEFAULT_AUTO_APPROVED_TOOLS — skipping HITL", + tool_name, + ) + return HITLResult( + rejected=False, decision_type="auto_approved", params=dict(params) + ) + approval = interrupt( { "type": action_type, diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index f8b1390d9..ca3334f8b 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -696,6 +696,12 @@ class NewChatMessage(BaseModel, TimestampMixin): index=True, ) + # Per-turn correlation id sourced from ``configurable.turn_id`` at + # streaming time (``f"{chat_id}:{ms}"``). Nullable because legacy rows + # predate the column. Used by C1's edit-from-arbitrary-position to map + # a message back to the LangGraph checkpoint that produced its turn. + turn_id = Column(String(64), nullable=True, index=True) + # Relationships thread = relationship("NewChatThread", back_populates="messages") author = relationship("User") @@ -2299,7 +2305,13 @@ class AgentActionLog(BaseModel): nullable=False, index=True, ) + # ``turn_id`` historically held the LangChain ``tool_call.id``. It has + # been renamed to ``tool_call_id`` (with a parallel column kept for one + # release for back-compat). The real chat-turn id lives in + # ``chat_turn_id`` and is sourced from ``configurable.turn_id``. turn_id = Column(String(64), nullable=True, index=True) + tool_call_id = Column(String(64), nullable=True, index=True) + chat_turn_id = Column(String(64), nullable=True, index=True) message_id = Column(String(128), nullable=True, index=True) tool_name = Column(String(255), nullable=False, index=True) args = Column(JSONB, nullable=True) @@ -2325,6 +2337,16 @@ class AgentActionLog(BaseModel): __table_args__ = ( Index("ix_agent_action_log_thread_created", "thread_id", "created_at"), + # Partial unique index enforces "at most one revert per + # original action". Created in migration 137 with + # ``WHERE reverse_of IS NOT NULL`` so non-revert rows + # (the vast majority) are unaffected and NULLs don't collide. + Index( + "ux_agent_action_log_reverse_of", + "reverse_of", + unique=True, + postgresql_where=text("reverse_of IS NOT NULL"), + ), ) @@ -2339,10 +2361,13 @@ class DocumentRevision(BaseModel): __tablename__ = "document_revisions" + # ``ON DELETE SET NULL`` (not CASCADE) so the snapshot survives the + # hard-delete it describes — without that, ``rm`` would wipe the row + # we'd need to undo it. See migration ``134_relax_revision_fks``. document_id = Column( Integer, - ForeignKey("documents.id", ondelete="CASCADE"), - nullable=False, + ForeignKey("documents.id", ondelete="SET NULL"), + nullable=True, index=True, ) search_space_id = Column( @@ -2377,10 +2402,13 @@ class FolderRevision(BaseModel): __tablename__ = "folder_revisions" + # ``ON DELETE SET NULL`` (not CASCADE) so the snapshot survives the + # hard-delete it describes — without that, ``rmdir`` would wipe the + # row we'd need to undo it. See migration ``134_relax_revision_fks``. folder_id = Column( Integer, - ForeignKey("folders.id", ondelete="CASCADE"), - nullable=False, + ForeignKey("folders.id", ondelete="SET NULL"), + nullable=True, index=True, ) search_space_id = Column( diff --git a/surfsense_backend/app/routes/agent_action_log_route.py b/surfsense_backend/app/routes/agent_action_log_route.py index 458635761..2608aa3b1 100644 --- a/surfsense_backend/app/routes/agent_action_log_route.py +++ b/surfsense_backend/app/routes/agent_action_log_route.py @@ -65,6 +65,13 @@ class AgentActionRead(BaseModel): reverse_of: int | None reverted_by_action_id: int | None is_revert_action: bool + # Correlation ids added in migration 135. ``tool_call_id`` is the + # LangChain tool-call id (joinable to ``data-action-log`` SSE events + # via ``langchainToolCallId``). ``chat_turn_id`` is the per-turn id + # from ``configurable.turn_id`` (used by the + # ``revert-turn/{chat_turn_id}`` endpoint). + tool_call_id: str | None = None + chat_turn_id: str | None = None created_at: datetime @@ -172,6 +179,8 @@ async def list_thread_actions( reverse_of=row.reverse_of, reverted_by_action_id=revert_map.get(row.id), is_revert_action=row.reverse_of is not None, + tool_call_id=row.tool_call_id, + chat_turn_id=row.chat_turn_id, created_at=row.created_at, ) for row in rows diff --git a/surfsense_backend/app/routes/agent_revert_route.py b/surfsense_backend/app/routes/agent_revert_route.py index 12484ff53..711081b15 100644 --- a/surfsense_backend/app/routes/agent_revert_route.py +++ b/surfsense_backend/app/routes/agent_revert_route.py @@ -11,14 +11,25 @@ flag flips. Once enabled, the route runs: 4. Revert dispatch via :func:`app.services.revert_service.revert_action`. 5. Idempotent on retries: if the same action is reverted twice the second call returns 409 ``"already reverted"``. + +This module also hosts the per-turn batch endpoint +``POST /api/threads/{thread_id}/revert-turn/{chat_turn_id}``. It +walks every reversible action emitted during a chat turn in reverse +``created_at`` order and reverts each independently. Partial success is the +common case — the response always contains a per-action result list and a +``status`` of ``"ok"`` or ``"partial"``; we never collapse the batch into a +whole-batch 4xx. """ from __future__ import annotations import logging +from typing import Literal from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel from sqlalchemy import select +from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.feature_flags import get_flags @@ -97,6 +108,16 @@ async def revert_agent_action( action=action, requester_user_id=str(user.id) if user is not None else None, ) + except IntegrityError: + # Partial unique index ``ux_agent_action_log_reverse_of`` caught + # a concurrent revert. Translate to the existing 409 "already + # reverted" contract so racing clients see consistent + # behaviour with the pre-flight TOCTOU check above. + await session.rollback() + raise HTTPException( + status_code=409, + detail="This action has already been reverted.", + ) from None except Exception as err: logger.exception("Revert dispatch raised for action_id=%s", action_id) await session.rollback() @@ -105,7 +126,16 @@ async def revert_agent_action( ) from err if outcome.status == "ok": - await session.commit() + try: + await session.commit() + except IntegrityError: + # Race lost on commit (constraint enforced at flush in some + # configs but at commit in others — defensive). + await session.rollback() + raise HTTPException( + status_code=409, + detail="This action has already been reverted.", + ) from None return { "status": "ok", "message": outcome.message, @@ -122,3 +152,357 @@ async def revert_agent_action( raise HTTPException(status_code=501, detail=outcome.message) # not_reversible raise HTTPException(status_code=409, detail=outcome.message) + + +# --------------------------------------------------------------------------- +# Per-turn revert batch endpoint +# --------------------------------------------------------------------------- + + +PerActionStatus = Literal[ + "reverted", + "already_reverted", + "not_reversible", + "permission_denied", + "failed", + "skipped", +] + + +class RevertTurnActionResult(BaseModel): + """Per-action outcome inside a ``revert-turn`` batch response.""" + + action_id: int + tool_name: str + status: PerActionStatus + message: str | None = None + new_action_id: int | None = None + error: str | None = None + + +class RevertTurnResponse(BaseModel): + """Top-level response for ``POST /threads/{id}/revert-turn/{chat_turn_id}``. + + ``status`` is ``"ok"`` only when every reversible row succeeded. Any + ``failed`` / ``not_reversible`` / ``permission_denied`` entry downgrades + it to ``"partial"``. Empty turns (no rows) return ``"ok"`` with an empty + ``results`` list — callers should treat that as a no-op. + + Counter invariant: + ``total == reverted + already_reverted + not_reversible + + permission_denied + failed + skipped`` + + Frontend toasts and the ``RevertTurnButton`` summary rely on this + invariant to display "X of Y reverted, Z could not be undone" without + silently dropping ``permission_denied`` or ``skipped`` rows. + """ + + status: Literal["ok", "partial"] + chat_turn_id: str + total: int + reverted: int + already_reverted: int + not_reversible: int + permission_denied: int = 0 + failed: int = 0 + skipped: int = 0 + results: list[RevertTurnActionResult] + + +def _classify_outcome(outcome: RevertOutcome) -> PerActionStatus: + if outcome.status == "ok": + return "reverted" + if outcome.status == "permission_denied": + return "permission_denied" + # ``not_found`` / ``tool_unavailable`` / ``reverse_not_implemented`` / + # ``not_reversible`` are all surfaced to the caller as "not_reversible" + # — they share the same UX (this row cannot be undone) and only the + # ``message`` differs. + return "not_reversible" + + +async def _was_already_reverted(session: AsyncSession, *, action_id: int) -> int | None: + """Return the id of an existing successful revert row, if any. + + Single-action variant — kept for the post-IntegrityError lookup + path where we already know we lost a race for one specific id. + """ + stmt = select(AgentActionLog.id).where(AgentActionLog.reverse_of == action_id) + result = await session.execute(stmt) + return result.scalars().first() + + +async def _was_already_reverted_batch( + session: AsyncSession, *, action_ids: list[int] +) -> dict[int, int]: + """Batch idempotency probe for the revert-turn loop. + + Replaces N individual ``SELECT id WHERE reverse_of = :id`` queries + (one per row in the turn) with a single ``SELECT id, reverse_of + WHERE reverse_of IN (:ids)``. The route still iterates rows in + reverse-chronological order, but the membership check is O(1) per + iteration after this query. For a turn with 30 actions that's 30 + fewer round-trips through asyncpg + a smaller transaction footprint. + + Returns a ``{original_action_id -> revert_action_id}`` map. Missing + keys mean "not yet reverted" — callers should treat them as + eligible for revert. + """ + if not action_ids: + return {} + stmt = select(AgentActionLog.id, AgentActionLog.reverse_of).where( + AgentActionLog.reverse_of.in_(action_ids) + ) + result = await session.execute(stmt) + return { + original_id: revert_id + for revert_id, original_id in result.all() + if original_id is not None + } + + +@router.post( + "/threads/{thread_id}/revert-turn/{chat_turn_id}", + response_model=RevertTurnResponse, +) +async def revert_agent_turn( + thread_id: int, + chat_turn_id: str, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +) -> RevertTurnResponse: + """Revert every reversible action emitted during ``chat_turn_id``. + + Walks ``AgentActionLog`` rows for the turn in reverse ``created_at`` + order so dependencies (e.g. ``mkdir`` -> ``write_file`` inside the new + folder) unwind in the right sequence. Each action is reverted in its + own SAVEPOINT so a single failure does not poison the batch. + + Partial success is intentional and returned with HTTP 200. Callers + must inspect ``results[*].status`` to find rows that need attention. + """ + + flags = get_flags() + if flags.disable_new_agent_stack or not flags.enable_revert_route: + raise HTTPException( + status_code=503, + detail=( + "Revert is not available on this deployment yet. The route " + "ships before the UI; flip SURFSENSE_ENABLE_REVERT_ROUTE to " + "enable it." + ), + ) + + thread = await load_thread(session, thread_id=thread_id) + if thread is None: + raise HTTPException(status_code=404, detail="Thread not found.") + + # Reverse-chronological so the latest mutation in the turn unwinds + # first. ``id.desc()`` is the deterministic tiebreaker for actions + # written in the same millisecond. + rows_stmt = ( + select(AgentActionLog) + .where( + AgentActionLog.thread_id == thread_id, + AgentActionLog.chat_turn_id == chat_turn_id, + ) + .order_by(AgentActionLog.created_at.desc(), AgentActionLog.id.desc()) + ) + rows = (await session.execute(rows_stmt)).scalars().all() + + requester_user_id = str(user.id) if user is not None else None + results: list[RevertTurnActionResult] = [] + # Counters MUST be exhaustive so the response invariant + # ``total == sum(counters)`` always holds. Frontend toasts and + # ``RevertTurnButton`` rely on this for "X of Y reverted" math. + counts: dict[str, int] = { + "reverted": 0, + "already_reverted": 0, + "not_reversible": 0, + "permission_denied": 0, + "failed": 0, + "skipped": 0, + } + + # Single batched idempotency probe replaces the previous per-row + # SELECT. ``rows`` are filtered in the loop so we pre-collect only + # the original-action ids (skip rows that are themselves + # reverts). + eligible_ids = [r.id for r in rows if r.reverse_of is None] + already_reverted_map = await _was_already_reverted_batch( + session, action_ids=eligible_ids + ) + + for action in rows: + # Skip rows that ARE reverts of an earlier action — reverting a + # revert is meaningless inside a batch (the user wants to wipe + # the original effects, not chase tail). + if action.reverse_of is not None: + counts["skipped"] += 1 + results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="skipped", + message="Row is itself a revert action; skipped.", + ) + ) + continue + + # Idempotency: surface "already_reverted" instead of failing. + existing_revert_id = already_reverted_map.get(action.id) + if existing_revert_id is not None: + counts["already_reverted"] += 1 + results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="already_reverted", + new_action_id=existing_revert_id, + ) + ) + continue + + if not can_revert( + requester_user_id=requester_user_id, + action=action, + is_admin=False, + ): + counts["permission_denied"] += 1 + results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="permission_denied", + message="You are not allowed to revert this action.", + ) + ) + continue + + # Per-row SAVEPOINT so one failed revert never poisons later + # successful ones. + try: + async with session.begin_nested(): + outcome = await revert_action( + session, + action=action, + requester_user_id=requester_user_id, + ) + if outcome.status != "ok": + raise _OutcomeRollbackError(outcome) + except _OutcomeRollbackError as rollback: + outcome = rollback.outcome + classified = _classify_outcome(outcome) + if classified == "permission_denied": + counts["permission_denied"] += 1 + else: + counts["not_reversible"] += 1 + results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status=classified, + message=outcome.message, + ) + ) + continue + except IntegrityError: + # Partial unique index caught a concurrent revert that won + # the race against our pre-flight ``_was_already_reverted`` + # SELECT. Look up the winner so + # we can surface its ``new_action_id`` to the client. + existing_revert_id = await _was_already_reverted( + session, action_id=action.id + ) + counts["already_reverted"] += 1 + results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="already_reverted", + new_action_id=existing_revert_id, + ) + ) + continue + except Exception as err: # pragma: no cover — defensive, logged + logger.exception( + "Unexpected revert failure inside batch for action_id=%s", + action.id, + ) + counts["failed"] += 1 + results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="failed", + error=str(err) or err.__class__.__name__, + ) + ) + continue + + counts["reverted"] += 1 + results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="reverted", + message=outcome.message, + new_action_id=outcome.new_action_id, + ) + ) + + # Single commit at the end — successful SAVEPOINTs above already + # released; failed ones rolled back to their savepoint. No row leaks + # across the boundary. + try: + await session.commit() + except Exception as err: # pragma: no cover — defensive + logger.exception( + "Final commit for revert-turn failed (thread=%s turn=%s)", + thread_id, + chat_turn_id, + ) + await session.rollback() + raise HTTPException( + status_code=500, + detail="Internal error while finalising revert-turn batch.", + ) from err + + has_partial = ( + counts["failed"] > 0 + or counts["not_reversible"] > 0 + or counts["permission_denied"] > 0 + ) + overall_status: Literal["ok", "partial"] = "partial" if has_partial else "ok" + + return RevertTurnResponse( + status=overall_status, + chat_turn_id=chat_turn_id, + total=len(rows), + reverted=counts["reverted"], + already_reverted=counts["already_reverted"], + not_reversible=counts["not_reversible"], + permission_denied=counts["permission_denied"], + failed=counts["failed"], + skipped=counts["skipped"], + results=results, + ) + + +class _OutcomeRollbackError(Exception): + """Sentinel raised inside the SAVEPOINT to roll back a non-OK outcome. + + ``revert_action`` writes a new ``agent_action_log`` row only on the + happy path, but on the failure paths it sometimes mutates the + ``DocumentRevision``/``Document`` tables before deciding the action + is not reversible. Wrapping each call in ``begin_nested`` and raising + this from the failure branch ensures we always discard partial + writes for failed rows. + """ + + def __init__(self, outcome: RevertOutcome) -> None: + self.outcome = outcome + super().__init__(outcome.message) + + +__all__ = ["router"] diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 0189dd139..e04cce1b5 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -11,6 +11,7 @@ These endpoints support the ThreadHistoryAdapter pattern from assistant-ui: """ import asyncio +import json import logging from datetime import UTC, datetime @@ -136,6 +137,260 @@ def _resolve_filesystem_selection( ) +def _find_pre_turn_checkpoint_id( + checkpoint_tuples: list, + *, + turn_id: str, +) -> str | None: + """Locate the LangGraph checkpoint immediately before ``turn_id`` started. + + ``checkpoint_tuples`` arrives newest-first from + ``checkpointer.alist(config)``. We walk OLDEST-first (``reversed``) + and remember the most recent checkpoint that does NOT belong to the + edited turn. As soon as we cross into the edited turn (a checkpoint + whose ``turn_id`` matches), we return the previously-tracked + checkpoint — that's the state immediately before ``turn_id`` began. + + The naive "newest-first, return first non-matching" approach is + INCORRECT when later turns exist after ``turn_id``: their + checkpoints also satisfy ``cp_turn_id != turn_id`` and would be + returned before the real pre-turn boundary is reached. + + Reads from ``cp_tuple.metadata`` (the durable surface promoted from + ``configurable`` at write time) rather than ``config["configurable"]`` + so the lookup is portable across checkpointer implementations. + + Returns ``None`` when no eligible pre-turn checkpoint exists (e.g. + the edited turn is the very first turn of the thread). Callers fall + back to the oldest available checkpoint in that case. + """ + + last_pre_turn_target: str | None = None + for cp_tuple in reversed(checkpoint_tuples): # oldest -> newest + metadata = getattr(cp_tuple, "metadata", None) or {} + cp_turn_id = metadata.get("turn_id") if isinstance(metadata, dict) else None + if cp_turn_id == turn_id: + # Crossed into the edited turn; the previous tracked + # checkpoint is the rewind target. May be ``None`` if we hit + # the edited turn on the very first iteration. + return last_pre_turn_target + try: + last_pre_turn_target = cp_tuple.config["configurable"]["checkpoint_id"] + except (KeyError, TypeError): + continue + return last_pre_turn_target + + +async def _revert_turns_for_regenerate( + *, + thread_id: int, + chat_turn_ids: list[str], + requester_user_id: str, +) -> dict: + """Best-effort revert pass for every ``chat_turn_id`` in ``chat_turn_ids``. + + Runs BEFORE the regenerate stream so the frontend can surface + partial-rollback feedback alongside the new assistant turn. Each + turn's actions are reverted in their own SAVEPOINTs (handled + inside :mod:`app.routes.agent_revert_route`'s helpers) so a single + failure never poisons the batch. + + Sequencing inside the request: revert THEN regenerate. The + operation is NOT atomic and partial state IS surfaced — see the + plan's "Sequencing inside the request" note. + """ + + from app.routes.agent_revert_route import ( + RevertTurnActionResult, + _classify_outcome, + _OutcomeRollbackError, + _was_already_reverted, + _was_already_reverted_batch, + ) + from app.services.revert_service import ( + can_revert, + revert_action, + ) + + aggregated_results: list[dict] = [] + # Exhaustive counters keep the response invariant + # ``total == sum(counters)`` true for ``data-revert-results``. + counts = { + "reverted": 0, + "already_reverted": 0, + "not_reversible": 0, + "permission_denied": 0, + "failed": 0, + "skipped": 0, + } + + # Local import keeps the route module's existing imports tidy and + # avoids a circular dependency at module-load time. + from app.db import AgentActionLog as _AgentActionLog + + async with shielded_async_session() as session: + for chat_turn_id in chat_turn_ids: + rows_stmt = ( + select(_AgentActionLog) + .where( + _AgentActionLog.thread_id == thread_id, + _AgentActionLog.chat_turn_id == chat_turn_id, + ) + .order_by( + _AgentActionLog.created_at.desc(), + _AgentActionLog.id.desc(), + ) + ) + rows = (await session.execute(rows_stmt)).scalars().all() + + # Batch idempotency probe across the turn (single SELECT + # instead of one per row). + eligible_ids = [r.id for r in rows if r.reverse_of is None] + already_reverted_map = await _was_already_reverted_batch( + session, action_ids=eligible_ids + ) + + for action in rows: + if action.reverse_of is not None: + counts["skipped"] += 1 + aggregated_results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="skipped", + message="Row is itself a revert action; skipped.", + ).model_dump() + ) + continue + + existing_revert_id = already_reverted_map.get(action.id) + if existing_revert_id is not None: + counts["already_reverted"] += 1 + aggregated_results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="already_reverted", + new_action_id=existing_revert_id, + ).model_dump() + ) + continue + + if not can_revert( + requester_user_id=requester_user_id, + action=action, + is_admin=False, + ): + counts["permission_denied"] += 1 + aggregated_results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="permission_denied", + message="You are not allowed to revert this action.", + ).model_dump() + ) + continue + + try: + async with session.begin_nested(): + outcome = await revert_action( + session, + action=action, + requester_user_id=requester_user_id, + ) + if outcome.status != "ok": + raise _OutcomeRollbackError(outcome) + except _OutcomeRollbackError as rollback: + outcome = rollback.outcome + classified = _classify_outcome(outcome) + if classified == "permission_denied": + counts["permission_denied"] += 1 + else: + counts["not_reversible"] += 1 + aggregated_results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status=classified, + message=outcome.message, + ).model_dump() + ) + continue + except IntegrityError: + # Concurrent revert won the race against the + # pre-flight ``_was_already_reverted`` SELECT. + # Surface the winning revert id so the client can + # treat this as a successful idempotent op. + existing_revert_id = await _was_already_reverted( + session, action_id=action.id + ) + counts["already_reverted"] += 1 + aggregated_results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="already_reverted", + new_action_id=existing_revert_id, + ).model_dump() + ) + continue + except Exception as err: # pragma: no cover — defensive + _logger.exception( + "Unexpected revert failure during regenerate batch " + "for action_id=%s", + action.id, + ) + counts["failed"] += 1 + aggregated_results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="failed", + error=str(err) or err.__class__.__name__, + ).model_dump() + ) + continue + + counts["reverted"] += 1 + aggregated_results.append( + RevertTurnActionResult( + action_id=action.id, + tool_name=action.tool_name, + status="reverted", + message=outcome.message, + new_action_id=outcome.new_action_id, + ).model_dump() + ) + + try: + await session.commit() + except Exception: + _logger.exception( + "[regenerate-revert] Final commit failed; rolling back batch." + ) + await session.rollback() + + has_partial = ( + counts["failed"] > 0 + or counts["not_reversible"] > 0 + or counts["permission_denied"] > 0 + ) + + return { + "status": "partial" if has_partial else "ok", + "chat_turn_ids": chat_turn_ids, + "total": len(aggregated_results), + "reverted": counts["reverted"], + "already_reverted": counts["already_reverted"], + "not_reversible": counts["not_reversible"], + "permission_denied": counts["permission_denied"], + "failed": counts["failed"], + "skipped": counts["skipped"], + "results": aggregated_results, + } + + def _try_delete_sandbox(thread_id: int) -> None: """Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked.""" from app.agents.new_chat.sandbox import ( @@ -574,6 +829,7 @@ async def get_thread_messages( token_usage=TokenUsageSummary.model_validate(msg.token_usage) if msg.token_usage else None, + turn_id=msg.turn_id, ) for msg in db_messages ] @@ -1006,12 +1262,24 @@ async def append_message( # Check thread-level access based on visibility await check_thread_access(session, thread, user) - # Create message + # Create message. ``turn_id`` is the per-turn correlation id from + # ``configurable.turn_id`` (added in migration 136) — when the + # client streams it back to ``appendMessage``, we persist it so + # C1's edit-from-arbitrary-position can later map this message + # back to the LangGraph checkpoint that produced its turn. + raw_turn_id = raw_body.get("turn_id") + turn_id_value = ( + str(raw_turn_id).strip() + if isinstance(raw_turn_id, str) and raw_turn_id.strip() + else None + ) + db_message = NewChatMessage( thread_id=thread_id, role=message_role, content=content, author_id=user.id, + turn_id=turn_id_value, ) session.add(db_message) @@ -1050,6 +1318,7 @@ async def append_message( created_at=db_message.created_at, author_id=db_message.author_id, token_usage=None, + turn_id=db_message.turn_id, ) except HTTPException: @@ -1373,43 +1642,123 @@ async def regenerate_response( user_query_to_use = request.user_query regenerate_image_urls: list[str] = [] - # Look through checkpoints to find the right one - # We want to find the checkpoint just before the last HumanMessage - for i, cp_tuple in enumerate(checkpoint_tuples): - # Access the checkpoint's channel_values which contains "messages" - checkpoint_data = cp_tuple.checkpoint - channel_values = checkpoint_data.get("channel_values", {}) - state_messages = channel_values.get("messages", []) + # --------------------------------------------------------------- + # Edit-from-arbitrary-position. When the client passes + # ``from_message_id`` we look up its persisted ``turn_id`` (added + # in migration 136) and pick the checkpoint immediately before + # that turn started. + # + # Legacy graceful-degradation contract: + # * Rows persisted BEFORE migration 136 have ``turn_id IS NULL``. + # Returning 400 in that case is the wrong UX — the user is + # editing an old message in an existing thread and just wants + # it to work. We instead skip the checkpoint rewind (the + # stream falls back to the latest state) and skip the revert + # pass (no chat_turn_id available to walk). Deletion still + # uses ``created_at``, so the messages-after-cursor slice is + # correct on both legacy and post-136 rows. + # --------------------------------------------------------------- + from_message_turn_id: str | None = None + from_message_created_at: datetime | None = None + legacy_from_message: bool = False + if request.from_message_id is not None: + from_msg_row = await session.execute( + select(NewChatMessage).filter( + NewChatMessage.id == request.from_message_id, + NewChatMessage.thread_id == thread_id, + ) + ) + from_msg = from_msg_row.scalars().first() + if from_msg is None: + raise HTTPException( + status_code=404, + detail="from_message_id not found in this thread.", + ) + from_message_created_at = from_msg.created_at + if not from_msg.turn_id: + # Legacy row — surface the degradation in logs but let + # the request proceed with the slice-based delete and a + # cold-start checkpoint. + legacy_from_message = True + _logger.warning( + "[regenerate] from_message_id=%s on thread=%s has no " + "turn_id (legacy row pre-migration-136). Falling back " + "to slice-based delete without checkpoint rewind. " + "revert_actions=%s will be ignored.", + request.from_message_id, + thread_id, + request.revert_actions, + ) + else: + from_message_turn_id = from_msg.turn_id - if state_messages: - last_msg = state_messages[-1] - # Find a checkpoint where the last message is NOT a HumanMessage - # This means we're at a state before the user's last message - if not isinstance(last_msg, HumanMessage): - # If no new user_query provided (reload), extract from a later checkpoint - if user_query_to_use is None and i > 0: - # Get the user query from a more recent checkpoint - for prev_cp_tuple in checkpoint_tuples[:i]: - prev_checkpoint_data = prev_cp_tuple.checkpoint - prev_channel_values = prev_checkpoint_data.get( - "channel_values", {} - ) - prev_messages = prev_channel_values.get("messages", []) - for msg in reversed(prev_messages): - if isinstance(msg, HumanMessage): - q, imgs = split_langchain_human_content(msg.content) - user_query_to_use = q - regenerate_image_urls = imgs - break - if user_query_to_use is not None and ( - str(user_query_to_use).strip() or regenerate_image_urls - ): - break - - target_checkpoint_id = cp_tuple.config["configurable"][ + # Walk oldest-to-newest and pick the LAST checkpoint whose + # ``turn_id`` differs from the edited turn — that's the state + # immediately before this turn started running. We read from + # ``metadata`` (the durable surface) rather than + # ``config["configurable"]`` so the lookup works across + # checkpointer implementations. + target_checkpoint_id = _find_pre_turn_checkpoint_id( + checkpoint_tuples, + turn_id=from_message_turn_id, + ) + if target_checkpoint_id is None and len(checkpoint_tuples) > 0: + # Fall back to the oldest checkpoint — better than + # 400ing when the agent didn't checkpoint pre-turn + # (e.g. very first turn of the thread). + target_checkpoint_id = checkpoint_tuples[-1].config["configurable"][ "checkpoint_id" ] - break + + # Look through checkpoints to find the right one + # We want to find the checkpoint just before the last HumanMessage. + # We enter this branch when: + # * the client did NOT pin ``from_message_id`` (legacy reload/edit), OR + # * the client pinned ``from_message_id`` but the row is a + # legacy pre-migration-136 row with no ``turn_id`` (we + # downgraded to the same heuristic as a regular reload). + # We DO skip it when a real turn_id pinned ``target_checkpoint_id`` + # — that's the C1 happy path and the heuristic below would just + # re-derive a worse target. + if request.from_message_id is None or legacy_from_message: + for i, cp_tuple in enumerate(checkpoint_tuples): + # Access the checkpoint's channel_values which contains "messages" + checkpoint_data = cp_tuple.checkpoint + channel_values = checkpoint_data.get("channel_values", {}) + state_messages = channel_values.get("messages", []) + + if state_messages: + last_msg = state_messages[-1] + # Find a checkpoint where the last message is NOT a HumanMessage + # This means we're at a state before the user's last message + if not isinstance(last_msg, HumanMessage): + # If no new user_query provided (reload), extract from a later checkpoint + if user_query_to_use is None and i > 0: + # Get the user query from a more recent checkpoint + for prev_cp_tuple in checkpoint_tuples[:i]: + prev_checkpoint_data = prev_cp_tuple.checkpoint + prev_channel_values = prev_checkpoint_data.get( + "channel_values", {} + ) + prev_messages = prev_channel_values.get("messages", []) + for msg in reversed(prev_messages): + if isinstance(msg, HumanMessage): + q, imgs = split_langchain_human_content( + msg.content + ) + user_query_to_use = q + regenerate_image_urls = imgs + break + if user_query_to_use is not None and ( + str(user_query_to_use).strip() + or regenerate_image_urls + ): + break + + target_checkpoint_id = cp_tuple.config["configurable"][ + "checkpoint_id" + ] + break # If we couldn't find a good checkpoint, try alternative approaches if target_checkpoint_id is None and checkpoint_tuples: @@ -1472,18 +1821,51 @@ async def regenerate_response( detail="Could not determine user query for regeneration. Please provide a user_query.", ) - # Get the last two messages to delete AFTER streaming succeeds - # This prevents data loss if streaming fails - last_messages_result = await session.execute( - select(NewChatMessage) - .filter(NewChatMessage.thread_id == thread_id) - .order_by(NewChatMessage.created_at.desc()) - .limit(2) - ) + # Get the messages to delete AFTER streaming succeeds. + # This prevents data loss if streaming fails. + # + # When ``from_message_id`` is set we slice from that message + # forward (using ``created_at`` so we also catch any tool/system + # messages persisted into the same turn). Otherwise + # we keep the legacy "last 2 messages" rewind. + if request.from_message_id is not None and from_message_created_at is not None: + last_messages_result = await session.execute( + select(NewChatMessage) + .filter( + NewChatMessage.thread_id == thread_id, + NewChatMessage.created_at >= from_message_created_at, + ) + .order_by(NewChatMessage.created_at.desc()) + ) + else: + last_messages_result = await session.execute( + select(NewChatMessage) + .filter(NewChatMessage.thread_id == thread_id) + .order_by(NewChatMessage.created_at.desc()) + .limit(2) + ) messages_to_delete = list(last_messages_result.scalars().all()) message_ids_to_delete = [msg.id for msg in messages_to_delete] + # When revert_actions is requested, collect the set of + # ``chat_turn_id``s present in the slice we're about to delete. + # Each one will be reverted (best-effort) BEFORE the regenerate + # stream begins. Legacy rows have ``turn_id=None`` and silently + # contribute nothing — we already logged the degradation above. + revert_turn_ids: list[str] = [] + if ( + request.revert_actions + and request.from_message_id is not None + and not legacy_from_message + ): + seen_turns: set[str] = set() + for msg in messages_to_delete: + tid = msg.turn_id + if tid and tid not in seen_turns: + seen_turns.add(tid) + revert_turn_ids.append(tid) + # Get search space for LLM config search_space_result = await session.execute( select(SearchSpace).filter(SearchSpace.id == request.search_space_id) @@ -1507,6 +1889,24 @@ async def regenerate_response( # This prevents data loss if streaming fails (network error, LLM error, etc.) async def stream_with_cleanup(): streaming_completed = False + # Best-effort revert pass BEFORE the regenerate stream begins. + # Each turn is reverted independently (per-row SAVEPOINTs + # inside the route helper) and the per-action results are surfaced + # on a single ``data-revert-results`` SSE event so the frontend + # can render any failed rows alongside the new turn. Failures here + # do NOT abort the regeneration — partial rollback is documented + # behaviour. + if revert_turn_ids: + revert_results = await _revert_turns_for_regenerate( + thread_id=thread_id, + chat_turn_ids=revert_turn_ids, + requester_user_id=str(user.id), + ) + envelope = { + "type": "data-revert-results", + "data": revert_results, + } + yield f"data: {json.dumps(envelope, default=str)}\n\n".encode() try: async for chunk in stream_new_chat( user_query=str(user_query_to_use), diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index 477fdf2ca..c7284e901 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -51,6 +51,11 @@ class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel): author_display_name: str | None = None author_avatar_url: str | None = None token_usage: TokenUsageSummary | None = None + # Per-turn correlation id (``f"{chat_id}:{ms}"``) from + # ``configurable.turn_id`` at streaming time. Nullable because + # legacy rows predate the column; clients should treat NULL as + # "edit-from-this-message is unavailable". + turn_id: str | None = None model_config = ConfigDict(from_attributes=True) @@ -241,6 +246,15 @@ class RegenerateRequest(BaseModel): For edit, optional user_images (when not None) replaces image URLs resolved from checkpoint/DB so the client can send the full user turn (text and/or images). + + Edit-from-arbitrary-position. When ``from_message_id`` is provided + the route slices conversation history starting at that message (instead of + the legacy "last 2 messages" rewind), rewinds the LangGraph checkpoint by + matching ``configurable.turn_id`` stored on the message (added in migration 136), and + optionally reverts every reversible action emitted in turns at or after + ``from_message_id``. The revert step is best-effort and runs BEFORE the + regenerate stream — partial failures are surfaced via SSE + ``data-revert-results`` and do not abort the regeneration. """ search_space_id: int @@ -257,6 +271,28 @@ class RegenerateRequest(BaseModel): default=None, description="If set, use these images for the regenerated turn (edit); overrides checkpoint/DB", ) + from_message_id: int | None = Field( + default=None, + description=( + "Message id to rewind to. When set, history is sliced " + "from this message forward and the LangGraph checkpoint is " + "rewound to the state immediately preceding this turn. Legacy " + "rows that predate migration 136 have ``turn_id=None`` and " + "still process — the route logs a warning, skips the " + "checkpoint rewind, and ignores ``revert_actions`` (no " + "chat_turn_id available to walk)." + ), + ) + revert_actions: bool = Field( + default=False, + description=( + "When true, every reversible action emitted at or " + "after ``from_message_id`` is reverted before the regenerate " + "stream begins. Per-action results are surfaced via the " + "``data-revert-results`` SSE event. Partial failures DO NOT " + "abort the regeneration." + ), + ) @model_validator(mode="after") def _validate_regenerate_user_images(self) -> Self: @@ -264,6 +300,14 @@ class RegenerateRequest(BaseModel): raise ValueError(f"At most {MAX_NEW_CHAT_IMAGES} images allowed") return self + @model_validator(mode="after") + def _validate_revert_actions_requires_from_message(self) -> Self: + if self.revert_actions and self.from_message_id is None: + raise ValueError( + "revert_actions requires from_message_id; specify which message to rewind to" + ) + return self + # ============================================================================= # Agent Tools Schemas diff --git a/surfsense_backend/app/services/new_streaming_service.py b/surfsense_backend/app/services/new_streaming_service.py index 3e24c1376..5b7ef26d0 100644 --- a/surfsense_backend/app/services/new_streaming_service.py +++ b/surfsense_backend/app/services/new_streaming_service.py @@ -588,13 +588,24 @@ class VercelStreamingService: # Tool Parts # ========================================================================= - def format_tool_input_start(self, tool_call_id: str, tool_name: str) -> str: + def format_tool_input_start( + self, + tool_call_id: str, + tool_name: str, + *, + langchain_tool_call_id: str | None = None, + ) -> str: """ Format the start of tool input streaming. Args: - tool_call_id: The unique tool call identifier - tool_name: The name of the tool being called + tool_call_id: The unique tool call identifier (synthetic, derived + from LangGraph ``run_id`` so the frontend has a stable card id). + tool_name: The name of the tool being called. + langchain_tool_call_id: Optional authoritative LangChain + ``tool_call.id``. When set, surfaces as + ``langchainToolCallId`` so the frontend can join this card + to the action-log row written by ``ActionLogMiddleware``. Returns: str: SSE formatted tool input start part @@ -602,13 +613,14 @@ class VercelStreamingService: Example output: data: {"type":"tool-input-start","toolCallId":"call_abc123","toolName":"getWeather"} """ - return self._format_sse( - { - "type": "tool-input-start", - "toolCallId": tool_call_id, - "toolName": tool_name, - } - ) + payload: dict[str, Any] = { + "type": "tool-input-start", + "toolCallId": tool_call_id, + "toolName": tool_name, + } + if langchain_tool_call_id: + payload["langchainToolCallId"] = langchain_tool_call_id + return self._format_sse(payload) def format_tool_input_delta(self, tool_call_id: str, input_text_delta: str) -> str: """ @@ -633,7 +645,12 @@ class VercelStreamingService: ) def format_tool_input_available( - self, tool_call_id: str, tool_name: str, input_data: dict[str, Any] + self, + tool_call_id: str, + tool_name: str, + input_data: dict[str, Any], + *, + langchain_tool_call_id: str | None = None, ) -> str: """ Format the completion of tool input. @@ -642,6 +659,8 @@ class VercelStreamingService: tool_call_id: The tool call identifier tool_name: The name of the tool input_data: The complete tool input parameters + langchain_tool_call_id: Optional authoritative LangChain + ``tool_call.id`` (see ``format_tool_input_start``). Returns: str: SSE formatted tool input available part @@ -649,22 +668,34 @@ class VercelStreamingService: Example output: data: {"type":"tool-input-available","toolCallId":"call_abc123","toolName":"getWeather","input":{"city":"SF"}} """ - return self._format_sse( - { - "type": "tool-input-available", - "toolCallId": tool_call_id, - "toolName": tool_name, - "input": input_data, - } - ) + payload: dict[str, Any] = { + "type": "tool-input-available", + "toolCallId": tool_call_id, + "toolName": tool_name, + "input": input_data, + } + if langchain_tool_call_id: + payload["langchainToolCallId"] = langchain_tool_call_id + return self._format_sse(payload) - def format_tool_output_available(self, tool_call_id: str, output: Any) -> str: + def format_tool_output_available( + self, + tool_call_id: str, + output: Any, + *, + langchain_tool_call_id: str | None = None, + ) -> str: """ Format tool execution output. Args: tool_call_id: The tool call identifier output: The tool execution result + langchain_tool_call_id: Optional authoritative LangChain + ``tool_call.id`` extracted from ``ToolMessage.tool_call_id``. + When set, the frontend can backfill any card whose + ``langchainToolCallId`` was not yet known at + ``tool-input-start`` time. Returns: str: SSE formatted tool output available part @@ -672,13 +703,14 @@ class VercelStreamingService: Example output: data: {"type":"tool-output-available","toolCallId":"call_abc123","output":{"weather":"sunny"}} """ - return self._format_sse( - { - "type": "tool-output-available", - "toolCallId": tool_call_id, - "output": output, - } - ) + payload: dict[str, Any] = { + "type": "tool-output-available", + "toolCallId": tool_call_id, + "output": output, + } + if langchain_tool_call_id: + payload["langchainToolCallId"] = langchain_tool_call_id + return self._format_sse(payload) # ========================================================================= # Step Parts diff --git a/surfsense_backend/app/services/revert_service.py b/surfsense_backend/app/services/revert_service.py index f3630e0b4..d02a31345 100644 --- a/surfsense_backend/app/services/revert_service.py +++ b/surfsense_backend/app/services/revert_service.py @@ -8,7 +8,9 @@ Operation outcomes mirror the plan: * **KB-owned actions** (NOTE / FILE / FOLDER mutations): restore from :class:`app.db.DocumentRevision` / :class:`app.db.FolderRevision` rows - written before the original mutation. + written before the original mutation. ``rm``/``rmdir`` re-INSERT a fresh + row from the snapshot; ``write_file`` create / ``mkdir`` DELETE the row + that was created; everything else is an in-place restore. * **Connector-owned actions with a declared ``reverse_descriptor``**: invoke the inverse tool through the agent's normal permission stack (NOT bypassed). Out of scope for this PR — returns ``REVERSE_NOT_IMPLEMENTED``. @@ -18,6 +20,11 @@ Operation outcomes mirror the plan: A successful revert appends a NEW row to ``agent_action_log`` with ``reverse_of=`` and the requesting user's ``user_id``, preserving an auditable chain. + +Dispatch must be exact-match (``tool_name == name``), NOT prefix matching. +``"rmdir".startswith("rm")`` would otherwise mis-route directory revert +to the document branch (and ``delete_note`` vs ``delete_folder`` is the +same trap waiting to happen). """ from __future__ import annotations @@ -25,17 +32,31 @@ from __future__ import annotations import logging from dataclasses import dataclass from datetime import UTC, datetime -from typing import Literal +from typing import Any, Literal -from sqlalchemy import select +from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession +from app.agents.new_chat.path_resolver import ( + DOCUMENTS_ROOT, + safe_filename, + safe_folder_segment, +) from app.db import ( AgentActionLog, + Chunk, + Document, DocumentRevision, + DocumentType, + Folder, FolderRevision, NewChatThread, ) +from app.utils.document_converters import ( + embed_texts, + generate_content_hash, + generate_unique_identifier_hash, +) logger = logging.getLogger(__name__) @@ -110,14 +131,244 @@ def can_revert( # --------------------------------------------------------------------------- -# Revert paths +# Helper: reconstruct virtual path from a snapshot # --------------------------------------------------------------------------- +async def _virtual_path_from_snapshot( + session: AsyncSession, + revision: DocumentRevision, +) -> str | None: + """Reconstruct the virtual_path the document was at before mutation. + + Preference order: + 1. ``metadata_before["virtual_path"]`` — written by every snapshot + helper since this PR. + 2. Compose ``"/"`` from + ``folder_id_before`` + ``title_before``. Walks the folder chain via + ``parent_id``. + """ + metadata = revision.metadata_before or {} + candidate = metadata.get("virtual_path") if isinstance(metadata, dict) else None + if isinstance(candidate, str) and candidate.startswith(DOCUMENTS_ROOT): + return candidate + + title = revision.title_before + if not isinstance(title, str) or not title: + return None + + parts: list[str] = [] + cursor: int | None = revision.folder_id_before + visited: set[int] = set() + while cursor is not None and cursor not in visited: + visited.add(cursor) + folder = await session.get(Folder, cursor) + if folder is None: + return None + parts.append(safe_folder_segment(str(folder.name or ""))) + cursor = folder.parent_id + parts.reverse() + + base = f"{DOCUMENTS_ROOT}/" + "/".join(parts) if parts else DOCUMENTS_ROOT + filename = safe_filename(title) + return f"{base}/{filename}" + + +# --------------------------------------------------------------------------- +# Document revision restore (write/edit/move/rm) +# --------------------------------------------------------------------------- + + +def _set_field(target: Any, field: str, value: Any) -> None: + if value is not None: + setattr(target, field, value) + + +async def _restore_in_place_document( + session: AsyncSession, + *, + revision: DocumentRevision, +) -> RevertOutcome: + """Apply an in-place restore to an existing :class:`Document`.""" + if revision.document_id is None: + return RevertOutcome( + status="tool_unavailable", + message=( + "Original document was hard-deleted; in-place restore is not possible." + ), + ) + doc = await session.get(Document, revision.document_id) + if doc is None: + return RevertOutcome( + status="tool_unavailable", + message="Original document has been deleted; revert cannot proceed.", + ) + + _set_field(doc, "content", revision.content_before) + _set_field(doc, "source_markdown", revision.content_before) + _set_field(doc, "title", revision.title_before) + _set_field(doc, "folder_id", revision.folder_id_before) + metadata_before = revision.metadata_before or {} + if isinstance(metadata_before, dict) and metadata_before: + doc.document_metadata = dict(metadata_before) + + if isinstance(revision.content_before, str): + doc.content_hash = generate_content_hash( + revision.content_before, doc.search_space_id + ) + + virtual_path = await _virtual_path_from_snapshot(session, revision) + if virtual_path: + doc.unique_identifier_hash = generate_unique_identifier_hash( + DocumentType.NOTE, + virtual_path, + doc.search_space_id, + ) + + chunks_before = revision.chunks_before + if isinstance(chunks_before, list): + await session.execute(delete(Chunk).where(Chunk.document_id == doc.id)) + chunk_texts = [ + str(c.get("content")) + for c in chunks_before + if isinstance(c, dict) and isinstance(c.get("content"), str) + ] + if chunk_texts: + chunk_embeddings = embed_texts(chunk_texts) + session.add_all( + [ + Chunk(document_id=doc.id, content=text, embedding=embedding) + for text, embedding in zip( + chunk_texts, chunk_embeddings, strict=True + ) + ] + ) + if isinstance(revision.content_before, str): + doc.embedding = embed_texts([revision.content_before])[0] + + doc.updated_at = datetime.now(UTC) + return RevertOutcome(status="ok", message="Document restored from snapshot.") + + +async def _reinsert_document_from_revision( + session: AsyncSession, + *, + revision: DocumentRevision, +) -> RevertOutcome: + """Re-INSERT a deleted :class:`Document` from a snapshot row (``rm`` revert).""" + if not isinstance(revision.title_before, str) or not revision.title_before: + return RevertOutcome( + status="not_reversible", + message="Snapshot lacks title_before; cannot recreate document.", + ) + if not isinstance(revision.content_before, str): + return RevertOutcome( + status="not_reversible", + message="Snapshot lacks content_before; cannot recreate document.", + ) + + virtual_path = await _virtual_path_from_snapshot(session, revision) + if not virtual_path: + return RevertOutcome( + status="not_reversible", + message=( + "Snapshot is missing both metadata_before['virtual_path'] AND " + "a resolvable (folder_id_before, title_before) pair." + ), + ) + + search_space_id = revision.search_space_id + unique_identifier_hash = generate_unique_identifier_hash( + DocumentType.NOTE, + virtual_path, + search_space_id, + ) + collision = await session.execute( + select(Document.id).where( + Document.search_space_id == search_space_id, + Document.unique_identifier_hash == unique_identifier_hash, + ) + ) + if collision.scalar_one_or_none() is not None: + return RevertOutcome( + status="tool_unavailable", + message=( + f"A document already exists at '{virtual_path}'; revert would " + "collide. Move the live doc out of the way first." + ), + ) + + metadata = revision.metadata_before or {} + if not isinstance(metadata, dict): + metadata = {} + metadata = dict(metadata) + metadata["virtual_path"] = virtual_path + + content = revision.content_before + new_doc = Document( + title=revision.title_before, + document_type=DocumentType.NOTE, + document_metadata=metadata, + content=content, + content_hash=generate_content_hash(content, search_space_id), + unique_identifier_hash=unique_identifier_hash, + source_markdown=content, + search_space_id=search_space_id, + folder_id=revision.folder_id_before, + updated_at=datetime.now(UTC), + ) + session.add(new_doc) + await session.flush() + + new_doc.embedding = embed_texts([content])[0] + chunk_texts = [] + chunks_before = revision.chunks_before + if isinstance(chunks_before, list): + chunk_texts = [ + str(c.get("content")) + for c in chunks_before + if isinstance(c, dict) and isinstance(c.get("content"), str) + ] + if chunk_texts: + chunk_embeddings = embed_texts(chunk_texts) + session.add_all( + [ + Chunk(document_id=new_doc.id, content=text, embedding=embedding) + for text, embedding in zip(chunk_texts, chunk_embeddings, strict=True) + ] + ) + + # Repoint the snapshot at the recreated row so a follow-up revert of + # the same row works as expected. + revision.document_id = new_doc.id + return RevertOutcome( + status="ok", + message=f"Re-inserted document '{revision.title_before}' from snapshot.", + ) + + +async def _delete_created_document( + session: AsyncSession, + *, + revision: DocumentRevision, +) -> RevertOutcome: + """Delete the document that ``write_file`` created (``content_before IS NULL``).""" + if revision.document_id is None: + return RevertOutcome( + status="ok", + message="No live row to delete (already removed elsewhere).", + ) + await session.execute(delete(Document).where(Document.id == revision.document_id)) + return RevertOutcome( + status="ok", + message="Deleted the document that was created by this action.", + ) + + async def _restore_document_revision( session: AsyncSession, *, action: AgentActionLog ) -> RevertOutcome: - """Restore the most recent :class:`DocumentRevision` for ``action``.""" + """Dispatch document-level revert based on ``action.tool_name``.""" stmt = ( select(DocumentRevision) .where(DocumentRevision.agent_action_id == action.id) @@ -132,23 +383,111 @@ async def _restore_document_revision( message="No document_revisions row tied to this action.", ) - from app.db import Document # late import to avoid cycles at module load + tool_name = (action.tool_name or "").lower() - doc = await session.get(Document, revision.document_id) - if doc is None: + if tool_name == "rm": + return await _reinsert_document_from_revision(session, revision=revision) + + if tool_name == "write_file" and revision.content_before is None: + return await _delete_created_document(session, revision=revision) + + return await _restore_in_place_document(session, revision=revision) + + +# --------------------------------------------------------------------------- +# Folder revision restore (mkdir/rmdir/rename/move) +# --------------------------------------------------------------------------- + + +async def _restore_in_place_folder( + session: AsyncSession, + *, + revision: FolderRevision, +) -> RevertOutcome: + if revision.folder_id is None: return RevertOutcome( status="tool_unavailable", - message="Original document has been deleted; revert cannot proceed.", + message="Original folder was hard-deleted; in-place restore is impossible.", + ) + folder = await session.get(Folder, revision.folder_id) + if folder is None: + return RevertOutcome( + status="tool_unavailable", + message="Original folder has been deleted; revert cannot proceed.", + ) + _set_field(folder, "name", revision.name_before) + _set_field(folder, "parent_id", revision.parent_id_before) + _set_field(folder, "position", revision.position_before) + folder.updated_at = datetime.now(UTC) + return RevertOutcome(status="ok", message="Folder restored from snapshot.") + + +async def _reinsert_folder_from_revision( + session: AsyncSession, + *, + revision: FolderRevision, +) -> RevertOutcome: + if not isinstance(revision.name_before, str) or not revision.name_before: + return RevertOutcome( + status="not_reversible", + message="Snapshot lacks name_before; cannot recreate folder.", + ) + new_folder = Folder( + name=revision.name_before, + parent_id=revision.parent_id_before, + position=revision.position_before, + search_space_id=revision.search_space_id, + updated_at=datetime.now(UTC), + ) + session.add(new_folder) + await session.flush() + revision.folder_id = new_folder.id + return RevertOutcome( + status="ok", + message=f"Re-inserted folder '{revision.name_before}' from snapshot.", + ) + + +async def _delete_created_folder( + session: AsyncSession, + *, + revision: FolderRevision, +) -> RevertOutcome: + if revision.folder_id is None: + return RevertOutcome( + status="ok", + message="No live folder row to delete (already removed elsewhere).", + ) + folder_id = revision.folder_id + + has_doc = await session.execute( + select(Document.id).where(Document.folder_id == folder_id).limit(1) + ) + if has_doc.scalar_one_or_none() is not None: + return RevertOutcome( + status="tool_unavailable", + message=( + "Folder is no longer empty (documents have been added since " + "mkdir); cannot revert." + ), + ) + has_child = await session.execute( + select(Folder.id).where(Folder.parent_id == folder_id).limit(1) + ) + if has_child.scalar_one_or_none() is not None: + return RevertOutcome( + status="tool_unavailable", + message=( + "Folder is no longer empty (sub-folders have been added " + "since mkdir); cannot revert." + ), ) - if revision.content_before is not None: - doc.content = revision.content_before - if revision.title_before is not None: - doc.title = revision.title_before - if revision.folder_id_before is not None: - doc.folder_id = revision.folder_id_before - doc.updated_at = datetime.now(UTC) - return RevertOutcome(status="ok", message="Document restored from snapshot.") + await session.execute(delete(Folder).where(Folder.id == folder_id)) + return RevertOutcome( + status="ok", + message="Deleted the folder that was created by this action.", + ) async def _restore_folder_revision( @@ -168,41 +507,44 @@ async def _restore_folder_revision( message="No folder_revisions row tied to this action.", ) - from app.db import Folder + tool_name = (action.tool_name or "").lower() - folder = await session.get(Folder, revision.folder_id) - if folder is None: - return RevertOutcome( - status="tool_unavailable", - message="Original folder has been deleted; revert cannot proceed.", - ) + if tool_name == "rmdir": + return await _reinsert_folder_from_revision(session, revision=revision) - if revision.name_before is not None: - folder.name = revision.name_before - if revision.parent_id_before is not None: - folder.parent_id = revision.parent_id_before - if revision.position_before is not None: - folder.position = revision.position_before - folder.updated_at = datetime.now(UTC) - return RevertOutcome(status="ok", message="Folder restored from snapshot.") + if tool_name == "mkdir": + return await _delete_created_folder(session, revision=revision) + + return await _restore_in_place_folder(session, revision=revision) -# Tool-name prefixes that route to KB document / folder revert paths. Kept -# as data so a future PR adding new KB-owned tools doesn't have to touch -# this module's control flow. -_DOC_TOOL_PREFIXES: tuple[str, ...] = ( - "edit_file", - "write_file", - "update_memory", - "create_note", - "update_note", - "delete_note", +# --------------------------------------------------------------------------- +# Dispatch +# --------------------------------------------------------------------------- +# +# Exact-name dispatch: ``tool_name == name``, NOT ``startswith(...)``. +# Prefix-matching mis-routes pairs like ``rm``/``rmdir`` and +# ``delete_note``/``delete_folder``. + +_DOC_TOOLS: frozenset[str] = frozenset( + { + "edit_file", + "write_file", + "move_file", + "rm", + "update_memory", + "create_note", + "update_note", + "delete_note", + } ) -_FOLDER_TOOL_PREFIXES: tuple[str, ...] = ( - "mkdir", - "move_file", - "rename_folder", - "delete_folder", +_FOLDER_TOOLS: frozenset[str] = frozenset( + { + "mkdir", + "rmdir", + "rename_folder", + "delete_folder", + } ) @@ -220,9 +562,9 @@ async def revert_action( """ tool_name = (action.tool_name or "").lower() - if tool_name.startswith(_DOC_TOOL_PREFIXES): + if tool_name in _DOC_TOOLS: outcome = await _restore_document_revision(session, action=action) - elif tool_name.startswith(_FOLDER_TOOL_PREFIXES): + elif tool_name in _FOLDER_TOOLS: outcome = await _restore_folder_revision(session, action=action) elif action.reverse_descriptor: # Connector-owned reversibles run through the normal permission diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index d6ca5418c..90601a5bc 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -32,6 +32,7 @@ from sqlalchemy.orm import selectinload from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent from app.agents.new_chat.checkpointer import get_checkpointer from app.agents.new_chat.errors import BusyError +from app.agents.new_chat.feature_flags import get_flags from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection from app.agents.new_chat.llm_config import ( AgentConfig, @@ -73,6 +74,91 @@ _background_tasks: set[asyncio.Task] = set() _perf_log = get_perf_logger() +def _extract_chunk_parts(chunk: Any) -> dict[str, Any]: + """Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts. + + Returns a dict with three keys: + + * ``text`` — concatenated string content (empty string if the chunk + contributes none). + * ``reasoning`` — concatenated reasoning content (empty string if the + chunk contributes none). + * ``tool_call_chunks`` — flat list of LangChain ``tool_call_chunk`` + dicts surfaced from either the typed-block list or the + ``tool_call_chunks`` attribute. + + Background + ---------- + ``AIMessageChunk.content`` can be: + + * a ``str`` (most providers), or + * a ``list`` of typed blocks ``{type: 'text' | 'reasoning' | + 'tool_call_chunk' | 'tool_use' | ..., text/content/...}`` for + Anthropic, Bedrock, and several reasoning configurations. + + Reasoning may also live under + ``chunk.additional_kwargs['reasoning_content']`` (some providers + surface it that way instead of as a typed block). Tool-call chunks + may live under ``chunk.tool_call_chunks`` even when ``content`` is a + plain string. + + Earlier versions only handled the ``isinstance(content, str)`` branch + and silently dropped reasoning blocks + tool-call chunks emitted by + LangChain ``AIMessageChunk``s. + """ + out: dict[str, Any] = {"text": "", "reasoning": "", "tool_call_chunks": []} + if chunk is None: + return out + + content = getattr(chunk, "content", None) + if isinstance(content, str): + if content: + out["text"] = content + elif isinstance(content, list): + text_parts: list[str] = [] + reasoning_parts: list[str] = [] + for block in content: + if not isinstance(block, dict): + continue + block_type = block.get("type") + if block_type == "text": + value = block.get("text") or block.get("content") or "" + if isinstance(value, str) and value: + text_parts.append(value) + elif block_type == "reasoning": + value = ( + block.get("reasoning") + or block.get("text") + or block.get("content") + or "" + ) + if isinstance(value, str) and value: + reasoning_parts.append(value) + elif block_type in ("tool_call_chunk", "tool_use"): + out["tool_call_chunks"].append(block) + if text_parts: + out["text"] = "".join(text_parts) + if reasoning_parts: + out["reasoning"] = "".join(reasoning_parts) + + additional = getattr(chunk, "additional_kwargs", None) or {} + if isinstance(additional, dict): + extra_reasoning = additional.get("reasoning_content") + if isinstance(extra_reasoning, str) and extra_reasoning: + existing = out["reasoning"] + out["reasoning"] = ( + (existing + extra_reasoning) if existing else extra_reasoning + ) + + extra_tool_chunks = getattr(chunk, "tool_call_chunks", None) + if isinstance(extra_tool_chunks, list): + for tcc in extra_tool_chunks: + if isinstance(tcc, dict): + out["tool_call_chunks"].append(tcc) + + return out + + def format_mentioned_surfsense_docs_as_context( documents: list[SurfsenseDocsDocument], ) -> str: @@ -401,6 +487,7 @@ async def _stream_agent_events( fallback_commit_search_space_id: int | None = None, fallback_commit_created_by_id: str | None = None, fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD, + fallback_commit_thread_id: int | None = None, ) -> AsyncGenerator[str, None]: """Shared async generator that streams and formats astream_events from the agent. @@ -433,6 +520,41 @@ async def _stream_agent_events( active_tool_depth: int = 0 # Track nesting: >0 means we're inside a tool called_update_memory: bool = False + # Reasoning-block streaming. We open a reasoning block on the + # first reasoning delta of a step, append deltas as they arrive, and + # close it when text starts (the model has switched to writing its + # answer) or ``on_chat_model_end`` fires for the model node. Reuses + # the same Vercel format-helpers as text-start/delta/end. + current_reasoning_id: str | None = None + + # Streaming-parity v2 feature flag. When OFF we keep the legacy + # shape: str-only content, no reasoning blocks, no + # ``langchainToolCallId`` propagation. The schema migrations + # (135 / 136) ship unconditionally because they're forward-compatible. + parity_v2 = bool(get_flags().enable_stream_parity_v2) + + # Best-effort attach of LangChain ``tool_call_id`` to the synthetic + # ``call_`` card id we already emit. We accumulate + # ``tool_call_chunks`` from ``on_chat_model_stream``, key them by + # name, and pop the next unconsumed entry at ``on_tool_start``. The + # authoritative id is later filled in at ``on_tool_end`` from + # ``ToolMessage.tool_call_id``. + pending_tool_call_chunks: list[dict[str, Any]] = [] + lc_tool_call_id_by_run: dict[str, str] = {} + + # Per-tool-end mutable cache for the LangChain tool_call_id resolved + # at ``on_tool_end``. ``_emit_tool_output`` reads this so every + # ``format_tool_output_available`` call automatically carries the + # authoritative id without duplicating the kwarg at every call site. + current_lc_tool_call_id: dict[str, str | None] = {"value": None} + + def _emit_tool_output(call_id: str, output: Any) -> str: + return streaming_service.format_tool_output_available( + call_id, + output, + langchain_tool_call_id=current_lc_tool_call_id["value"], + ) + def next_thinking_step_id() -> str: nonlocal thinking_step_counter thinking_step_counter += 1 @@ -461,22 +583,61 @@ async def _stream_agent_events( if "surfsense:internal" in event.get("tags", []): continue # Suppress middleware-internal LLM tokens (e.g. KB search classification) chunk = event.get("data", {}).get("chunk") - if chunk and hasattr(chunk, "content"): - content = chunk.content - if content and isinstance(content, str): - if current_text_id is None: - completion_event = complete_current_step() - if completion_event: - yield completion_event - if just_finished_tool: - last_active_step_id = None - last_active_step_title = "" - last_active_step_items = [] - just_finished_tool = False - current_text_id = streaming_service.generate_text_id() - yield streaming_service.format_text_start(current_text_id) - yield streaming_service.format_text_delta(current_text_id, content) - accumulated_text += content + if not chunk: + continue + parts = _extract_chunk_parts(chunk) + + # Accumulate any tool_call_chunks for best-effort + # correlation with ``on_tool_start`` below. We don't emit + # anything here; the matching is done at tool-start time. + if parity_v2 and parts["tool_call_chunks"]: + for tcc in parts["tool_call_chunks"]: + pending_tool_call_chunks.append(tcc) + + reasoning_delta = parts["reasoning"] + text_delta = parts["text"] + + # Reasoning streaming. Open a reasoning block on first + # delta; append every subsequent delta until text begins. + # When text starts we close the reasoning block first so the + # frontend sees the natural hand-off. Gated behind the + # parity-v2 flag so legacy deployments keep today's shape. + if parity_v2 and reasoning_delta: + if current_text_id is not None: + yield streaming_service.format_text_end(current_text_id) + current_text_id = None + if current_reasoning_id is None: + completion_event = complete_current_step() + if completion_event: + yield completion_event + if just_finished_tool: + last_active_step_id = None + last_active_step_title = "" + last_active_step_items = [] + just_finished_tool = False + current_reasoning_id = streaming_service.generate_reasoning_id() + yield streaming_service.format_reasoning_start(current_reasoning_id) + yield streaming_service.format_reasoning_delta( + current_reasoning_id, reasoning_delta + ) + + if text_delta: + if current_reasoning_id is not None: + yield streaming_service.format_reasoning_end(current_reasoning_id) + current_reasoning_id = None + if current_text_id is None: + completion_event = complete_current_step() + if completion_event: + yield completion_event + if just_finished_tool: + last_active_step_id = None + last_active_step_title = "" + last_active_step_items = [] + just_finished_tool = False + current_text_id = streaming_service.generate_text_id() + yield streaming_service.format_text_start(current_text_id) + yield streaming_service.format_text_delta(current_text_id, text_delta) + accumulated_text += text_delta elif event_type == "on_tool_start": active_tool_depth += 1 @@ -596,6 +757,95 @@ async def _stream_agent_events( status="in_progress", items=last_active_step_items, ) + elif tool_name == "rm": + rm_path = ( + tool_input.get("path", "") + if isinstance(tool_input, dict) + else str(tool_input) + ) + display_path = rm_path if len(rm_path) <= 80 else "…" + rm_path[-77:] + last_active_step_title = "Deleting file" + last_active_step_items = [display_path] if display_path else [] + yield streaming_service.format_thinking_step( + step_id=tool_step_id, + title="Deleting file", + status="in_progress", + items=last_active_step_items, + ) + elif tool_name == "rmdir": + rmdir_path = ( + tool_input.get("path", "") + if isinstance(tool_input, dict) + else str(tool_input) + ) + display_path = ( + rmdir_path if len(rmdir_path) <= 80 else "…" + rmdir_path[-77:] + ) + last_active_step_title = "Deleting folder" + last_active_step_items = [display_path] if display_path else [] + yield streaming_service.format_thinking_step( + step_id=tool_step_id, + title="Deleting folder", + status="in_progress", + items=last_active_step_items, + ) + elif tool_name == "mkdir": + mkdir_path = ( + tool_input.get("path", "") + if isinstance(tool_input, dict) + else str(tool_input) + ) + display_path = ( + mkdir_path if len(mkdir_path) <= 80 else "…" + mkdir_path[-77:] + ) + last_active_step_title = "Creating folder" + last_active_step_items = [display_path] if display_path else [] + yield streaming_service.format_thinking_step( + step_id=tool_step_id, + title="Creating folder", + status="in_progress", + items=last_active_step_items, + ) + elif tool_name == "move_file": + src = ( + tool_input.get("source_path", "") + if isinstance(tool_input, dict) + else "" + ) + dst = ( + tool_input.get("destination_path", "") + if isinstance(tool_input, dict) + else "" + ) + display_src = src if len(src) <= 60 else "…" + src[-57:] + display_dst = dst if len(dst) <= 60 else "…" + dst[-57:] + last_active_step_title = "Moving file" + last_active_step_items = ( + [f"{display_src} → {display_dst}"] if src or dst else [] + ) + yield streaming_service.format_thinking_step( + step_id=tool_step_id, + title="Moving file", + status="in_progress", + items=last_active_step_items, + ) + elif tool_name == "write_todos": + todos = ( + tool_input.get("todos", []) if isinstance(tool_input, dict) else [] + ) + todo_count = len(todos) if isinstance(todos, list) else 0 + last_active_step_title = "Planning tasks" + last_active_step_items = ( + [f"{todo_count} task{'s' if todo_count != 1 else ''}"] + if todo_count + else [] + ) + yield streaming_service.format_thinking_step( + step_id=tool_step_id, + title="Planning tasks", + status="in_progress", + items=last_active_step_items, + ) elif tool_name == "save_document": doc_title = ( tool_input.get("title", "") @@ -703,7 +953,15 @@ async def _stream_agent_events( items=last_active_step_items, ) else: - last_active_step_title = f"Using {tool_name.replace('_', ' ')}" + # Fallback for tools without a curated thinking-step title + # (typically connector tools, MCP-registered tools, or + # newly added tools that haven't been wired up here yet). + # Render the snake_cased name as a sentence-cased phrase + # so non-technical users see e.g. "Send gmail email" + # rather than the raw identifier "send_gmail_email". + last_active_step_title = ( + tool_name.replace("_", " ").strip().capitalize() or tool_name + ) last_active_step_items = [] yield streaming_service.format_thinking_step( step_id=tool_step_id, @@ -716,7 +974,39 @@ async def _stream_agent_events( if run_id else streaming_service.generate_tool_call_id() ) - yield streaming_service.format_tool_input_start(tool_call_id, tool_name) + + # Best-effort attach the LangChain ``tool_call_id``. We + # pop the first chunk in ``pending_tool_call_chunks`` whose + # name matches; if none match (the chunked args may not yet + # carry a ``name`` field, or the model skipped the chunked + # form) we leave ``langchainToolCallId`` unset for now and + # fill it in authoritatively at ``on_tool_end`` from + # ``ToolMessage.tool_call_id``. + langchain_tool_call_id: str | None = None + if parity_v2 and pending_tool_call_chunks: + matched_idx: int | None = None + for idx, tcc in enumerate(pending_tool_call_chunks): + if tcc.get("name") == tool_name and tcc.get("id"): + matched_idx = idx + break + if matched_idx is None: + for idx, tcc in enumerate(pending_tool_call_chunks): + if tcc.get("id"): + matched_idx = idx + break + if matched_idx is not None: + matched = pending_tool_call_chunks.pop(matched_idx) + candidate = matched.get("id") + if isinstance(candidate, str) and candidate: + langchain_tool_call_id = candidate + if run_id: + lc_tool_call_id_by_run[run_id] = candidate + + yield streaming_service.format_tool_input_start( + tool_call_id, + tool_name, + langchain_tool_call_id=langchain_tool_call_id, + ) # Sanitize tool_input: strip runtime-injected non-serializable # values (e.g. LangChain ToolRuntime) before sending over SSE. if isinstance(tool_input, dict): @@ -733,6 +1023,7 @@ async def _stream_agent_events( tool_call_id, tool_name, _safe_input, + langchain_tool_call_id=langchain_tool_call_id, ) elif event_type == "on_tool_end": @@ -774,6 +1065,23 @@ async def _stream_agent_events( ) completed_step_ids.add(original_step_id) + # Authoritative LangChain tool_call_id from the returned + # ``ToolMessage``. Falls back to whatever we matched + # at ``on_tool_start`` time (kept in ``lc_tool_call_id_by_run``) + # if the output isn't a ToolMessage. The value is stored in + # ``current_lc_tool_call_id`` so ``_emit_tool_output`` + # picks it up for every output emit below. Stays None when + # parity_v2 is off so legacy emit paths are untouched. + current_lc_tool_call_id["value"] = None + if parity_v2: + authoritative = getattr(raw_output, "tool_call_id", None) + if isinstance(authoritative, str) and authoritative: + current_lc_tool_call_id["value"] = authoritative + if run_id: + lc_tool_call_id_by_run[run_id] = authoritative + elif run_id and run_id in lc_tool_call_id_by_run: + current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id] + if tool_name == "read_file": yield streaming_service.format_thinking_step( step_id=original_step_id, @@ -809,6 +1117,41 @@ async def _stream_agent_events( status="completed", items=last_active_step_items, ) + elif tool_name == "rm": + yield streaming_service.format_thinking_step( + step_id=original_step_id, + title="Deleting file", + status="completed", + items=last_active_step_items, + ) + elif tool_name == "rmdir": + yield streaming_service.format_thinking_step( + step_id=original_step_id, + title="Deleting folder", + status="completed", + items=last_active_step_items, + ) + elif tool_name == "mkdir": + yield streaming_service.format_thinking_step( + step_id=original_step_id, + title="Creating folder", + status="completed", + items=last_active_step_items, + ) + elif tool_name == "move_file": + yield streaming_service.format_thinking_step( + step_id=original_step_id, + title="Moving file", + status="completed", + items=last_active_step_items, + ) + elif tool_name == "write_todos": + yield streaming_service.format_thinking_step( + step_id=original_step_id, + title="Planning tasks", + status="completed", + items=last_active_step_items, + ) elif tool_name == "save_document": result_str = ( tool_output.get("result", "") @@ -1060,9 +1403,14 @@ async def _stream_agent_events( items=completed_items, ) else: + # Fallback completion title — see the matching in-progress + # branch above for the wording rationale. + fallback_title = ( + tool_name.replace("_", " ").strip().capitalize() or tool_name + ) yield streaming_service.format_thinking_step( step_id=original_step_id, - title=f"Using {tool_name.replace('_', ' ')}", + title=fallback_title, status="completed", items=last_active_step_items, ) @@ -1073,7 +1421,7 @@ async def _stream_agent_events( last_active_step_items = [] if tool_name == "generate_podcast": - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, tool_output if isinstance(tool_output, dict) @@ -1098,7 +1446,7 @@ async def _stream_agent_events( "error", ) elif tool_name == "generate_video_presentation": - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, tool_output if isinstance(tool_output, dict) @@ -1126,7 +1474,7 @@ async def _stream_agent_events( "error", ) elif tool_name == "generate_image": - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, tool_output if isinstance(tool_output, dict) @@ -1153,12 +1501,12 @@ async def _stream_agent_events( display_output["content_preview"] = ( content[:500] + "..." if len(content) > 500 else content ) - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, display_output, ) else: - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, {"result": tool_output}, ) @@ -1186,7 +1534,7 @@ async def _stream_agent_events( ) result_text = _tool_output_to_text(tool_output) if _tool_output_has_error(tool_output): - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, { "status": "error", @@ -1195,7 +1543,7 @@ async def _stream_agent_events( }, ) else: - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, { "status": "completed", @@ -1205,7 +1553,7 @@ async def _stream_agent_events( ) elif tool_name == "generate_report": # Stream the full report result so frontend can render the ReportCard - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, tool_output if isinstance(tool_output, dict) @@ -1232,7 +1580,7 @@ async def _stream_agent_events( "error", ) elif tool_name == "generate_resume": - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, tool_output if isinstance(tool_output, dict) @@ -1283,7 +1631,7 @@ async def _stream_agent_events( "update_confluence_page", "delete_confluence_page", ): - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, tool_output if isinstance(tool_output, dict) @@ -1311,7 +1659,7 @@ async def _stream_agent_events( if fpath and fpath not in result.sandbox_files: result.sandbox_files.append(fpath) - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, { "exit_code": exit_code, @@ -1346,12 +1694,12 @@ async def _stream_agent_events( citations[chunk_url]["snippet"] = ( content[:200] + "…" if len(content) > 200 else content ) - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, {"status": "completed", "citations": citations}, ) else: - yield streaming_service.format_tool_output_available( + yield _emit_tool_output( tool_call_id, {"status": "completed", "result_length": len(str(tool_output))}, ) @@ -1409,6 +1757,25 @@ async def _stream_agent_events( }, ) + elif event_type == "on_custom_event" and event.get("name") == "action_log": + # Surface a freshly committed AgentActionLog row so the chat + # tool card can render its Revert button immediately. + data = event.get("data", {}) + if data.get("id") is not None: + yield streaming_service.format_data("action-log", data) + + elif ( + event_type == "on_custom_event" + and event.get("name") == "action_log_updated" + ): + # Reversibility flipped in kb_persistence after the SAVEPOINT + # for a destructive op (rm/rmdir/move/edit/write) committed. + # Frontend uses this to flip the card's Revert + # button on without re-fetching the actions list. + data = event.get("data", {}) + if data.get("id") is not None: + yield streaming_service.format_data("action-log-updated", data) + elif event_type in ("on_chain_end", "on_agent_end"): if current_text_id is not None: yield streaming_service.format_text_end(current_text_id) @@ -1426,11 +1793,12 @@ async def _stream_agent_events( # Safety net: if astream_events was cancelled before # KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work - # (dirty_paths / staged_dirs / pending_moves) will still be in the - # checkpointed state. Run the SAME shared commit helper here so the - # turn's writes don't get lost on client disconnect, then push the - # delta back into the graph using `as_node=...` so reducers fire as if - # the after_agent hook produced it. + # (dirty_paths / staged_dirs / pending_moves / pending_deletes / + # pending_dir_deletes) will still be in the checkpointed state. Run + # the SAME shared commit helper here so the turn's writes don't get + # lost on client disconnect, then push the delta back into the graph + # using `as_node=...` so reducers fire as if the after_agent hook + # produced it. if ( fallback_commit_filesystem_mode == FilesystemMode.CLOUD and fallback_commit_search_space_id is not None @@ -1438,6 +1806,8 @@ async def _stream_agent_events( (state_values.get("dirty_paths") or []) or (state_values.get("staged_dirs") or []) or (state_values.get("pending_moves") or []) + or (state_values.get("pending_deletes") or []) + or (state_values.get("pending_dir_deletes") or []) ) ): try: @@ -1446,6 +1816,7 @@ async def _stream_agent_events( search_space_id=fallback_commit_search_space_id, created_by_id=fallback_commit_created_by_id, filesystem_mode=fallback_commit_filesystem_mode, + thread_id=fallback_commit_thread_id, dispatch_events=False, ) if delta: @@ -1954,13 +2325,33 @@ async def stream_new_chat( config = { "configurable": configurable, - "recursion_limit": 80, # Increase from default 25 to allow more tool iterations + # Effectively uncapped, matching the agent-level + # ``with_config`` default in ``chat_deepagent.create_agent`` + # and the unbounded ``while(true)`` loop used by OpenCode's + # ``session/processor.ts``. Real circuit-breakers live in + # middleware: ``DoomLoopMiddleware`` (sliding-window tool + # signature check), plus ``enable_tool_call_limit`` / + # ``enable_model_call_limit`` when those flags are set. The + # original LangGraph default of 25 (and our previous 80 + # bump) hit users on legitimate multi-tool plans. + "recursion_limit": 10_000, } # Start the message stream yield streaming_service.format_message_start() yield streaming_service.format_start_step() + # Surface the per-turn correlation id at the very start of the + # stream so the frontend can stamp it onto the in-flight + # assistant message and replay it via ``appendMessage`` + # for durable storage. Tool/action-log events DO carry it later, + # but pure-text turns never produce action-log events; this + # event guarantees the frontend learns the turn id regardless. + yield streaming_service.format_data( + "turn-info", + {"chat_turn_id": stream_result.turn_id}, + ) + # Initial thinking step - analyzing the request if mentioned_surfsense_docs: initial_title = "Analyzing referenced content" @@ -2111,6 +2502,7 @@ async def stream_new_chat( if filesystem_selection else FilesystemMode.CLOUD ), + fallback_commit_thread_id=chat_id, ): if not _first_event_logged: _perf_log.info( @@ -2652,11 +3044,22 @@ async def stream_resume_chat( "request_id": request_id or "unknown", "turn_id": stream_result.turn_id, }, - "recursion_limit": 80, + # See ``stream_new_chat`` above for rationale: effectively + # uncapped to mirror the agent default and OpenCode's + # session loop. Doom-loop / call-limit middleware enforce + # the real ceiling. + "recursion_limit": 10_000, } yield streaming_service.format_message_start() yield streaming_service.format_start_step() + # Same rationale as ``stream_new_chat``: emit the turn id so + # resumed streams can be persisted with their correlation id + # intact. + yield streaming_service.format_data( + "turn-info", + {"chat_turn_id": stream_result.turn_id}, + ) _t_stream_start = time.perf_counter() _first_event_logged = False @@ -2674,6 +3077,7 @@ async def stream_resume_chat( if filesystem_selection else FilesystemMode.CLOUD ), + fallback_commit_thread_id=chat_id, ): if not _first_event_logged: _perf_log.info( diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py b/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py index aad1524c9..8ef1430a9 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py @@ -15,6 +15,17 @@ from app.agents.new_chat.middleware.action_log import ActionLogMiddleware from app.agents.new_chat.tools.registry import ToolDefinition +@dataclass +class _FakeRuntime: + """Minimal stand-in for ``ToolRuntime`` used in unit tests. + + ``ActionLogMiddleware`` reads ``runtime.config['configurable']['turn_id']`` + to populate the new ``chat_turn_id`` column (see migration 135). + """ + + config: dict[str, Any] | None = None + + @dataclass class _FakeRequest: """Minimal stand-in for ToolCallRequest used in unit tests.""" @@ -120,6 +131,9 @@ class TestActionLogMiddlewarePersistence: "args": {"color": "red", "size": 3}, "id": "tc-abc", }, + runtime=_FakeRuntime( + config={"configurable": {"turn_id": "42:1700000000000"}} + ), ) result_msg = ToolMessage(content="ok", tool_call_id="tc-abc", id="msg-1") handler = AsyncMock(return_value=result_msg) @@ -142,6 +156,32 @@ class TestActionLogMiddlewarePersistence: assert row.error is None assert row.reverse_descriptor is None assert row.reversible is False + # Migration 135: ``turn_id`` is the deprecated alias of ``tool_call_id``; + # ``chat_turn_id`` comes from ``runtime.config['configurable']['turn_id']``. + assert row.tool_call_id == "tc-abc" + assert row.turn_id == "tc-abc" + assert row.chat_turn_id == "42:1700000000000" + + @pytest.mark.asyncio + async def test_chat_turn_id_none_when_runtime_missing( + self, patch_get_flags, fake_session_factory + ) -> None: + """``chat_turn_id`` falls back to NULL when ``runtime.config`` is absent.""" + captured, factory = fake_session_factory + mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None) + request = _FakeRequest( + tool_call={"name": "make_widget", "args": {}, "id": "tc-1"}, + runtime=None, + ) + handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc-1")) + with ( + patch_get_flags(_enabled_flags()), + patch("app.db.shielded_async_session", side_effect=lambda: factory()), + ): + await mw.awrap_tool_call(request, handler) + row = captured["rows"][0] + assert row.tool_call_id == "tc-1" + assert row.chat_turn_id is None @pytest.mark.asyncio async def test_writes_row_on_failure_and_reraises( @@ -293,6 +333,76 @@ class TestReverseDescriptor: assert row.reversible is False +class TestActionLogDispatch: + """Verify ``adispatch_custom_event`` fires after commit.""" + + @pytest.mark.asyncio + async def test_dispatches_action_log_event_on_success( + self, patch_get_flags, fake_session_factory + ) -> None: + _captured, factory = fake_session_factory + mw = ActionLogMiddleware(thread_id=42, search_space_id=7, user_id="u1") + request = _FakeRequest( + tool_call={ + "name": "make_widget", + "args": {"color": "red"}, + "id": "tc-evt", + }, + runtime=_FakeRuntime( + config={"configurable": {"turn_id": "42:1700000000000"}} + ), + ) + result_msg = ToolMessage(content="ok", tool_call_id="tc-evt", id="msg-42") + handler = AsyncMock(return_value=result_msg) + + dispatch_mock = AsyncMock() + with ( + patch_get_flags(_enabled_flags()), + patch("app.db.shielded_async_session", side_effect=lambda: factory()), + patch( + "app.agents.new_chat.middleware.action_log.adispatch_custom_event", + dispatch_mock, + ), + ): + await mw.awrap_tool_call(request, handler) + + dispatch_mock.assert_awaited_once() + call_args = dispatch_mock.await_args + assert call_args is not None + assert call_args.args[0] == "action_log" + payload = call_args.args[1] + assert payload["lc_tool_call_id"] == "tc-evt" + assert payload["chat_turn_id"] == "42:1700000000000" + assert payload["tool_name"] == "make_widget" + assert payload["reversible"] is False + assert payload["reverse_descriptor_present"] is False + assert payload["error"] is False + + @pytest.mark.asyncio + async def test_no_dispatch_when_persistence_fails(self, patch_get_flags) -> None: + """If commit fails the dispatch is suppressed (no row to surface).""" + mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None) + request = _FakeRequest( + tool_call={"name": "make_widget", "args": {}, "id": "tc1"} + ) + handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1")) + dispatch_mock = AsyncMock() + + def _exploding_session(): + raise RuntimeError("DB is down") + + with ( + patch_get_flags(_enabled_flags()), + patch("app.db.shielded_async_session", side_effect=_exploding_session), + patch( + "app.agents.new_chat.middleware.action_log.adispatch_custom_event", + dispatch_mock, + ), + ): + await mw.awrap_tool_call(request, handler) + dispatch_mock.assert_not_awaited() + + class TestArgsTruncation: @pytest.mark.asyncio async def test_huge_args_payload_is_truncated( diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_desktop_safety_rules.py b/surfsense_backend/tests/unit/agents/new_chat/test_desktop_safety_rules.py new file mode 100644 index 000000000..653175eab --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_desktop_safety_rules.py @@ -0,0 +1,122 @@ +"""Tests for the desktop-mode safety ruleset. + +In desktop mode the agent operates against the user's real disk with no +revision history, so destructive filesystem operations must require +explicit approval. These tests pin the set of tools that get the ``ask`` +gate so it cannot silently regress. +""" + +from __future__ import annotations + +import pytest + +from app.agents.new_chat.middleware.permission import PermissionMiddleware +from app.agents.new_chat.permissions import ( + Rule, + Ruleset, + aggregate_action, + evaluate_many, +) + +pytestmark = pytest.mark.unit + + +# Mirror the ruleset built inside ``chat_deepagent._build_compiled_agent_blocking`` +# when ``filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER``. Keeping a +# copy here means the rule contract has a focused regression test even when +# the larger graph-build helper is hard to instantiate in unit tests. +DESKTOP_SAFETY_RULESET = Ruleset( + rules=[ + Rule(permission="rm", pattern="*", action="ask"), + Rule(permission="rmdir", pattern="*", action="ask"), + Rule(permission="move_file", pattern="*", action="ask"), + Rule(permission="edit_file", pattern="*", action="ask"), + Rule(permission="write_file", pattern="*", action="ask"), + ], + origin="desktop_safety", +) + +SURFSENSE_DEFAULTS = Ruleset( + rules=[Rule(permission="*", pattern="*", action="allow")], + origin="surfsense_defaults", +) + + +def _action_for(tool_name: str, *rulesets: Ruleset) -> str: + rules = evaluate_many(tool_name, [tool_name], *rulesets) + return aggregate_action(rules) + + +class TestDesktopSafetyRulesGateDestructiveOps: + @pytest.mark.parametrize( + "tool_name", + ["rm", "rmdir", "move_file", "edit_file", "write_file"], + ) + def test_destructive_op_resolves_to_ask(self, tool_name: str) -> None: + # surfsense_defaults says "allow */*"; desktop_safety must override + # because it's layered later (last-match-wins). + action = _action_for(tool_name, SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET) + assert action == "ask", ( + f"{tool_name} must require approval in desktop mode " + f"(no revert path on real disk); got {action!r}" + ) + + @pytest.mark.parametrize( + "tool_name", + ["read_file", "ls", "list_tree", "grep", "glob", "cd", "pwd", "mkdir"], + ) + def test_safe_ops_remain_allowed(self, tool_name: str) -> None: + # Read-only and trivially-reversible tools must NOT get gated — + # otherwise every navigation in desktop mode pops an interrupt. + action = _action_for(tool_name, SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET) + assert action == "allow", ( + f"{tool_name} should not be gated in desktop mode; got {action!r}" + ) + + +class TestDesktopSafetyOverridesAllowDefault: + def test_layer_order_last_match_wins(self) -> None: + # If desktop_safety is layered BEFORE surfsense_defaults, the allow + # default would win and the safety net would be inert. This test + # protects against accidentally swapping the rulesets in + # ``_build_compiled_agent_blocking``. + action = _action_for("rm", DESKTOP_SAFETY_RULESET, SURFSENSE_DEFAULTS) + # Layered "wrong way" — the broad allow now wins. + assert action == "allow" + + # Correct order: defaults < desktop_safety -> ask wins. + action = _action_for("rm", SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET) + assert action == "ask" + + +class TestPermissionMiddlewareIntegration: + def test_middleware_raises_interrupt_for_rm_in_desktop_mode(self) -> None: + from langchain_core.messages import AIMessage + + from app.agents.new_chat.errors import RejectedError + + mw = PermissionMiddleware(rulesets=[SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET]) + # Stub the interrupt to a "reject" decision so we can assert the + # ask path was taken without spinning up the LangGraph runtime. + mw._raise_interrupt = lambda **kw: {"decision_type": "reject"} # type: ignore[assignment] + + state = { + "messages": [ + AIMessage( + content="", + tool_calls=[ + { + "name": "rm", + "args": {"path": "/Users/me/Documents/important.docx"}, + "id": "tc-rm", + } + ], + ) + ] + } + + class _FakeRuntime: + config: dict = {"configurable": {"thread_id": "test"}} + + with pytest.raises(RejectedError): + mw.after_model(state, _FakeRuntime()) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_hitl_auto_approve.py b/surfsense_backend/tests/unit/agents/new_chat/test_hitl_auto_approve.py new file mode 100644 index 000000000..0bbdf37bf --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_hitl_auto_approve.py @@ -0,0 +1,111 @@ +"""Tests for the default auto-approval list in ``hitl.request_approval``. + +These pin the policy that low-stakes connector creation tools (drafts, +new-file creates) skip the HITL interrupt by default. Without this set, +every "draft my newsletter" turn used to fire ~3 interrupts before any +useful work happened. +""" + +from __future__ import annotations + +import pytest + +from app.agents.new_chat.tools.hitl import ( + DEFAULT_AUTO_APPROVED_TOOLS, + HITLResult, + request_approval, +) + +pytestmark = pytest.mark.unit + + +class TestDefaultAutoApprovedToolsList: + def test_set_contains_expected_creation_tools(self) -> None: + # If anyone changes the policy list, we want a single test to + # update so the contract is explicit. Keep this in sync with + # ``hitl.DEFAULT_AUTO_APPROVED_TOOLS``. + expected = { + "create_gmail_draft", + "update_gmail_draft", + "create_notion_page", + "create_confluence_page", + "create_google_drive_file", + "create_dropbox_file", + "create_onedrive_file", + } + assert expected == DEFAULT_AUTO_APPROVED_TOOLS + + def test_set_is_immutable(self) -> None: + # frozenset prevents accidental at-runtime mutation that would + # silently widen the auto-approval surface. + assert isinstance(DEFAULT_AUTO_APPROVED_TOOLS, frozenset) + + def test_send_tools_are_not_auto_approved(self) -> None: + # External-broadcast tools must always prompt. + for tool_name in ( + "send_gmail_email", + "send_discord_message", + "send_teams_message", + "delete_notion_page", + "create_calendar_event", + "delete_calendar_event", + ): + assert tool_name not in DEFAULT_AUTO_APPROVED_TOOLS, ( + f"{tool_name} must remain HITL-gated" + ) + + +class TestRequestApprovalAutoBypass: + def test_auto_approved_tool_skips_interrupt(self) -> None: + # No interrupt mock set up — if the function attempted to call + # ``langgraph.types.interrupt`` it would raise GraphInterrupt. + # The fact that we get a clean HITLResult proves the bypass. + result = request_approval( + action_type="gmail_draft_creation", + tool_name="create_gmail_draft", + params={"to": "alice@example.com", "subject": "hi", "body": "hey"}, + ) + assert isinstance(result, HITLResult) + assert result.rejected is False + assert result.decision_type == "auto_approved" + # Original params are preserved untouched (no user edits possible). + assert result.params == { + "to": "alice@example.com", + "subject": "hi", + "body": "hey", + } + + def test_non_listed_tool_still_attempts_interrupt(self) -> None: + # A tool NOT in the default list must reach ``langgraph.interrupt``. + # Outside a runnable context that call raises a RuntimeError — + # which is exactly the signal we want: the bypass did NOT fire. + with pytest.raises(RuntimeError, match="runnable context"): + request_approval( + action_type="gmail_email_send", + tool_name="send_gmail_email", + params={"to": "alice@example.com", "subject": "hi", "body": "hey"}, + ) + + def test_user_trusted_tools_still_take_precedence(self) -> None: + # ``trusted_tools`` (per-connector "always allow" from MCP/UI) + # was checked BEFORE the default list and must keep working + # for tools outside the default list. + result = request_approval( + action_type="mcp_tool_call", + tool_name="my_custom_mcp_tool", + params={"x": 1}, + trusted_tools=["my_custom_mcp_tool"], + ) + assert result.decision_type == "trusted" + assert result.rejected is False + + def test_auto_approved_overrides_no_trusted_tools(self) -> None: + # When trusted_tools is empty and tool is in the default list, + # we should still bypass — proves the order in request_approval. + result = request_approval( + action_type="notion_page_creation", + tool_name="create_notion_page", + params={"title": "Plan"}, + trusted_tools=[], + ) + assert result.decision_type == "auto_approved" diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_rm_rmdir_cloud.py b/surfsense_backend/tests/unit/agents/new_chat/test_rm_rmdir_cloud.py new file mode 100644 index 000000000..7cabb6524 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_rm_rmdir_cloud.py @@ -0,0 +1,333 @@ +"""Cloud-mode behavior tests for the new ``rm`` and ``rmdir`` filesystem tools. + +The tools build ``Command(update=...)`` payloads that the persistence +middleware applies at end of turn. These tests stub out the backend and +runtime to assert the staging payload shape: + +* ``rm`` queues into ``pending_deletes`` and tombstones state files. +* ``rm`` rejects directories, ``/documents``, root, and the anonymous doc. +* ``rmdir`` queues into ``pending_dir_deletes`` and rejects non-empty dirs. +* ``rmdir`` un-stages a same-turn ``mkdir`` rather than queuing a delete. +* ``rmdir`` refuses to drop the cwd or any of its ancestors. +* ``KBPostgresBackend`` view-helpers honor staged deletes. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock + +import pytest + +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware +from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend + +pytestmark = pytest.mark.unit + + +def _make_middleware(mode: FilesystemMode = FilesystemMode.CLOUD): + middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) + middleware._filesystem_mode = mode + middleware._custom_tool_descriptions = {} + return middleware + + +def _runtime(state: dict[str, Any] | None = None, *, tool_call_id: str = "tc-abc"): + state = state or {} + state.setdefault("cwd", "/documents") + return SimpleNamespace(state=state, tool_call_id=tool_call_id) + + +class _KBBackendStub(KBPostgresBackend): + """Construct-able subclass of :class:`KBPostgresBackend` for tests. + + We bypass the real ``__init__`` (which expects a runtime + DB session) + and inject just the methods the rm/rmdir tools touch. The class + inheritance keeps ``isinstance(backend, KBPostgresBackend)`` checks + inside the tools happy, which is what gates them from the desktop + code path. + """ + + def __init__(self, *, children=None, file_data=None) -> None: + self.als_info = AsyncMock(return_value=children or []) + self._load_file_data = AsyncMock( + return_value=(file_data, 17) if file_data is not None else None + ) + + +def _make_backend_stub(*, children=None, file_data=None) -> KBPostgresBackend: + return _KBBackendStub(children=children, file_data=file_data) + + +def _bind_backend(middleware, backend): + """Inject a backend resolver onto the middleware test instance.""" + middleware._get_backend = lambda runtime: backend + return backend + + +# --------------------------------------------------------------------------- +# rm +# --------------------------------------------------------------------------- + + +class TestRmStaging: + @pytest.mark.asyncio + async def test_stages_delete_and_tombstones_state(self): + m = _make_middleware() + _bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]})) + runtime = _runtime( + { + "cwd": "/documents", + "files": {"/documents/notes.md": {"content": ["hello"]}}, + "doc_id_by_path": {"/documents/notes.md": 17}, + }, + tool_call_id="tc-1", + ) + + tool = m._create_rm_tool() + result = await tool.coroutine("/documents/notes.md", runtime=runtime) + + assert hasattr(result, "update"), f"expected Command, got {result!r}" + update = result.update + assert update["pending_deletes"] == [ + {"path": "/documents/notes.md", "tool_call_id": "tc-1"} + ] + assert update["files"] == {"/documents/notes.md": None} + assert update["doc_id_by_path"] == {"/documents/notes.md": None} + + @pytest.mark.asyncio + async def test_rejects_documents_root(self): + m = _make_middleware() + runtime = _runtime() + tool = m._create_rm_tool() + result = await tool.coroutine("/documents", runtime=runtime) + assert isinstance(result, str) + assert "refusing to rm" in result + + @pytest.mark.asyncio + async def test_rejects_root(self): + m = _make_middleware() + runtime = _runtime() + tool = m._create_rm_tool() + result = await tool.coroutine("/", runtime=runtime) + assert isinstance(result, str) + assert "refusing to rm" in result + + @pytest.mark.asyncio + async def test_rejects_directory_via_staged_dirs(self): + m = _make_middleware() + runtime = _runtime( + { + "staged_dirs": ["/documents/team-x"], + } + ) + tool = m._create_rm_tool() + result = await tool.coroutine("/documents/team-x", runtime=runtime) + assert isinstance(result, str) + assert "directory" in result.lower() + assert "rmdir" in result + + @pytest.mark.asyncio + async def test_rejects_directory_via_listing(self): + m = _make_middleware() + _bind_backend( + m, + _make_backend_stub( + children=[{"path": "/documents/foo/x.md", "is_dir": False}] + ), + ) + runtime = _runtime() + tool = m._create_rm_tool() + result = await tool.coroutine("/documents/foo", runtime=runtime) + assert isinstance(result, str) + assert "directory" in result.lower() + + @pytest.mark.asyncio + async def test_rejects_anonymous_doc(self): + m = _make_middleware() + runtime = _runtime( + { + "kb_anon_doc": { + "path": "/documents/uploaded.xml", + "title": "uploaded", + "content": "", + "chunks": [], + } + } + ) + tool = m._create_rm_tool() + result = await tool.coroutine("/documents/uploaded.xml", runtime=runtime) + assert isinstance(result, str) + assert "read-only" in result + + @pytest.mark.asyncio + async def test_drops_path_from_dirty_paths(self): + m = _make_middleware() + _bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]})) + runtime = _runtime( + { + "files": {"/documents/notes.md": {"content": ["x"]}}, + "doc_id_by_path": {"/documents/notes.md": 17}, + "dirty_paths": ["/documents/notes.md"], + } + ) + tool = m._create_rm_tool() + result = await tool.coroutine("/documents/notes.md", runtime=runtime) + update = result.update + # First element is _CLEAR sentinel; the rest must NOT contain the + # rm'd path. + dirty = update.get("dirty_paths") or [] + assert "/documents/notes.md" not in dirty[1:] + + +# --------------------------------------------------------------------------- +# rmdir +# --------------------------------------------------------------------------- + + +class TestRmdirStaging: + @pytest.mark.asyncio + async def test_stages_dir_delete_when_empty_and_db_backed(self): + m = _make_middleware() + backend = _bind_backend(m, _make_backend_stub(children=[])) + # Override _load_file_data to return None (folder, not a file) and + # parent listing to claim the folder exists. + backend._load_file_data = AsyncMock(return_value=None) + backend.als_info = AsyncMock( + side_effect=[ + [], # children of /documents/proj + [ + {"path": "/documents/proj", "is_dir": True}, + ], # parent listing + ] + ) + runtime = _runtime( + { + "cwd": "/documents", + }, + tool_call_id="tc-rd", + ) + + tool = m._create_rmdir_tool() + result = await tool.coroutine("/documents/proj", runtime=runtime) + + assert hasattr(result, "update") + update = result.update + assert update["pending_dir_deletes"] == [ + {"path": "/documents/proj", "tool_call_id": "tc-rd"} + ] + + @pytest.mark.asyncio + async def test_rejects_non_empty(self): + m = _make_middleware() + _bind_backend( + m, + _make_backend_stub( + children=[{"path": "/documents/proj/x.md", "is_dir": False}] + ), + ) + runtime = _runtime() + tool = m._create_rmdir_tool() + result = await tool.coroutine("/documents/proj", runtime=runtime) + assert isinstance(result, str) + assert "not empty" in result + + @pytest.mark.asyncio + async def test_unstages_same_turn_mkdir(self): + m = _make_middleware() + _bind_backend(m, _make_backend_stub(children=[])) + runtime = _runtime( + { + "cwd": "/documents", + "staged_dirs": ["/documents/scratch"], + }, + tool_call_id="tc-rd", + ) + tool = m._create_rmdir_tool() + result = await tool.coroutine("/documents/scratch", runtime=runtime) + + assert hasattr(result, "update") + update = result.update + assert "pending_dir_deletes" not in update + # _CLEAR sentinel + remaining items (in this case, none). + staged_after = update["staged_dirs"] + assert staged_after[0] == "\x00__SURFSENSE_FILESYSTEM_CLEAR__\x00" + assert "/documents/scratch" not in staged_after[1:] + + @pytest.mark.asyncio + async def test_rejects_root(self): + m = _make_middleware() + runtime = _runtime() + tool = m._create_rmdir_tool() + for victim in ("/", "/documents"): + result = await tool.coroutine(victim, runtime=runtime) + assert isinstance(result, str) + assert "refusing to rmdir" in result + + @pytest.mark.asyncio + async def test_rejects_cwd(self): + m = _make_middleware() + runtime = _runtime({"cwd": "/documents/proj"}) + tool = m._create_rmdir_tool() + result = await tool.coroutine("/documents/proj", runtime=runtime) + assert isinstance(result, str) + assert "cwd" in result.lower() + + @pytest.mark.asyncio + async def test_rejects_ancestor_of_cwd(self): + m = _make_middleware() + runtime = _runtime({"cwd": "/documents/proj/sub"}) + tool = m._create_rmdir_tool() + result = await tool.coroutine("/documents/proj", runtime=runtime) + assert isinstance(result, str) + assert "cwd" in result.lower() + + @pytest.mark.asyncio + async def test_rejects_files(self): + m = _make_middleware() + _bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]})) + runtime = _runtime() + tool = m._create_rmdir_tool() + result = await tool.coroutine("/documents/notes.md", runtime=runtime) + assert isinstance(result, str) + assert "is a file" in result + + +# --------------------------------------------------------------------------- +# KBPostgresBackend view filter +# --------------------------------------------------------------------------- + + +class TestKBPostgresBackendDeleteFilter: + """als_info / glob / grep should suppress paths queued for delete.""" + + def _make_backend(self, state: dict[str, Any]) -> KBPostgresBackend: + runtime = SimpleNamespace(state=state) + backend = KBPostgresBackend(search_space_id=1, runtime=runtime) + return backend + + def test_pending_filesystem_view_returns_deleted_paths(self): + backend = self._make_backend( + { + "pending_deletes": [ + {"path": "/documents/x.md", "tool_call_id": "t1"}, + ], + "pending_dir_deletes": [ + {"path": "/documents/d1", "tool_call_id": "t2"}, + ], + } + ) + removed, alias, deleted_dirs = backend._pending_filesystem_view({}) + assert "/documents/x.md" in removed + assert "/documents/d1" in deleted_dirs + assert alias == {} + + def test_dir_suppressed_covers_descendants(self): + backend = self._make_backend({}) + deleted_dirs = {"/documents/d"} + assert backend._is_dir_suppressed("/documents/d", deleted_dirs) + assert backend._is_dir_suppressed("/documents/d/x.md", deleted_dirs) + assert backend._is_dir_suppressed("/documents/d/sub/y.md", deleted_dirs) + assert not backend._is_dir_suppressed("/documents/other.md", deleted_dirs) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_state_reducers.py b/surfsense_backend/tests/unit/agents/new_chat/test_state_reducers.py index 3caeb9a34..185753990 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/test_state_reducers.py +++ b/surfsense_backend/tests/unit/agents/new_chat/test_state_reducers.py @@ -98,10 +98,54 @@ class TestInitialFilesystemState: state = _initial_filesystem_state() assert state["cwd"] == "/documents" assert state["staged_dirs"] == [] + assert state["staged_dir_tool_calls"] == {} assert state["pending_moves"] == [] + assert state["pending_deletes"] == [] + assert state["pending_dir_deletes"] == [] assert state["doc_id_by_path"] == {} assert state["dirty_paths"] == [] + assert state["dirty_path_tool_calls"] == {} assert state["kb_priority"] == [] assert state["kb_matched_chunk_ids"] == {} assert state["kb_anon_doc"] is None assert state["tree_version"] == 0 + + +class TestMultiEditSamePathCoalescing: + """Multi-edit-same-path turns must coalesce into ONE binding record. + + The persistence body uses ``dirty_path_tool_calls[path]`` to find the + tool_call_id that produced the current state on disk. Because + ``dirty_paths`` dedupes via :func:`_add_unique_reducer` the second + edit doesn't append a new path entry — and because + ``_dict_merge_with_tombstones_reducer`` lets the right-hand side + overwrite, the LATEST tool_call_id wins. That's the correct behavior + for snapshotting: revert restores to the pre-mutation state, and + multiple back-to-back edits in one turn coalesce into a single + revisible op (the user sees ONE Revert button per turn-per-path, + not N). + """ + + def test_dirty_paths_dedupes_repeated_writes(self): + # ``_add_unique_reducer`` is applied to ``dirty_paths``. Two writes + # to the same path produce one entry, not two. + first = _add_unique_reducer([], ["/documents/a.md"]) + second = _add_unique_reducer(first, ["/documents/a.md"]) + assert second == ["/documents/a.md"] + + def test_dirty_path_tool_calls_keeps_latest_tool_call_id(self): + # First write tags the path with tcid-1. + merged = _dict_merge_with_tombstones_reducer({}, {"/documents/a.md": "tcid-1"}) + # Second write to the same path tags it with tcid-2 (latest wins). + merged = _dict_merge_with_tombstones_reducer( + merged, {"/documents/a.md": "tcid-2"} + ) + assert merged == {"/documents/a.md": "tcid-2"} + + def test_rm_tombstones_dirty_path_tool_call(self): + # ``rm`` writes ``{path: None}`` into dirty_path_tool_calls to + # prevent a stale binding from leaking past the delete. + merged = _dict_merge_with_tombstones_reducer( + {"/documents/a.md": "tcid-1"}, {"/documents/a.md": None} + ) + assert merged == {} diff --git a/surfsense_backend/tests/unit/db/__init__.py b/surfsense_backend/tests/unit/db/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/db/test_relax_revision_fks_migration.py b/surfsense_backend/tests/unit/db/test_relax_revision_fks_migration.py new file mode 100644 index 000000000..82c299488 --- /dev/null +++ b/surfsense_backend/tests/unit/db/test_relax_revision_fks_migration.py @@ -0,0 +1,83 @@ +"""Smoke test for the ``134_relax_revision_fks`` Alembic migration. + +A full apply/rollback test would require a live Postgres; here we verify +the migration module's static contract: + +* The chain wires it as a successor of ``133_drop_documents_content_hash_unique``. +* ``upgrade()`` declares two FK creations with ``ondelete='SET NULL'`` + (one for ``document_revisions.document_id``, one for + ``folder_revisions.folder_id``). +* ``downgrade()`` re-establishes ``ondelete='CASCADE'`` after draining + orphaned revisions. + +If any of these invariants regress the snapshot/revert pipeline silently +loses the ability to undo ``rm`` / ``rmdir`` on environments that ran the +migration "down" or never ran it at all. +""" + +from __future__ import annotations + +import importlib.util +import inspect +from pathlib import Path + +import pytest + +pytestmark = pytest.mark.unit + + +_MIGRATION_PATH = ( + Path(__file__).resolve().parents[3] + / "alembic" + / "versions" + / "134_relax_revision_fks.py" +) + + +def _load_migration(): + """Load the migration module by file path (no package import needed).""" + spec = importlib.util.spec_from_file_location("_migration_134", _MIGRATION_PATH) + assert spec and spec.loader, "could not load migration spec" + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_migration_chain_revision_ids() -> None: + module = _load_migration() + # The migration file uses short numeric revision IDs to match the + # in-tree convention (cf. ``133`` -> ``134``); the ``134_.py`` + # filename is documentation, not the canonical revision string. + assert getattr(module, "revision", None) == "134" + assert getattr(module, "down_revision", None) == "133" + + +def test_migration_exposes_upgrade_and_downgrade() -> None: + module = _load_migration() + upgrade = getattr(module, "upgrade", None) + downgrade = getattr(module, "downgrade", None) + assert callable(upgrade), "upgrade() is required" + assert callable(downgrade), "downgrade() is required" + + +def test_upgrade_creates_set_null_fks_for_both_revision_tables() -> None: + module = _load_migration() + src = inspect.getsource(module.upgrade) + assert "document_revisions" in src + assert "folder_revisions" in src + # Both new FKs MUST be ON DELETE SET NULL — that's the entire point + # of the migration: snapshots must outlive their parent row. + assert src.count('ondelete="SET NULL"') >= 2 + # And the ``document_id`` / ``folder_id`` columns become nullable. + assert "nullable=True" in src + + +def test_downgrade_drains_orphans_then_restores_cascade() -> None: + module = _load_migration() + src = inspect.getsource(module.downgrade) + # Drain orphaned rows BEFORE we can re-impose NOT NULL. + assert "DELETE FROM document_revisions WHERE document_id IS NULL" in src + assert "DELETE FROM folder_revisions WHERE folder_id IS NULL" in src + # Then restore the original CASCADE/NOT NULL contract. + assert src.count('ondelete="CASCADE"') >= 2 + assert "nullable=False" in src diff --git a/surfsense_backend/tests/unit/middleware/test_filesystem_middleware.py b/surfsense_backend/tests/unit/middleware/test_filesystem_middleware.py index c2e304399..70430f4ca 100644 --- a/surfsense_backend/tests/unit/middleware/test_filesystem_middleware.py +++ b/surfsense_backend/tests/unit/middleware/test_filesystem_middleware.py @@ -168,6 +168,8 @@ class TestModeSpecificPrompts: "edit_file", "move_file", "mkdir", + "rm", + "rmdir", "list_tree", "grep", ): @@ -182,6 +184,8 @@ class TestModeSpecificPrompts: "edit_file", "move_file", "mkdir", + "rm", + "rmdir", "list_tree", "grep", ): @@ -190,6 +194,18 @@ class TestModeSpecificPrompts: assert "/documents/" not in text, f"{name} mentions cloud namespace" assert "temp_" not in text, f"{name} mentions cloud temp_ semantics" + def test_cloud_descs_include_rm_and_rmdir(self): + descs = _build_tool_descriptions(FilesystemMode.CLOUD) + assert "rm" in descs and "rmdir" in descs + assert "Deletes a single file" in descs["rm"] + assert "Deletes an empty directory" in descs["rmdir"] + assert "rmdir" in descs["rmdir"] and "POSIX" in descs["rmdir"] + + def test_desktop_descs_warn_about_irreversibility(self): + descs = _build_tool_descriptions(FilesystemMode.DESKTOP_LOCAL_FOLDER) + assert "NOT reversible" in descs["rm"] + assert "NOT reversible" in descs["rmdir"] + def test_sandbox_addendum_appended_when_available(self): prompt = _build_filesystem_system_prompt( FilesystemMode.CLOUD, sandbox_available=True diff --git a/surfsense_backend/tests/unit/middleware/test_kb_persistence_revisions.py b/surfsense_backend/tests/unit/middleware/test_kb_persistence_revisions.py new file mode 100644 index 000000000..feca23d27 --- /dev/null +++ b/surfsense_backend/tests/unit/middleware/test_kb_persistence_revisions.py @@ -0,0 +1,309 @@ +"""Unit tests for the kb_persistence snapshot helpers. + +The full ``commit_staged_filesystem_state`` body exercises a real session +in integration tests; here we verify the building blocks used by the +snapshot/revert pipeline: + +* ``_find_action_ids_batch`` issues a SINGLE query for N tool_call_ids + (regression guard against the N+1 lookup pattern). +* ``_mark_action_reversible`` is a no-op when ``action_id`` is ``None``. +* ``_doc_revision_payload`` and ``_load_chunks_for_snapshot`` produce the + shape the snapshot helpers consume. + +These tests use ``MagicMock`` / ``AsyncMock`` against a fake session so +the assertions run in milliseconds and don't require Postgres. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.agents.new_chat.middleware import kb_persistence + +pytestmark = pytest.mark.unit + + +class _FakeResult: + def __init__(self, rows: list[Any] | None = None, scalar: Any = None) -> None: + self._rows = rows or [] + self._scalar = scalar + + def all(self) -> list[Any]: + return list(self._rows) + + def scalar_one_or_none(self) -> Any: + return self._scalar + + +class _FakeSession: + def __init__(self) -> None: + self.execute = AsyncMock() + + +@pytest.mark.asyncio +async def test_find_action_ids_batch_issues_single_query() -> None: + """The lookup MUST be a single ``IN (...)`` SELECT, not N selects.""" + session = _FakeSession() + session.execute.return_value = _FakeResult( + rows=[ + MagicMock(id=11, tool_call_id="tc-a"), + MagicMock(id=22, tool_call_id="tc-b"), + MagicMock(id=33, tool_call_id="tc-c"), + ] + ) + + mapping = await kb_persistence._find_action_ids_batch( + session, # type: ignore[arg-type] + thread_id=1, + tool_call_ids={"tc-a", "tc-b", "tc-c"}, + ) + + assert mapping == {"tc-a": 11, "tc-b": 22, "tc-c": 33} + assert session.execute.await_count == 1, ( + "Snapshot binding must batch into ONE query; got " + f"{session.execute.await_count} (regression: N+1 lookup pattern)." + ) + + +@pytest.mark.asyncio +async def test_find_action_ids_batch_short_circuits_when_thread_id_missing() -> None: + session = _FakeSession() + mapping = await kb_persistence._find_action_ids_batch( + session, # type: ignore[arg-type] + thread_id=None, + tool_call_ids={"tc-a"}, + ) + assert mapping == {} + assert session.execute.await_count == 0 + + +@pytest.mark.asyncio +async def test_find_action_ids_batch_short_circuits_when_no_calls() -> None: + session = _FakeSession() + mapping = await kb_persistence._find_action_ids_batch( + session, # type: ignore[arg-type] + thread_id=42, + tool_call_ids=set(), + ) + assert mapping == {} + assert session.execute.await_count == 0 + + +@pytest.mark.asyncio +async def test_mark_action_reversible_is_noop_for_null_id() -> None: + session = _FakeSession() + await kb_persistence._mark_action_reversible(session, action_id=None) # type: ignore[arg-type] + assert session.execute.await_count == 0 + + +@pytest.mark.asyncio +async def test_mark_action_reversible_runs_update_for_real_id() -> None: + session = _FakeSession() + await kb_persistence._mark_action_reversible(session, action_id=99) # type: ignore[arg-type] + assert session.execute.await_count == 1 + + +def test_doc_revision_payload_captures_metadata_virtual_path() -> None: + """Snapshot helpers must capture ``metadata_before`` for revert reuse.""" + doc = MagicMock() + doc.content = "body" + doc.title = "notes.md" + doc.folder_id = 7 + doc.document_metadata = {"virtual_path": "/documents/team/notes.md"} + + payload = kb_persistence._doc_revision_payload( + doc, chunks_before=[{"content": "x"}] + ) + + assert payload["title_before"] == "notes.md" + assert payload["folder_id_before"] == 7 + assert payload["content_before"] == "body" + assert payload["chunks_before"] == [{"content": "x"}] + assert payload["metadata_before"] == {"virtual_path": "/documents/team/notes.md"} + + +def test_doc_revision_payload_handles_missing_metadata() -> None: + doc = MagicMock() + doc.content = "" + doc.title = "" + doc.folder_id = None + doc.document_metadata = None + payload = kb_persistence._doc_revision_payload(doc) + assert payload["metadata_before"] is None + + +@pytest.mark.asyncio +async def test_load_chunks_for_snapshot_returns_content_only() -> None: + """Snapshot chunks intentionally omit embeddings (regenerated on revert).""" + session = _FakeSession() + session.execute.return_value = _FakeResult( + rows=[ + MagicMock(content="alpha"), + MagicMock(content="beta"), + ] + ) + chunks = await kb_persistence._load_chunks_for_snapshot( + session, + doc_id=42, # type: ignore[arg-type] + ) + assert chunks == [{"content": "alpha"}, {"content": "beta"}] + + +# --------------------------------------------------------------------------- +# Deferred reversibility-flip dispatches. +# +# The snapshot helpers used to dispatch ``action_log_updated`` directly +# from inside the SAVEPOINT block. That meant the SSE side-channel +# could tell the UI a row was reversible while the OUTER transaction +# was still pending — and if the outer commit failed, every SAVEPOINT +# rolled back too, leaving the UI in a state inconsistent with +# durable storage. The deferred-dispatch contract fixes that: +# +# • when a ``deferred_dispatches`` list is provided, the helper +# APPENDS the action_id and does NOT dispatch; +# • the caller (``commit_staged_filesystem_state``) flushes the list +# only AFTER ``await session.commit()`` succeeds; on rollback it +# clears the list so nothing is emitted. +# --------------------------------------------------------------------------- + + +class _NestedCtx: + """Async context manager mimicking ``session.begin_nested()``.""" + + async def __aenter__(self) -> _NestedCtx: + return self + + async def __aexit__(self, exc_type, exc, tb) -> bool: + return False + + +@pytest.mark.asyncio +async def test_pre_write_snapshot_defers_dispatch_when_list_provided( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Helpers MUST queue dispatches when ``deferred_dispatches`` is set.""" + session = MagicMock() + session.begin_nested = MagicMock(return_value=_NestedCtx()) + session.execute = AsyncMock(return_value=_FakeResult(rows=[])) + session.flush = AsyncMock() + + def _add(rev: Any) -> None: + rev.id = 17 + + session.add = MagicMock(side_effect=_add) + + dispatched: list[int] = [] + + async def _fake_dispatch(action_id: int | None) -> None: + if action_id is not None: + dispatched.append(int(action_id)) + + monkeypatch.setattr( + kb_persistence, "_dispatch_reversibility_update", _fake_dispatch + ) + + deferred: list[int] = [] + doc = MagicMock(id=99, document_metadata={"virtual_path": "/documents/x.md"}) + doc.title = "x.md" + doc.folder_id = None + doc.content = "body" + + rev_id = await kb_persistence._snapshot_document_pre_write( + session, # type: ignore[arg-type] + doc=doc, + action_id=42, + search_space_id=1, + turn_id="t-1", + deferred_dispatches=deferred, + ) + + assert rev_id == 17 + # Inline dispatch must NOT have fired; the action_id is queued. + assert dispatched == [] + assert deferred == [42] + + +@pytest.mark.asyncio +async def test_pre_write_snapshot_dispatches_inline_when_list_omitted( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Direct callers (no outer transaction) keep the legacy inline dispatch.""" + session = MagicMock() + session.begin_nested = MagicMock(return_value=_NestedCtx()) + session.execute = AsyncMock(return_value=_FakeResult(rows=[])) + session.flush = AsyncMock() + + def _add(rev: Any) -> None: + rev.id = 7 + + session.add = MagicMock(side_effect=_add) + + dispatched: list[int] = [] + + async def _fake_dispatch(action_id: int | None) -> None: + if action_id is not None: + dispatched.append(int(action_id)) + + monkeypatch.setattr( + kb_persistence, "_dispatch_reversibility_update", _fake_dispatch + ) + + doc = MagicMock(id=11, document_metadata={"virtual_path": "/documents/y.md"}) + doc.title = "y.md" + doc.folder_id = None + doc.content = "body" + + await kb_persistence._snapshot_document_pre_write( + session, # type: ignore[arg-type] + doc=doc, + action_id=88, + search_space_id=1, + turn_id="t-1", + # No deferred_dispatches arg — fall back to inline dispatch. + ) + + assert dispatched == [88] + + +@pytest.mark.asyncio +async def test_pre_mkdir_snapshot_defers_dispatch_when_list_provided( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Folder mkdir snapshots honour the same deferred-dispatch contract.""" + session = MagicMock() + session.begin_nested = MagicMock(return_value=_NestedCtx()) + session.execute = AsyncMock() # _mark_action_reversible calls execute + session.flush = AsyncMock() + + def _add(rev: Any) -> None: + rev.id = 3 + + session.add = MagicMock(side_effect=_add) + + dispatched: list[int] = [] + + async def _fake_dispatch(action_id: int | None) -> None: + if action_id is not None: + dispatched.append(int(action_id)) + + monkeypatch.setattr( + kb_persistence, "_dispatch_reversibility_update", _fake_dispatch + ) + + deferred: list[int] = [] + folder = MagicMock(id=2, name="f", parent_id=None, position="a0") + + await kb_persistence._snapshot_folder_pre_mkdir( + session, # type: ignore[arg-type] + folder=folder, + action_id=55, + search_space_id=1, + turn_id="t-1", + deferred_dispatches=deferred, + ) + + assert dispatched == [] + assert deferred == [55] diff --git a/surfsense_backend/tests/unit/middleware/test_knowledge_tree.py b/surfsense_backend/tests/unit/middleware/test_knowledge_tree.py new file mode 100644 index 000000000..caaec3114 --- /dev/null +++ b/surfsense_backend/tests/unit/middleware/test_knowledge_tree.py @@ -0,0 +1,139 @@ +"""Unit tests for ``KnowledgeTreeMiddleware`` rendering. + +The empty-folder marker is critical UX: without it, the LLM cannot +distinguish a leaf folder containing one document from a leaf folder +that has no descendants at all, and ends up firing ``rmdir`` on +non-empty folders. These tests pin the rendering contract so that +contract cannot silently regress. +""" + +from __future__ import annotations + +from app.agents.new_chat.middleware.knowledge_tree import KnowledgeTreeMiddleware +from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT + + +def _compute(folder_paths: list[str], doc_paths: list[str]) -> set[str]: + return KnowledgeTreeMiddleware._compute_non_empty_folders(folder_paths, doc_paths) + + +class TestComputeNonEmptyFolders: + def test_folder_with_direct_document_is_non_empty(self): + folder_paths = [f"{DOCUMENTS_ROOT}/Travel/Boarding Pass"] + doc_paths = [ + f"{DOCUMENTS_ROOT}/Travel/Boarding Pass/southwest.pdf.xml", + ] + non_empty = _compute(folder_paths, doc_paths) + assert f"{DOCUMENTS_ROOT}/Travel/Boarding Pass" in non_empty + + def test_truly_empty_leaf_folder_is_not_non_empty(self): + folder_paths = [f"{DOCUMENTS_ROOT}/Travel/Boarding Pass"] + doc_paths: list[str] = [] + assert _compute(folder_paths, doc_paths) == set() + + def test_documents_propagate_up_to_all_ancestors(self): + folder_paths = [ + f"{DOCUMENTS_ROOT}/A", + f"{DOCUMENTS_ROOT}/A/B", + f"{DOCUMENTS_ROOT}/A/B/C", + ] + doc_paths = [f"{DOCUMENTS_ROOT}/A/B/C/file.xml"] + non_empty = _compute(folder_paths, doc_paths) + assert non_empty == { + f"{DOCUMENTS_ROOT}/A", + f"{DOCUMENTS_ROOT}/A/B", + f"{DOCUMENTS_ROOT}/A/B/C", + } + + def test_chain_with_subfolders_marks_only_leaf_empty(self): + # POSIX-like semantic: a folder is "empty" only if it has no + # immediate children (docs OR sub-folders). The model needs this + # because parallel ``rmdir`` calls all see the same starting state, + # so trying to rmdir a parent before its children is never safe. + folder_paths = [ + f"{DOCUMENTS_ROOT}/X", + f"{DOCUMENTS_ROOT}/X/Y", + f"{DOCUMENTS_ROOT}/X/Y/Z", + ] + non_empty = _compute(folder_paths, []) + # Only ``X/Y/Z`` (the leaf) is empty. ``X`` and ``X/Y`` each have a + # sub-folder child, so they are non-empty and should NOT carry the + # ``(empty)`` marker. + assert non_empty == {f"{DOCUMENTS_ROOT}/X", f"{DOCUMENTS_ROOT}/X/Y"} + + def test_sibling_with_doc_does_not_mark_other_sibling_non_empty(self): + # Mirrors a real DB layout where every intermediate folder is + # materialized in the ``folders`` table. + folder_paths = [ + f"{DOCUMENTS_ROOT}/Travel", + f"{DOCUMENTS_ROOT}/Travel/Boarding Pass", + f"{DOCUMENTS_ROOT}/Travel/Notes", + ] + doc_paths = [f"{DOCUMENTS_ROOT}/Travel/Notes/itinerary.xml"] + non_empty = _compute(folder_paths, doc_paths) + # ``Travel`` is non-empty because it has children, ``Notes`` is non-empty + # because of the doc, but ``Boarding Pass`` (sibling leaf) is empty. + assert f"{DOCUMENTS_ROOT}/Travel" in non_empty + assert f"{DOCUMENTS_ROOT}/Travel/Notes" in non_empty + assert f"{DOCUMENTS_ROOT}/Travel/Boarding Pass" not in non_empty + + +class TestFormatTreeRendering: + """Integration check: empty leaf gets ``(empty)`` marker; non-empty doesn't.""" + + def _render( + self, + folder_paths: list[str], + doc_specs: list[dict], + ) -> str: + from app.agents.new_chat.path_resolver import PathIndex + + index = PathIndex( + folder_paths={i + 1: p for i, p in enumerate(folder_paths)}, + ) + + class _Row: + def __init__(self, **kw): + self.__dict__.update(kw) + + docs = [_Row(**spec) for spec in doc_specs] + + mw = KnowledgeTreeMiddleware( + search_space_id=1, + filesystem_mode=None, # type: ignore[arg-type] + ) + return mw._format_tree(index, docs) + + def test_renders_empty_marker_only_for_truly_empty_folders(self): + # Reproduces the failure scenario from the bug report: + # ``Boarding Pass`` is empty (its only doc was just deleted), while + # ``Tax Returns`` still has ``federal.pdf``. All intermediate + # folders are present in the index, mirroring the real DB layout. + folder_paths = [ + "/documents/File Upload", + "/documents/File Upload/2026-04-08", + "/documents/File Upload/2026-04-08/Travel", + "/documents/File Upload/2026-04-08/Travel/Boarding Pass", + "/documents/File Upload/2026-04-15", + "/documents/File Upload/2026-04-15/Finance", + "/documents/File Upload/2026-04-15/Finance/Tax Returns", + ] + tax_returns_folder_id = ( + folder_paths.index("/documents/File Upload/2026-04-15/Finance/Tax Returns") + + 1 + ) + rendered = self._render( + folder_paths=folder_paths, + doc_specs=[ + { + "id": 100, + "title": "federal.pdf", + "folder_id": tax_returns_folder_id, + }, + ], + ) + assert "Boarding Pass/ (empty)" in rendered + assert "Tax Returns/ (empty)" not in rendered + # Intermediate ancestors of the doc must NOT be marked empty. + assert "Finance/ (empty)" not in rendered + assert "2026-04-15/ (empty)" not in rendered diff --git a/surfsense_backend/tests/unit/middleware/test_local_folder_backend.py b/surfsense_backend/tests/unit/middleware/test_local_folder_backend.py index 7dfc68402..6e81ecf8e 100644 --- a/surfsense_backend/tests/unit/middleware/test_local_folder_backend.py +++ b/surfsense_backend/tests/unit/middleware/test_local_folder_backend.py @@ -69,3 +69,74 @@ def test_local_backend_write_rejects_missing_parent_directory(tmp_path: Path): assert write.error is not None assert "parent directory" in write.error assert not (tmp_path / "tempoo").exists() + + +def test_local_backend_delete_file_success(tmp_path: Path): + backend = LocalFolderBackend(str(tmp_path)) + (tmp_path / "delete-me.md").write_text("bye") + + res = backend.delete_file("/delete-me.md") + assert res.error is None + assert res.path == "/delete-me.md" + assert not (tmp_path / "delete-me.md").exists() + + +def test_local_backend_delete_file_rejects_directory(tmp_path: Path): + backend = LocalFolderBackend(str(tmp_path)) + (tmp_path / "subdir").mkdir() + + res = backend.delete_file("/subdir") + assert res.error is not None + assert "directory" in res.error + assert (tmp_path / "subdir").exists() + + +def test_local_backend_delete_file_missing_returns_error(tmp_path: Path): + backend = LocalFolderBackend(str(tmp_path)) + + res = backend.delete_file("/nope.md") + assert res.error is not None + assert "not found" in res.error + + +def test_local_backend_rmdir_success(tmp_path: Path): + backend = LocalFolderBackend(str(tmp_path)) + (tmp_path / "empty").mkdir() + + res = backend.rmdir("/empty") + assert res.error is None + assert res.path == "/empty" + assert not (tmp_path / "empty").exists() + + +def test_local_backend_rmdir_rejects_non_empty(tmp_path: Path): + backend = LocalFolderBackend(str(tmp_path)) + (tmp_path / "withkid").mkdir() + (tmp_path / "withkid" / "child.md").write_text("x") + + res = backend.rmdir("/withkid") + assert res.error is not None + assert "not empty" in res.error + assert (tmp_path / "withkid" / "child.md").exists() + + +def test_local_backend_rmdir_rejects_file(tmp_path: Path): + backend = LocalFolderBackend(str(tmp_path)) + (tmp_path / "f.md").write_text("x") + + res = backend.rmdir("/f.md") + assert res.error is not None + assert "not a directory" in res.error + + +def test_local_backend_rmdir_rejects_root(tmp_path: Path): + """``rmdir /`` MUST fail. The exact error wording comes from + ``_resolve_virtual`` (root resolves to outside the sandbox); what + matters is that the call returns an error and does NOT delete the + sandbox root on disk.""" + backend = LocalFolderBackend(str(tmp_path)) + + res = backend.rmdir("/") + assert res.error is not None + assert "Invalid path" in res.error or "root" in res.error + assert tmp_path.exists() diff --git a/surfsense_backend/tests/unit/routes/__init__.py b/surfsense_backend/tests/unit/routes/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/routes/test_regenerate_from_message_id.py b/surfsense_backend/tests/unit/routes/test_regenerate_from_message_id.py new file mode 100644 index 000000000..709014d55 --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_regenerate_from_message_id.py @@ -0,0 +1,143 @@ +"""Unit tests for the edit-from-arbitrary-position helpers inside ``new_chat_routes``. + +The regenerate route's edit-from-position path introduces: +* ``_find_pre_turn_checkpoint_id`` — walks LangGraph checkpoint tuples + newest-first and picks the first one whose ``metadata["turn_id"]`` + differs from the edited turn. That checkpoint is the rewind target + (state immediately before the edited turn started). +* ``RegenerateRequest`` accepts ``from_message_id`` + ``revert_actions`` + with a validator that prevents callers from requesting a revert pass + without specifying which turn to roll back. + +These are pure-Python helpers that don't need a live DB, so we exercise +them with a small ``CheckpointTuple``-shaped namespace and direct +schema instantiation. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from app.routes.new_chat_routes import _find_pre_turn_checkpoint_id +from app.schemas.new_chat import RegenerateRequest + + +def _cp(checkpoint_id: str, turn_id: str | None) -> SimpleNamespace: + """Build a fake ``CheckpointTuple`` with the metadata shape we read.""" + return SimpleNamespace( + config={"configurable": {"checkpoint_id": checkpoint_id}}, + metadata={"turn_id": turn_id} if turn_id is not None else {}, + ) + + +class TestFindPreTurnCheckpointId: + def test_returns_last_pre_turn_checkpoint_when_editing_latest_turn(self) -> None: + # Newest-first: T2 is the most-recent turn. The latest non-T2 + # checkpoint (cp2) is the rewind target — state immediately + # before T2 began. + tuples = [ + _cp("cp4", "T2"), + _cp("cp3", "T2"), + _cp("cp2", "T1"), + _cp("cp1", "T1"), + ] + assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2" + + def test_returns_pre_turn_checkpoint_when_later_turns_exist(self) -> None: + # Regression for the bug where walking newest-first returned the + # FIRST cp with ``turn_id != target`` — which is one of the + # later-turn checkpoints, NOT the pre-turn boundary. Editing + # T2 must rewind to the latest T1 checkpoint (cp2), not to the + # latest T3 checkpoint (cp6). + tuples = [ + _cp("cp6", "T3"), + _cp("cp5", "T3"), + _cp("cp4", "T2"), + _cp("cp3", "T2"), + _cp("cp2", "T1"), + _cp("cp1", "T1"), + ] + assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2" + + def test_returns_none_when_editing_first_turn(self) -> None: + # No pre-turn boundary exists; caller is expected to fall back + # to the oldest checkpoint or special-case "first turn of the + # thread". + tuples = [ + _cp("cp4", "T2"), + _cp("cp3", "T2"), + _cp("cp2", "T1"), + _cp("cp1", "T1"), + ] + assert _find_pre_turn_checkpoint_id(tuples, turn_id="T1") is None + + def test_returns_none_when_only_edited_turn_present(self) -> None: + tuples = [_cp("cp2", "T2"), _cp("cp1", "T2")] + assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") is None + + def test_returns_none_for_empty_history(self) -> None: + assert _find_pre_turn_checkpoint_id([], turn_id="T1") is None + + def test_legacy_checkpoints_without_turn_id_count_as_pre_turn(self) -> None: + # Checkpoints written before migration 136 have no + # ``metadata.turn_id``. They should be eligible rewind targets + # — they came before the + # edited turn began. + tuples = [ + _cp("cp3", "T2"), + SimpleNamespace( + config={"configurable": {"checkpoint_id": "cp2"}}, + metadata=None, + ), + _cp("cp1", "T1"), + ] + # Walking oldest-first: cp1(T1) tracked, cp2(legacy/None) tracked, + # then cp3(T2) crosses the boundary -> return cp2. + assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2" + + def test_skips_checkpoint_missing_checkpoint_id_in_config(self) -> None: + # If a checkpoint tuple's ``config["configurable"]`` is missing + # the ``checkpoint_id`` key (corrupt / partial), we keep the + # last known good target instead of crashing. + broken = SimpleNamespace( + config={"configurable": {}}, metadata={"turn_id": "T1"} + ) + tuples = [ + _cp("cp3", "T2"), + broken, + _cp("cp1", "T1"), + ] + # cp1(T1) tracked, broken skipped, cp3(T2) -> return cp1. + assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp1" + + +class TestRegenerateRequestValidation: + def test_revert_actions_requires_from_message_id(self) -> None: + with pytest.raises(Exception) as exc: + RegenerateRequest( + search_space_id=1, + user_query="hi", + revert_actions=True, + ) + msg = str(exc.value).lower() + assert "from_message_id" in msg + + def test_from_message_id_without_revert_is_allowed(self) -> None: + req = RegenerateRequest( + search_space_id=1, + user_query="hi", + from_message_id=42, + ) + assert req.from_message_id == 42 + assert req.revert_actions is False + + def test_revert_actions_with_from_message_id_passes(self) -> None: + req = RegenerateRequest( + search_space_id=1, + user_query="hi", + from_message_id=42, + revert_actions=True, + ) + assert req.revert_actions is True diff --git a/surfsense_backend/tests/unit/routes/test_revert_turn_route.py b/surfsense_backend/tests/unit/routes/test_revert_turn_route.py new file mode 100644 index 000000000..1e1cbffb3 --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_revert_turn_route.py @@ -0,0 +1,530 @@ +"""Unit tests for ``POST /threads/{id}/revert-turn/{chat_turn_id}``. + +The per-turn batch revert route walks rows in reverse ``created_at`` +order, reverts each independently, and returns a per-action result +list. Partial success is normal — the response status +is ``"partial"`` whenever any row could not be reverted, but we never +collapse the whole batch into a 4xx. + +These tests stub ``load_thread`` / ``revert_action`` and feed a fake +session, so they exercise the route's dispatch logic without a real DB. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any +from unittest.mock import AsyncMock, patch + +import pytest + +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.routes import agent_revert_route +from app.services.revert_service import RevertOutcome + + +@dataclass +class _FakeAction: + id: int + tool_name: str + user_id: str | None = "u1" + reverse_of: int | None = None + error: dict | None = None + + +@dataclass +class _FakeUser: + id: str = "u1" + + +@dataclass +class _ScalarResult: + rows: list[Any] + + def first(self) -> Any: + return self.rows[0] if self.rows else None + + def all(self) -> list[Any]: + return list(self.rows) + + +@dataclass +class _Result: + rows: list[Any] = field(default_factory=list) + + def scalars(self) -> _ScalarResult: + return _ScalarResult(self.rows) + + def all(self) -> list[Any]: + # ``_was_already_reverted_batch`` calls ``.all()`` directly on + # the row-tuple result (no ``.scalars()`` indirection). The + # rows queued for that helper are list[(revert_id, original_id)]. + return list(self.rows) + + +class _FakeNestedCtx: + """Async context manager that mimics ``session.begin_nested()``. + + The route raises a sentinel exception inside this block to roll back + bad rows. We just pass the exception through. + """ + + async def __aenter__(self) -> _FakeNestedCtx: + return self + + async def __aexit__(self, exc_type, exc, tb) -> bool: + # Returning False (or None) propagates the exception; the route + # catches its own sentinel above this layer. + return False + + +class _FakeSession: + """Minimal AsyncSession stand-in for the revert-turn route. + + Holds a queue of result objects; each ``execute(...)`` pops the next + one. The route calls ``execute`` exactly once per query so this maps + cleanly onto the assertion order of the test. + """ + + def __init__(self) -> None: + self._results: list[_Result] = [] + self.committed = False + self.rolled_back = False + # Count execute() calls to assert "no N+1 reverts". + self.execute_call_count = 0 + + def queue(self, *results: _Result) -> None: + self._results.extend(results) + + async def execute(self, _stmt: Any) -> _Result: + self.execute_call_count += 1 + if not self._results: + return _Result(rows=[]) + return self._results.pop(0) + + def begin_nested(self) -> _FakeNestedCtx: + return _FakeNestedCtx() + + async def commit(self) -> None: + self.committed = True + + async def rollback(self) -> None: + self.rolled_back = True + + +def _enabled_flags() -> AgentFeatureFlags: + return AgentFeatureFlags( + disable_new_agent_stack=False, + enable_action_log=True, + enable_revert_route=True, + ) + + +@pytest.fixture +def patch_get_flags(): + def _patch(flags: AgentFeatureFlags): + return patch( + "app.routes.agent_revert_route.get_flags", + return_value=flags, + ) + + return _patch + + +class TestFlagGuard: + @pytest.mark.asyncio + async def test_returns_503_when_revert_route_disabled( + self, patch_get_flags + ) -> None: + flags = AgentFeatureFlags( + disable_new_agent_stack=False, + enable_action_log=True, + enable_revert_route=False, + ) + session = _FakeSession() + with patch_get_flags(flags), pytest.raises(Exception) as exc: + await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="42:1700000000000", + session=session, + user=_FakeUser(), + ) + assert getattr(exc.value, "status_code", None) == 503 + + +class TestRevertTurnDispatch: + @pytest.mark.asyncio + async def test_empty_turn_returns_ok_with_no_rows(self, patch_get_flags) -> None: + session = _FakeSession() + session.queue(_Result(rows=[])) # rows query returns nothing + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-empty", + session=session, + user=_FakeUser(), + ) + assert response.status == "ok" + assert response.total == 0 + assert response.results == [] + assert session.committed is True + + @pytest.mark.asyncio + async def test_walks_rows_in_reverse_and_reverts_each( + self, patch_get_flags + ) -> None: + rows = [ + _FakeAction(id=10, tool_name="rm"), + _FakeAction(id=9, tool_name="write_file"), + _FakeAction(id=8, tool_name="mkdir"), + ] + session = _FakeSession() + session.queue(_Result(rows=rows)) + # Single batched ``_was_already_reverted_batch`` probe replaces + # the previous N per-row SELECTs. + session.queue(_Result(rows=[])) + + async def _fake_revert(_session, *, action, requester_user_id): + return RevertOutcome( + status="ok", + message=f"reverted-{action.id}", + new_action_id=100 + action.id, + ) + + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + patch.object( + agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert) + ), + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-3", + session=session, + user=_FakeUser(), + ) + + assert response.status == "ok" + assert response.total == 3 + assert response.reverted == 3 + assert [r.action_id for r in response.results] == [10, 9, 8] + assert all(r.status == "reverted" for r in response.results) + assert response.results[0].new_action_id == 110 + # Only TWO ``execute`` calls regardless of the row count: one + # for the rows query, one for the batched + # ``_was_already_reverted_batch`` probe. Regression guard + # against re-introducing the per-row N+1 lookup. + assert session.execute_call_count == 2, ( + "revert-turn loop must batch idempotency probes; got " + f"{session.execute_call_count} execute() calls (expected 2)." + ) + + @pytest.mark.asyncio + async def test_already_reverted_rows_are_marked_idempotent( + self, patch_get_flags + ) -> None: + rows = [_FakeAction(id=5, tool_name="edit_file")] + session = _FakeSession() + session.queue(_Result(rows=rows)) + # Batch probe returns ``[(revert_id, original_id)]``. + session.queue(_Result(rows=[(42, 5)])) + + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert, + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-i", + session=session, + user=_FakeUser(), + ) + assert response.status == "ok" + assert response.already_reverted == 1 + assert response.results[0].status == "already_reverted" + assert response.results[0].new_action_id == 42 + revert.assert_not_called() + + @pytest.mark.asyncio + async def test_revert_action_skips_existing_revert_rows( + self, patch_get_flags + ) -> None: + rows = [_FakeAction(id=99, tool_name="_revert:edit_file", reverse_of=42)] + session = _FakeSession() + session.queue(_Result(rows=rows)) + + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert, + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-rev", + session=session, + user=_FakeUser(), + ) + assert response.status == "ok" + assert response.results[0].status == "skipped" + revert.assert_not_called() + + @pytest.mark.asyncio + async def test_partial_success_when_some_rows_not_reversible( + self, patch_get_flags + ) -> None: + rows = [ + _FakeAction(id=2, tool_name="send_email"), + _FakeAction(id=1, tool_name="edit_file"), + ] + session = _FakeSession() + session.queue(_Result(rows=rows)) + # Single batched idempotency probe. + session.queue(_Result(rows=[])) + + async def _fake_revert(_session, *, action, requester_user_id): + if action.tool_name == "send_email": + return RevertOutcome( + status="not_reversible", + message="connector revert not yet implemented", + ) + return RevertOutcome(status="ok", message="ok", new_action_id=500) + + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + patch.object( + agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert) + ), + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-mix", + session=session, + user=_FakeUser(), + ) + assert response.status == "partial" + assert response.reverted == 1 + assert response.not_reversible == 1 + statuses = sorted(r.status for r in response.results) + assert statuses == ["not_reversible", "reverted"] + + @pytest.mark.asyncio + async def test_unexpected_exception_marks_row_failed_not_batch( + self, patch_get_flags + ) -> None: + rows = [ + _FakeAction(id=20, tool_name="edit_file"), + _FakeAction(id=21, tool_name="edit_file"), + ] + session = _FakeSession() + session.queue(_Result(rows=rows)) + # Single batched idempotency probe. + session.queue(_Result(rows=[])) + + async def _fake_revert(_session, *, action, requester_user_id): + if action.id == 20: + raise RuntimeError("disk on fire") + return RevertOutcome(status="ok", message="ok", new_action_id=999) + + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + patch.object( + agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert) + ), + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-fail", + session=session, + user=_FakeUser(), + ) + assert response.status == "partial" + assert response.failed == 1 + assert response.reverted == 1 + bad = next(r for r in response.results if r.action_id == 20) + assert bad.status == "failed" + assert "disk on fire" in (bad.error or "") + good = next(r for r in response.results if r.action_id == 21) + assert good.status == "reverted" + + @pytest.mark.asyncio + async def test_permission_denied_when_other_user_owns_action( + self, patch_get_flags + ) -> None: + rows = [_FakeAction(id=7, tool_name="edit_file", user_id="someone-else")] + session = _FakeSession() + session.queue(_Result(rows=rows)) + # Batch idempotency probe (no prior reverts). + session.queue(_Result(rows=[])) + + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert, + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-perm", + session=session, + user=_FakeUser(id="not-owner"), + ) + assert response.status == "partial" + assert response.results[0].status == "permission_denied" + # ``permission_denied`` has its own dedicated counter so the + # response invariant ``total == sum(counters)`` always holds + # without overloading ``not_reversible`` (which historically + # absorbed this case and confused frontend toasts). + assert response.permission_denied == 1 + assert response.not_reversible == 0 + revert.assert_not_called() + + @pytest.mark.asyncio + async def test_counter_invariant_holds_across_mixed_outcomes( + self, patch_get_flags + ) -> None: + """Every row is accounted for in EXACTLY ONE counter. + + Mixes one of every supported outcome (reverted, already_reverted, + not_reversible, permission_denied, failed, skipped) and asserts + that the sum of counters equals ``response.total``. + """ + rows = [ + _FakeAction(id=10, tool_name="edit_file"), # ok + _FakeAction(id=9, tool_name="edit_file"), # already_reverted + _FakeAction(id=8, tool_name="send_email"), # not_reversible + _FakeAction(id=7, tool_name="rm", user_id="other"), # permission_denied + _FakeAction(id=6, tool_name="edit_file"), # failed + _FakeAction(id=5, tool_name="_revert:edit_file", reverse_of=99), # skipped + ] + session = _FakeSession() + session.queue(_Result(rows=rows)) + # Single batched probe; only id=9 has a prior revert. + # Schema: list[(revert_id, original_id)]. + session.queue(_Result(rows=[(42, 9)])) + + async def _fake_revert(_session, *, action, requester_user_id): + if action.id == 10: + return RevertOutcome(status="ok", message="ok", new_action_id=500) + if action.id == 8: + return RevertOutcome( + status="not_reversible", + message="connector revert not yet implemented", + ) + if action.id == 6: + raise RuntimeError("boom") + raise AssertionError(f"unexpected revert call for {action.id}") + + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + patch.object( + agent_revert_route, + "revert_action", + AsyncMock(side_effect=_fake_revert), + ), + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-mixed-all", + session=session, + user=_FakeUser(), # only id=7 has a different user_id + ) + + assert response.total == len(rows) == 6 + bucket_sum = ( + response.reverted + + response.already_reverted + + response.not_reversible + + response.permission_denied + + response.failed + + response.skipped + ) + assert bucket_sum == response.total, ( + "Counter invariant broken: total " + f"({response.total}) != sum of counters ({bucket_sum}). " + f"Counters: reverted={response.reverted}, " + f"already_reverted={response.already_reverted}, " + f"not_reversible={response.not_reversible}, " + f"permission_denied={response.permission_denied}, " + f"failed={response.failed}, skipped={response.skipped}" + ) + assert response.reverted == 1 + assert response.already_reverted == 1 + assert response.not_reversible == 1 + assert response.permission_denied == 1 + assert response.failed == 1 + assert response.skipped == 1 + + @pytest.mark.asyncio + async def test_integrity_error_translates_to_already_reverted( + self, patch_get_flags + ) -> None: + """The partial unique index on ``reverse_of`` raises + ``IntegrityError`` when a concurrent revert wins the race against + the pre-flight ``_was_already_reverted`` SELECT. The route MUST + recover by re-querying for the winning revert id and returning + ``status="already_reverted"`` (not ``"failed"``) so racing + clients see consistent idempotent semantics. + """ + from sqlalchemy.exc import IntegrityError + + rows = [_FakeAction(id=33, tool_name="edit_file")] + session = _FakeSession() + session.queue(_Result(rows=rows)) + # Batch pre-flight probe: nothing yet (we'll race). + session.queue(_Result(rows=[])) + # Post-IntegrityError fallback uses the SCALAR + # ``_was_already_reverted`` (single-id lookup) so it pulls + # ``[777]`` via ``.scalars().first()``. + session.queue(_Result(rows=[777])) + + async def _racing_revert(_session, *, action, requester_user_id): + raise IntegrityError("INSERT", {}, Exception("dup reverse_of")) + + with ( + patch_get_flags(_enabled_flags()), + patch.object( + agent_revert_route, "load_thread", AsyncMock(return_value=object()) + ), + patch.object( + agent_revert_route, + "revert_action", + AsyncMock(side_effect=_racing_revert), + ), + ): + response = await agent_revert_route.revert_agent_turn( + thread_id=1, + chat_turn_id="ct-race", + session=session, + user=_FakeUser(), + ) + + assert response.failed == 0, ( + "IntegrityError must NOT surface as a failed row; the unique " + "index is the durable expression of idempotency." + ) + assert response.already_reverted == 1 + assert response.results[0].status == "already_reverted" + assert response.results[0].new_action_id == 777 diff --git a/surfsense_backend/tests/unit/services/test_revert_filesystem_tools.py b/surfsense_backend/tests/unit/services/test_revert_filesystem_tools.py new file mode 100644 index 000000000..95314741a --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_revert_filesystem_tools.py @@ -0,0 +1,370 @@ +"""Unit tests for the filesystem-tool branches of ``revert_service``. + +Covers: + +* Exact-name dispatch — ``rmdir`` does NOT mis-route to the document + branch (``"rmdir".startswith("rm")`` would mis-route under the legacy + prefix-based dispatch). +* ``rm`` revert re-INSERTs a fresh document from the snapshot, including + re-creating chunks. Falls back to ``(folder_id_before, title_before)`` + when ``metadata_before["virtual_path"]`` is missing. +* ``write_file`` create-revert (``content_before IS NULL``) DELETEs the + document. +* ``rmdir`` revert re-INSERTs a fresh folder from the snapshot. +* ``mkdir`` revert DELETEs the empty folder; reports ``tool_unavailable`` + when the folder gained children. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import numpy as np +import pytest + +from app.services import revert_service + +pytestmark = pytest.mark.unit + + +@pytest.fixture(autouse=True) +def _stub_embeddings(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + revert_service, + "embed_texts", + lambda texts: [np.zeros(8, dtype=np.float32) for _ in texts], + ) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _FakeResult: + def __init__(self, rows: list[Any] | None = None, scalar: Any = None) -> None: + self._rows = rows or [] + self._scalar = scalar + + def all(self) -> list[Any]: + return list(self._rows) + + def scalar_one_or_none(self) -> Any: + return self._scalar + + def scalars(self) -> Any: + return _FakeScalarsProxy(self._rows) + + +class _FakeScalarsProxy: + def __init__(self, rows: list[Any]) -> None: + self._rows = rows + + def first(self) -> Any: + return self._rows[0] if self._rows else None + + +class _FakeSession: + def __init__(self) -> None: + self.execute = AsyncMock() + self.added: list[Any] = [] + self.deleted: list[Any] = [] + self.flush = AsyncMock() + # session.get(Model, pk) lookup + self.get = AsyncMock(return_value=None) + + async def _flush_assigning_ids() -> None: + for obj in self.added: + if getattr(obj, "id", None) is None: + obj.id = 999 + + self.flush.side_effect = _flush_assigning_ids + + def add(self, obj: Any) -> None: + self.added.append(obj) + + def add_all(self, objs: list[Any]) -> None: + self.added.extend(objs) + + +def _action(*, tool_name: str, action_id: int = 7): + return MagicMock( + id=action_id, + tool_name=tool_name, + thread_id=1, + search_space_id=2, + user_id="user-1", + reverse_descriptor=None, + ) + + +def _doc_revision( + *, + document_id: int | None = None, + content_before: str | None = "old content", + title_before: str | None = "notes.md", + folder_id_before: int | None = 5, + chunks_before: list[dict[str, str]] | None = None, + metadata_before: dict[str, str] | None = None, +): + revision = MagicMock() + revision.id = 100 + revision.document_id = document_id + revision.search_space_id = 2 + revision.content_before = content_before + revision.title_before = title_before + revision.folder_id_before = folder_id_before + revision.chunks_before = chunks_before or [] + revision.metadata_before = metadata_before + return revision + + +def _folder_revision( + *, + folder_id: int | None = None, + name_before: str | None = "team", + parent_id_before: int | None = None, + position_before: str | None = "a0", +): + revision = MagicMock() + revision.id = 200 + revision.folder_id = folder_id + revision.search_space_id = 2 + revision.name_before = name_before + revision.parent_id_before = parent_id_before + revision.position_before = position_before + return revision + + +# --------------------------------------------------------------------------- +# Exact-name dispatch regression guards +# --------------------------------------------------------------------------- + + +class TestExactDispatch: + """Regression: ``rmdir`` MUST NOT route to the document branch.""" + + @pytest.mark.asyncio + async def test_rmdir_does_not_misroute_to_document(self) -> None: + # If dispatch used `startswith("rm")` we'd hit the document branch + # here. With exact-name lookup `rmdir` lands in `_FOLDER_TOOLS`. + session = _FakeSession() + action = _action(tool_name="rmdir") + # No folder revisions exist for this action. + session.execute.return_value = _FakeResult(rows=[]) + outcome = await revert_service.revert_action( + session, # type: ignore[arg-type] + action=action, + requester_user_id="user-1", + ) + assert outcome.status == "not_reversible" + assert "folder_revisions" in outcome.message + + def test_dispatch_sets_split_doc_and_folder(self) -> None: + # Static guards on the dispatch tables themselves so a future + # refactor doesn't accidentally reintroduce the prefix bug. + assert "rm" in revert_service._DOC_TOOLS + assert "rmdir" in revert_service._FOLDER_TOOLS + assert "rmdir" not in revert_service._DOC_TOOLS + assert "rm" not in revert_service._FOLDER_TOOLS + # ``move_file`` lives only in document tools (it's a doc rename). + assert "move_file" in revert_service._DOC_TOOLS + assert "move_file" not in revert_service._FOLDER_TOOLS + + +# --------------------------------------------------------------------------- +# rm revert (re-INSERT) +# --------------------------------------------------------------------------- + + +class TestRmRevert: + @pytest.mark.asyncio + async def test_re_inserts_document_with_chunks(self) -> None: + session = _FakeSession() + revision = _doc_revision( + document_id=None, # row was hard-deleted + content_before="hello world", + title_before="x.md", + folder_id_before=None, + chunks_before=[{"content": "alpha"}, {"content": "beta"}], + metadata_before={"virtual_path": "/documents/x.md"}, + ) + # No collision check hit and the resulting query returns nothing. + session.execute.return_value = _FakeResult(scalar=None) + + outcome = await revert_service._reinsert_document_from_revision( + session, # type: ignore[arg-type] + revision=revision, + ) + + assert outcome.status == "ok" + # New Document + 2 chunks must have been added. + from app.db import Chunk, Document + + added_docs = [obj for obj in session.added if isinstance(obj, Document)] + added_chunks = [obj for obj in session.added if isinstance(obj, Chunk)] + assert len(added_docs) == 1 + assert added_docs[0].title == "x.md" + assert len(added_chunks) == 2 + # Snapshot was repointed at the new doc id so a follow-up revert works. + assert revision.document_id == added_docs[0].id + + @pytest.mark.asyncio + async def test_falls_back_to_folder_id_and_title_for_virtual_path( + self, + ) -> None: + session = _FakeSession() + # Snapshot with NO metadata_before — the fallback path must kick in. + revision = _doc_revision( + document_id=None, + content_before="hello", + title_before="cap.md", + folder_id_before=42, + chunks_before=[], + metadata_before=None, + ) + # session.get(Folder, 42) returns a folder with a name. + folder = MagicMock() + folder.name = "team" + folder.parent_id = None + # First .get is for the folder lookup in the path-derivation. + session.get = AsyncMock(return_value=folder) + session.execute.return_value = _FakeResult(scalar=None) + + outcome = await revert_service._reinsert_document_from_revision( + session, # type: ignore[arg-type] + revision=revision, + ) + assert outcome.status == "ok" + + @pytest.mark.asyncio + async def test_falls_back_to_root_path_when_no_folder( + self, + ) -> None: + """metadata_before is None and folder_id_before is None still + resolves: title fallback yields ``/documents/`` so revert + proceeds at the root of the documents tree.""" + session = _FakeSession() + revision = _doc_revision( + document_id=None, + content_before="hello", + title_before="x.md", + folder_id_before=None, + metadata_before=None, + ) + # No collision in the documents tree at /documents/x.md. + session.execute.return_value = _FakeResult(scalar=None) + outcome = await revert_service._reinsert_document_from_revision( + session, # type: ignore[arg-type] + revision=revision, + ) + assert outcome.status == "ok" + + @pytest.mark.asyncio + async def test_collision_with_live_doc_returns_tool_unavailable(self) -> None: + session = _FakeSession() + revision = _doc_revision( + document_id=None, + content_before="hi", + title_before="x.md", + folder_id_before=None, + metadata_before={"virtual_path": "/documents/x.md"}, + ) + # SELECT for unique_identifier_hash collision hits an existing row. + session.execute.return_value = _FakeResult(scalar=42) + outcome = await revert_service._reinsert_document_from_revision( + session, # type: ignore[arg-type] + revision=revision, + ) + assert outcome.status == "tool_unavailable" + assert "collide" in outcome.message + + +# --------------------------------------------------------------------------- +# write_file create revert (DELETE) +# --------------------------------------------------------------------------- + + +class TestWriteFileCreateRevert: + @pytest.mark.asyncio + async def test_deletes_created_doc(self) -> None: + session = _FakeSession() + revision = _doc_revision( + document_id=99, + content_before=None, # marker for "created in this action" + title_before=None, + ) + outcome = await revert_service._delete_created_document( + session, # type: ignore[arg-type] + revision=revision, + ) + assert outcome.status == "ok" + # Exactly one DELETE was issued. + assert session.execute.await_count == 1 + + +# --------------------------------------------------------------------------- +# rmdir revert (re-INSERT folder) +# --------------------------------------------------------------------------- + + +class TestRmdirRevert: + @pytest.mark.asyncio + async def test_re_inserts_folder_from_snapshot(self) -> None: + session = _FakeSession() + revision = _folder_revision( + folder_id=None, + name_before="team", + parent_id_before=None, + position_before="a0", + ) + outcome = await revert_service._reinsert_folder_from_revision( + session, # type: ignore[arg-type] + revision=revision, + ) + from app.db import Folder + + assert outcome.status == "ok" + added_folders = [obj for obj in session.added if isinstance(obj, Folder)] + assert len(added_folders) == 1 + assert added_folders[0].name == "team" + assert revision.folder_id == added_folders[0].id + + +# --------------------------------------------------------------------------- +# mkdir revert (DELETE folder) +# --------------------------------------------------------------------------- + + +class TestMkdirRevert: + @pytest.mark.asyncio + async def test_deletes_empty_folder(self) -> None: + session = _FakeSession() + revision = _folder_revision(folder_id=42) + # Both the doc-existence check and the child-folder check return None. + session.execute.side_effect = [ + _FakeResult(scalar=None), # docs + _FakeResult(scalar=None), # children + _FakeResult(scalar=None), # delete (no return value) + ] + outcome = await revert_service._delete_created_folder( + session, # type: ignore[arg-type] + revision=revision, + ) + assert outcome.status == "ok" + # 3 executes: docs check, children check, delete. + assert session.execute.await_count == 3 + + @pytest.mark.asyncio + async def test_reports_tool_unavailable_when_folder_has_children(self) -> None: + session = _FakeSession() + revision = _folder_revision(folder_id=42) + # First check (docs) returns "row found". + session.execute.return_value = _FakeResult(scalar=1) + outcome = await revert_service._delete_created_folder( + session, # type: ignore[arg-type] + revision=revision, + ) + assert outcome.status == "tool_unavailable" + assert "no longer empty" in outcome.message diff --git a/surfsense_backend/tests/unit/tasks/__init__.py b/surfsense_backend/tests/unit/tasks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/tasks/chat/__init__.py b/surfsense_backend/tests/unit/tasks/chat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py b/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py new file mode 100644 index 000000000..7f32bf456 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py @@ -0,0 +1,185 @@ +"""Unit tests for ``stream_new_chat._extract_chunk_parts``. + +Earlier versions only handled ``isinstance(chunk.content, str)`` and +silently dropped every other shape (Anthropic typed-block lists, +Bedrock reasoning blocks, ``additional_kwargs.reasoning_content`` from +a few providers). These regression tests pin those four shapes plus the +defensive cases (``None`` chunk, mixed types, missing fields). +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +import pytest + +from app.tasks.chat.stream_new_chat import _extract_chunk_parts + + +@dataclass +class _FakeChunk: + """Minimal stand-in for ``AIMessageChunk`` used in unit tests.""" + + content: Any = "" + additional_kwargs: dict[str, Any] = field(default_factory=dict) + tool_call_chunks: list[dict[str, Any]] = field(default_factory=list) + + +class TestStringContent: + def test_plain_string_content_extracts_as_text(self) -> None: + chunk = _FakeChunk(content="hello world") + out = _extract_chunk_parts(chunk) + assert out["text"] == "hello world" + assert out["reasoning"] == "" + assert out["tool_call_chunks"] == [] + + def test_empty_string_content_yields_empty_text(self) -> None: + chunk = _FakeChunk(content="") + out = _extract_chunk_parts(chunk) + assert out["text"] == "" + assert out["reasoning"] == "" + assert out["tool_call_chunks"] == [] + + +class TestListContent: + def test_list_of_text_blocks_concatenates(self) -> None: + chunk = _FakeChunk( + content=[ + {"type": "text", "text": "Hello "}, + {"type": "text", "text": "world"}, + ] + ) + out = _extract_chunk_parts(chunk) + assert out["text"] == "Hello world" + assert out["reasoning"] == "" + + def test_mixed_text_and_reasoning_blocks(self) -> None: + chunk = _FakeChunk( + content=[ + {"type": "reasoning", "reasoning": "Let me think... "}, + {"type": "reasoning", "text": "still thinking."}, + {"type": "text", "text": "The answer is 42."}, + ] + ) + out = _extract_chunk_parts(chunk) + assert out["text"] == "The answer is 42." + assert out["reasoning"] == "Let me think... still thinking." + + def test_tool_call_chunks_in_content_list_extracted(self) -> None: + chunk = _FakeChunk( + content=[ + {"type": "text", "text": "Calling tool..."}, + { + "type": "tool_call_chunk", + "id": "call_123", + "name": "make_widget", + "args": '{"color":"red"}', + }, + ] + ) + out = _extract_chunk_parts(chunk) + assert out["text"] == "Calling tool..." + assert out["reasoning"] == "" + assert len(out["tool_call_chunks"]) == 1 + assert out["tool_call_chunks"][0]["id"] == "call_123" + assert out["tool_call_chunks"][0]["name"] == "make_widget" + + def test_tool_use_blocks_also_extracted(self) -> None: + """Some providers (Anthropic) emit ``type='tool_use'`` instead.""" + chunk = _FakeChunk( + content=[ + { + "type": "tool_use", + "id": "call_xyz", + "name": "search", + }, + ] + ) + out = _extract_chunk_parts(chunk) + assert out["tool_call_chunks"] == [ + {"type": "tool_use", "id": "call_xyz", "name": "search"} + ] + + def test_unknown_block_types_are_ignored(self) -> None: + chunk = _FakeChunk( + content=[ + {"type": "image_url", "url": "https://example.com/x.png"}, + {"type": "text", "text": "ok"}, + ] + ) + out = _extract_chunk_parts(chunk) + assert out["text"] == "ok" + + def test_blocks_without_text_field_are_ignored(self) -> None: + chunk = _FakeChunk( + content=[ + {"type": "text"}, # no text/content key + {"type": "text", "text": "kept"}, + ] + ) + out = _extract_chunk_parts(chunk) + assert out["text"] == "kept" + + +class TestAdditionalKwargsReasoning: + def test_reasoning_content_in_additional_kwargs(self) -> None: + """Some providers stash reasoning in ``additional_kwargs.reasoning_content``.""" + chunk = _FakeChunk( + content="visible answer", + additional_kwargs={"reasoning_content": "internal monologue"}, + ) + out = _extract_chunk_parts(chunk) + assert out["text"] == "visible answer" + assert out["reasoning"] == "internal monologue" + + def test_reasoning_appended_to_typed_block_reasoning(self) -> None: + chunk = _FakeChunk( + content=[{"type": "reasoning", "text": "from blocks. "}], + additional_kwargs={"reasoning_content": "from kwargs."}, + ) + out = _extract_chunk_parts(chunk) + assert out["reasoning"] == "from blocks. from kwargs." + + +class TestToolCallChunksAttribute: + def test_tool_call_chunks_attribute_extracted_alongside_string_content( + self, + ) -> None: + chunk = _FakeChunk( + content="streaming text", + tool_call_chunks=[ + {"name": "save_document", "args": '{"title":"x"}', "id": "tc-9"} + ], + ) + out = _extract_chunk_parts(chunk) + assert out["text"] == "streaming text" + assert len(out["tool_call_chunks"]) == 1 + assert out["tool_call_chunks"][0]["id"] == "tc-9" + + def test_attribute_and_typed_block_chunks_both_collected(self) -> None: + chunk = _FakeChunk( + content=[ + { + "type": "tool_call_chunk", + "id": "from-block", + "name": "x", + } + ], + tool_call_chunks=[{"id": "from-attr", "name": "y"}], + ) + out = _extract_chunk_parts(chunk) + ids = [tcc.get("id") for tcc in out["tool_call_chunks"]] + assert ids == ["from-block", "from-attr"] + + +class TestDefensive: + @pytest.mark.parametrize( + "chunk_value", + [None, _FakeChunk(content=None), _FakeChunk(content=42)], + ) + def test_invalid_chunk_returns_empty_parts(self, chunk_value: Any) -> None: + out = _extract_chunk_parts(chunk_value) + assert out["text"] == "" + assert out["reasoning"] == "" + assert out["tool_call_chunks"] == [] diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 80ee9e9cd..f21a0a30b 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -14,6 +14,13 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; import { z } from "zod"; import { disabledToolsAtom } from "@/atoms/agent-tools/agent-tools.atoms"; +import { + agentActionsByChatTurnIdAtom, + markAgentActionRevertedAtom, + resetAgentActionMapAtom, + updateAgentActionReversibleAtom, + upsertAgentActionAtom, +} from "@/atoms/chat/agent-actions.atom"; import { clearTargetCommentIdAtom, currentThreadAtom, @@ -37,6 +44,11 @@ import { closeEditorPanelAtom } from "@/atoms/editor/editor-panel.atom"; import { membersAtom } from "@/atoms/members/members-query.atoms"; import { removeChatTabAtom, updateChatTabTitleAtom } from "@/atoms/tabs/tabs.atom"; import { currentUserAtom } from "@/atoms/user/user-query.atoms"; +import { + EditMessageDialog, + type EditMessageDialogChoice, +} from "@/components/assistant-ui/edit-message-dialog"; +import { StepSeparatorDataUI } from "@/components/assistant-ui/step-separator"; import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps"; import { Thread } from "@/components/assistant-ui/thread"; import { @@ -60,15 +72,20 @@ import { setActivePodcastTaskId, } from "@/lib/chat/podcast-state"; import { + addStepSeparator, addToolCall, + appendReasoning, appendText, buildContentForPersistence, buildContentForUI, type ContentPartsState, + endReasoning, FrameBatchedUpdater, + findToolCallIdByLcId, readSSEStream, type SSEEvent, type ThinkingStepData, + type ToolUIGate, updateThinkingSteps, updateToolCall, } from "@/lib/chat/streaming-state"; @@ -257,44 +274,38 @@ function extractMentionedDocuments(content: unknown): MentionedDocumentInfo[] { } /** - * Tools that should render custom UI in the chat. + * Every tool call renders a card. The legacy + * ``BASE_TOOLS_WITH_UI`` allowlist used to drop unknown tool calls on the + * floor; we now route everything through ``ToolFallback``. Persisted + * payload size stays bounded because the backend's + * ``format_thinking_step`` summarisation and the + * ``result_length``-only default for unknown tools (see + * ``stream_new_chat.py``) keep the JSON from ballooning. */ -const BASE_TOOLS_WITH_UI = new Set([ - "web_search", - "generate_podcast", - "generate_report", - "generate_resume", - "generate_video_presentation", - "display_image", - "generate_image", - "delete_notion_page", - "create_notion_page", - "update_notion_page", - "create_linear_issue", - "update_linear_issue", - "delete_linear_issue", - "create_google_drive_file", - "delete_google_drive_file", - "create_onedrive_file", - "delete_onedrive_file", - "create_dropbox_file", - "delete_dropbox_file", - "create_calendar_event", - "update_calendar_event", - "delete_calendar_event", - "create_gmail_draft", - "update_gmail_draft", - "send_gmail_email", - "trash_gmail_email", - "create_jira_issue", - "update_jira_issue", - "delete_jira_issue", - "create_confluence_page", - "update_confluence_page", - "delete_confluence_page", - "execute", - // "write_todos", // Disabled for now -]); +const TOOLS_WITH_UI_ALL: ToolUIGate = "all"; + +/** + * When a streamed message is persisted, the backend returns the durable + * ``turn_id`` (``configurable.turn_id`` from the agent run). Merge it + * into the assistant-ui message metadata so the per-turn "Revert turn" + * button can scope to this turn's actions even after a full chat reload. + */ +function mergeChatTurnIdIntoMessage( + msg: ThreadMessageLike, + turnId: string | null | undefined +): ThreadMessageLike { + if (!turnId) return msg; + const existingMeta = (msg.metadata ?? {}) as { custom?: Record<string, unknown> }; + const existingCustom = existingMeta.custom ?? {}; + if ((existingCustom as { chatTurnId?: string }).chatTurnId === turnId) return msg; + return { + ...msg, + metadata: { + ...existingMeta, + custom: { ...existingCustom, chatTurnId: turnId }, + }, + }; +} export default function NewChatPage() { const params = useParams(); @@ -311,7 +322,7 @@ export default function NewChatPage() { assistantMsgId: string; interruptData: Record<string, unknown>; } | null>(null); - const toolsWithUI = useMemo(() => new Set([...BASE_TOOLS_WITH_UI]), []); + const toolsWithUI = TOOLS_WITH_UI_ALL; const setMessageDocumentsMap = useSetAtom(messageDocumentsMapAtom); const persistAssistantErrorMessage = useCallback( @@ -364,12 +375,14 @@ export default function NewChatPage() { userMsgId, content, mentionedDocs, + turnId, logContext, }: { threadId: number | null; userMsgId: string; content: unknown; mentionedDocs?: MentionedDocumentInfo[]; + turnId?: string | null; logContext: string; }) => { if (!threadId) return null; @@ -390,10 +403,18 @@ export default function NewChatPage() { const savedUserMessage = await appendMessage(threadId, { role: "user", content: normalizedContent as AppendMessage["content"], + turn_id: turnId, }); const newUserMsgId = `msg-${savedUserMessage.id}`; setMessages((prev) => - prev.map((m) => (m.id === userMsgId ? { ...m, id: newUserMsgId } : m)) + prev.map((m) => + m.id === userMsgId + ? mergeChatTurnIdIntoMessage( + { ...m, id: newUserMsgId }, + savedUserMessage.turn_id + ) + : m + ) ); if (mentionedDocs && mentionedDocs.length > 0) { setMessageDocumentsMap((prev) => { @@ -419,6 +440,7 @@ export default function NewChatPage() { assistantMsgId, content, tokenUsage, + turnId, logContext, onRemapped, }: { @@ -426,6 +448,7 @@ export default function NewChatPage() { assistantMsgId: string; content: unknown; tokenUsage?: Record<string, unknown>; + turnId?: string | null; logContext: string; onRemapped?: (newMsgId: string) => void; }) => { @@ -435,11 +458,19 @@ export default function NewChatPage() { role: "assistant", content: content as AppendMessage["content"], token_usage: tokenUsage, + turn_id: turnId, }); const newMsgId = `msg-${savedMessage.id}`; tokenUsageStore.rename(assistantMsgId, newMsgId); setMessages((prev) => - prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) + prev.map((m) => + m.id === assistantMsgId + ? mergeChatTurnIdIntoMessage( + { ...m, id: newMsgId }, + savedMessage.turn_id + ) + : m + ) ); onRemapped?.(newMsgId); return newMsgId; @@ -470,6 +501,25 @@ export default function NewChatPage() { const setAgentCreatedDocuments = useSetAtom(agentCreatedDocumentsAtom); const pendingUserImageUrls = useAtomValue(pendingUserImageDataUrlsAtom); const setPendingUserImageUrls = useSetAtom(pendingUserImageDataUrlsAtom); + // Agent action log SSE side-channel. + const upsertAgentAction = useSetAtom(upsertAgentActionAtom); + const updateAgentActionReversible = useSetAtom(updateAgentActionReversibleAtom); + const markAgentActionReverted = useSetAtom(markAgentActionRevertedAtom); + const resetAgentActionMap = useSetAtom(resetAgentActionMapAtom); + // Chat-turn-keyed action map for the edit-from-position pre-flight + // that decides whether to show the confirmation dialog. + const agentActionsByChatTurnId = useAtomValue(agentActionsByChatTurnIdAtom); + // Edit dialog state. Holds the message id being edited and + // the (already extracted) regenerate args so we can resume the edit + // after the user picks "revert all" / "continue" / "cancel". + const [editDialogState, setEditDialogState] = useState<{ + fromMessageId: number; + userQuery: string | null; + userMessageContent: ThreadMessageLike["content"]; + userImages: NewChatUserImagePayload[]; + downstreamReversibleCount: number; + downstreamTotalCount: number; + } | null>(null); // Get current user for author info in shared chats const { data: currentUser } = useAtomValue(currentUserAtom); @@ -678,6 +728,7 @@ export default function NewChatPage() { clearPlanOwnerRegistry(); closeReportPanel(); closeEditorPanel(); + resetAgentActionMap(); try { if (urlChatId > 0) { @@ -746,6 +797,7 @@ export default function NewChatPage() { removeChatTab, searchSpaceId, tokenUsageStore, + resetAgentActionMap, ]); // Initialize on mount, and re-init when switching search spaces (even if urlChatId is the same) @@ -990,6 +1042,7 @@ export default function NewChatPage() { const contentPartsState: ContentPartsState = { contentParts: [], currentTextPartIndex: -1, + currentReasoningPartIndex: -1, toolCallIndices: new Map(), }; const { contentParts, toolCallIndices } = contentPartsState; @@ -997,6 +1050,8 @@ export default function NewChatPage() { let tokenUsageData: Record<string, unknown> | null = null; let newAccepted = false; let userPersisted = false; + // Captured from ``data-turn-info`` at stream start. + let streamedChatTurnId: string | null = null; try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; @@ -1088,21 +1143,52 @@ export default function NewChatPage() { scheduleFlush(); break; + case "reasoning-delta": + appendReasoning(contentPartsState, parsed.delta); + scheduleFlush(); + break; + + case "reasoning-end": + endReasoning(contentPartsState); + scheduleFlush(); + break; + + case "start-step": + addStepSeparator(contentPartsState); + scheduleFlush(); + break; + + case "finish-step": + break; + case "tool-input-start": - addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {}); + addToolCall( + contentPartsState, + toolsWithUI, + parsed.toolCallId, + parsed.toolName, + {}, + false, + parsed.langchainToolCallId + ); batcher.flush(); break; case "tool-input-available": { if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {} }); + updateToolCall(contentPartsState, parsed.toolCallId, { + args: parsed.input || {}, + langchainToolCallId: parsed.langchainToolCallId, + }); } else { addToolCall( contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, - parsed.input || {} + parsed.input || {}, + false, + parsed.langchainToolCallId ); } batcher.flush(); @@ -1110,7 +1196,10 @@ export default function NewChatPage() { } case "tool-output-available": { - updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output }); + updateToolCall(contentPartsState, parsed.toolCallId, { + result: parsed.output, + langchainToolCallId: parsed.langchainToolCallId, + }); markInterruptsCompleted(contentParts); if (parsed.output?.status === "pending" && parsed.output?.podcast_id) { const idx = toolCallIndices.get(parsed.toolCallId); @@ -1216,6 +1305,50 @@ export default function NewChatPage() { break; } + case "data-action-log": { + const al = parsed.data; + const matchedToolCallId = al.lc_tool_call_id + ? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id) + : null; + upsertAgentAction({ + action: { + id: al.id, + threadId: currentThreadId, + lcToolCallId: al.lc_tool_call_id, + chatTurnId: al.chat_turn_id, + toolName: al.tool_name, + reversible: al.reversible, + reverseDescriptorPresent: al.reverse_descriptor_present, + error: al.error, + revertedByActionId: null, + isRevertAction: false, + createdAt: al.created_at, + }, + toolCallId: matchedToolCallId, + }); + break; + } + + case "data-action-log-updated": { + updateAgentActionReversible({ + id: parsed.data.id, + reversible: parsed.data.reversible, + }); + break; + } + + case "data-turn-info": { + streamedChatTurnId = parsed.data.chat_turn_id || null; + if (streamedChatTurnId) { + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m + ) + ); + } + break; + } + case "data-token-usage": tokenUsageData = parsed.data; tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData); @@ -1237,6 +1370,7 @@ export default function NewChatPage() { userMsgId, content: persistContent, mentionedDocs: allMentionedDocs, + turnId: streamedChatTurnId, logContext: "new chat", }); userPersisted = Boolean(persistedUserMsgId); @@ -1250,6 +1384,7 @@ export default function NewChatPage() { assistantMsgId, content: finalContent, tokenUsage: tokenUsageData ?? undefined, + turnId: streamedChatTurnId, logContext: "new chat", onRemapped: (newMsgId) => { setPendingInterrupt((prev) => @@ -1278,6 +1413,7 @@ export default function NewChatPage() { userMsgId, content: persistContent, mentionedDocs: allMentionedDocs, + turnId: streamedChatTurnId, logContext: "new chat (aborted)", }); userPersisted = Boolean(persistedUserMsgId); @@ -1286,11 +1422,12 @@ export default function NewChatPage() { } } - // Request was cancelled by user - persist partial response if any content was received const hasContent = contentParts.some( (part) => (part.type === "text" && part.text.length > 0) || - (part.type === "tool-call" && toolsWithUI.has(part.toolName)) + (part.type === "reasoning" && part.text.length > 0) || + (part.type === "tool-call" && + (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) ); if (hasContent && currentThreadId) { const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); @@ -1298,6 +1435,7 @@ export default function NewChatPage() { threadId: currentThreadId, assistantMsgId, content: partialContent, + turnId: streamedChatTurnId, logContext: "partial new chat", }); } @@ -1309,6 +1447,7 @@ export default function NewChatPage() { userMsgId, content: persistContent, mentionedDocs: allMentionedDocs, + turnId: streamedChatTurnId, logContext: "new chat (stream error)", }); userPersisted = Boolean(persistedUserMsgId); @@ -1347,7 +1486,8 @@ export default function NewChatPage() { tokenUsageStore, pendingUserImageUrls, setPendingUserImageUrls, - toolsWithUI, + upsertAgentAction, + updateAgentActionReversible, handleStreamTerminalError, handleChatFailure, persistAssistantTurn, @@ -1384,11 +1524,14 @@ export default function NewChatPage() { const contentPartsState: ContentPartsState = { contentParts: [], currentTextPartIndex: -1, + currentReasoningPartIndex: -1, toolCallIndices: new Map(), }; const { contentParts, toolCallIndices } = contentPartsState; let tokenUsageData: Record<string, unknown> | null = null; let resumeAccepted = false; + // Captured from ``data-turn-info`` at stream start. + let streamedChatTurnId: string | null = null; const existingMsg = messages.find((m) => m.id === assistantMsgId); if (existingMsg && Array.isArray(existingMsg.content)) { @@ -1492,8 +1635,34 @@ export default function NewChatPage() { scheduleFlush(); break; + case "reasoning-delta": + appendReasoning(contentPartsState, parsed.delta); + scheduleFlush(); + break; + + case "reasoning-end": + endReasoning(contentPartsState); + scheduleFlush(); + break; + + case "start-step": + addStepSeparator(contentPartsState); + scheduleFlush(); + break; + + case "finish-step": + break; + case "tool-input-start": - addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {}); + addToolCall( + contentPartsState, + toolsWithUI, + parsed.toolCallId, + parsed.toolName, + {}, + false, + parsed.langchainToolCallId + ); batcher.flush(); break; @@ -1501,6 +1670,7 @@ export default function NewChatPage() { if (toolCallIndices.has(parsed.toolCallId)) { updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {}, + langchainToolCallId: parsed.langchainToolCallId, }); } else { addToolCall( @@ -1508,7 +1678,9 @@ export default function NewChatPage() { toolsWithUI, parsed.toolCallId, parsed.toolName, - parsed.input || {} + parsed.input || {}, + false, + parsed.langchainToolCallId ); } batcher.flush(); @@ -1517,6 +1689,7 @@ export default function NewChatPage() { case "tool-output-available": updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output, + langchainToolCallId: parsed.langchainToolCallId, }); markInterruptsCompleted(contentParts); batcher.flush(); @@ -1578,6 +1751,50 @@ export default function NewChatPage() { break; } + case "data-action-log": { + const al = parsed.data; + const matchedToolCallId = al.lc_tool_call_id + ? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id) + : null; + upsertAgentAction({ + action: { + id: al.id, + threadId: resumeThreadId, + lcToolCallId: al.lc_tool_call_id, + chatTurnId: al.chat_turn_id, + toolName: al.tool_name, + reversible: al.reversible, + reverseDescriptorPresent: al.reverse_descriptor_present, + error: al.error, + revertedByActionId: null, + isRevertAction: false, + createdAt: al.created_at, + }, + toolCallId: matchedToolCallId, + }); + break; + } + + case "data-action-log-updated": { + updateAgentActionReversible({ + id: parsed.data.id, + reversible: parsed.data.reversible, + }); + break; + } + + case "data-turn-info": { + streamedChatTurnId = parsed.data.chat_turn_id || null; + if (streamedChatTurnId) { + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m + ) + ); + } + break; + } + case "data-token-usage": tokenUsageData = parsed.data; tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData); @@ -1597,6 +1814,7 @@ export default function NewChatPage() { assistantMsgId, content: finalContent, tokenUsage: tokenUsageData ?? undefined, + turnId: streamedChatTurnId, logContext: "resumed chat", }); } @@ -1613,7 +1831,9 @@ export default function NewChatPage() { const hasContent = contentParts.some( (part) => (part.type === "text" && part.text.length > 0) || - (part.type === "tool-call" && toolsWithUI.has(part.toolName)) + (part.type === "reasoning" && part.text.length > 0) || + (part.type === "tool-call" && + (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) ); if (!hasContent) return; const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); @@ -1621,6 +1841,7 @@ export default function NewChatPage() { threadId: resumeThreadId, assistantMsgId, content: partialContent, + turnId: streamedChatTurnId, logContext: "partial resumed chat", }); }, @@ -1635,7 +1856,8 @@ export default function NewChatPage() { messages, searchSpaceId, tokenUsageStore, - toolsWithUI, + upsertAgentAction, + updateAgentActionReversible, handleStreamTerminalError, persistAssistantTurn, ] @@ -1716,6 +1938,12 @@ export default function NewChatPage() { userMessageContent: ThreadMessageLike["content"]; userImages: NewChatUserImagePayload[]; sourceUserMessageId?: string; + }, + editFromPosition?: { + /** Message id (numeric, parsed from ``msg-<n>``) to rewind to. */ + fromMessageId?: number | null; + /** When true, revert reversible downstream actions before stream. */ + revertActions?: boolean; } ) => { if (!threadId) { @@ -1775,6 +2003,7 @@ export default function NewChatPage() { const contentPartsState: ContentPartsState = { contentParts: [], currentTextPartIndex: -1, + currentReasoningPartIndex: -1, toolCallIndices: new Map(), }; const { contentParts, toolCallIndices } = contentPartsState; @@ -1782,6 +2011,10 @@ export default function NewChatPage() { let tokenUsageData: Record<string, unknown> | null = null; let regenerateAccepted = false; let userPersisted = false; + // Captured from ``data-turn-info`` at stream start; stamped + // onto persisted messages so future edits can locate the + // right LangGraph checkpoint. + let streamedChatTurnId: string | null = null; // Add placeholder messages to UI // Always add back the user message (with new query for edit, or original content for reload) @@ -1814,6 +2047,16 @@ export default function NewChatPage() { if (isEdit) { requestBody.user_images = editExtras?.userImages ?? []; } + // Explicit edit-from-arbitrary-position. Only send + // ``from_message_id`` / ``revert_actions`` when the + // caller asked for them; otherwise the backend keeps the + // legacy "last 2 messages" behaviour for back-compat. + if (editFromPosition?.fromMessageId != null) { + requestBody.from_message_id = editFromPosition.fromMessageId; + if (editFromPosition.revertActions) { + requestBody.revert_actions = true; + } + } const response = await fetch(getRegenerateUrl(threadId), { method: "POST", headers: { @@ -1831,8 +2074,21 @@ export default function NewChatPage() { // Only switch UI to regenerated placeholder messages after the backend accepts // regenerate. This avoids local message loss when regenerate fails early (e.g. 400). + // + // When an explicit ``editFromPosition.fromMessageId`` is passed, slice from + // that message forward so edit-from-arbitrary-position drops every downstream + // message; otherwise fall back to the legacy "drop the last 2" behaviour. setMessages((prev) => { - const base = prev.length >= 2 ? prev.slice(0, -2) : prev; + let base = prev; + if (editFromPosition?.fromMessageId != null) { + const targetId = `msg-${editFromPosition.fromMessageId}`; + const sliceIndex = prev.findIndex((m) => m.id === targetId); + if (sliceIndex >= 0) { + base = prev.slice(0, sliceIndex); + } + } else if (prev.length >= 2) { + base = prev.slice(0, -2); + } return [ ...base, userMessage, @@ -1869,28 +2125,62 @@ export default function NewChatPage() { scheduleFlush(); break; + case "reasoning-delta": + appendReasoning(contentPartsState, parsed.delta); + scheduleFlush(); + break; + + case "reasoning-end": + endReasoning(contentPartsState); + scheduleFlush(); + break; + + case "start-step": + addStepSeparator(contentPartsState); + scheduleFlush(); + break; + + case "finish-step": + break; + case "tool-input-start": - addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {}); + addToolCall( + contentPartsState, + toolsWithUI, + parsed.toolCallId, + parsed.toolName, + {}, + false, + parsed.langchainToolCallId + ); batcher.flush(); break; case "tool-input-available": if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {} }); + updateToolCall(contentPartsState, parsed.toolCallId, { + args: parsed.input || {}, + langchainToolCallId: parsed.langchainToolCallId, + }); } else { addToolCall( contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, - parsed.input || {} + parsed.input || {}, + false, + parsed.langchainToolCallId ); } batcher.flush(); break; case "tool-output-available": - updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output }); + updateToolCall(contentPartsState, parsed.toolCallId, { + result: parsed.output, + langchainToolCallId: parsed.langchainToolCallId, + }); markInterruptsCompleted(contentParts); if (parsed.output?.status === "pending" && parsed.output?.podcast_id) { const idx = toolCallIndices.get(parsed.toolCallId); @@ -1916,6 +2206,82 @@ export default function NewChatPage() { break; } + case "data-action-log": { + const al = parsed.data; + const matchedToolCallId = al.lc_tool_call_id + ? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id) + : null; + upsertAgentAction({ + action: { + id: al.id, + threadId, + lcToolCallId: al.lc_tool_call_id, + chatTurnId: al.chat_turn_id, + toolName: al.tool_name, + reversible: al.reversible, + reverseDescriptorPresent: al.reverse_descriptor_present, + error: al.error, + revertedByActionId: null, + isRevertAction: false, + createdAt: al.created_at, + }, + toolCallId: matchedToolCallId, + }); + break; + } + + case "data-action-log-updated": { + updateAgentActionReversible({ + id: parsed.data.id, + reversible: parsed.data.reversible, + }); + break; + } + + case "data-turn-info": { + streamedChatTurnId = parsed.data.chat_turn_id || null; + if (streamedChatTurnId) { + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m + ) + ); + } + break; + } + + case "data-revert-results": { + const summary = parsed.data; + // failureCount must include every "not undone" bucket + // (not_reversible, permission_denied, failed) so the + // toast's "X could not be rolled back" math matches + // the response invariant ``total === sum(counters)``. + // ``skipped`` rows are batch revert artefacts (revert + // rows themselves) and are not user-facing failures. + const failureCount = + summary.failed + summary.not_reversible + (summary.permission_denied ?? 0); + if (failureCount > 0) { + toast.warning( + `Pre-revert: ${summary.reverted}/${summary.total} undone, ${failureCount} could not be rolled back.` + ); + } else if (summary.reverted > 0) { + toast.success( + summary.reverted === 1 + ? "Reverted 1 downstream action before regenerating." + : `Reverted ${summary.reverted} downstream actions before regenerating.` + ); + } + for (const r of summary.results) { + if (r.status === "reverted" || r.status === "already_reverted") { + markAgentActionReverted({ + id: r.action_id, + newActionId: r.new_action_id ?? null, + }); + } + } + break; + } + case "data-token-usage": tokenUsageData = parsed.data; tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData); @@ -1936,6 +2302,7 @@ export default function NewChatPage() { userMsgId, content: userContentToPersist, mentionedDocs: sourceMentionedDocs, + turnId: streamedChatTurnId, logContext: "regenerated", }); userPersisted = Boolean(persistedUserMsgId); @@ -1945,6 +2312,7 @@ export default function NewChatPage() { assistantMsgId, content: finalContent, tokenUsage: tokenUsageData ?? undefined, + turnId: streamedChatTurnId, logContext: "regenerated", }); @@ -1966,6 +2334,7 @@ export default function NewChatPage() { userMsgId, content: userContentToPersist, mentionedDocs: sourceMentionedDocs, + turnId: streamedChatTurnId, logContext: "regenerated (aborted)", }); userPersisted = Boolean(persistedUserMsgId); @@ -1973,7 +2342,9 @@ export default function NewChatPage() { const hasContent = contentParts.some( (part) => (part.type === "text" && part.text.length > 0) || - (part.type === "tool-call" && toolsWithUI.has(part.toolName)) + (part.type === "reasoning" && part.text.length > 0) || + (part.type === "tool-call" && + (toolsWithUI === "all" || toolsWithUI.has(part.toolName))) ); if (!hasContent) return; const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI); @@ -1982,6 +2353,7 @@ export default function NewChatPage() { assistantMsgId, content: partialContent, tokenUsage: tokenUsageData ?? undefined, + turnId: streamedChatTurnId, logContext: "partial regenerated chat", }); }, @@ -1992,6 +2364,7 @@ export default function NewChatPage() { userMsgId, content: userContentToPersist, mentionedDocs: sourceMentionedDocs, + turnId: streamedChatTurnId, logContext: "regenerated (stream error)", }); userPersisted = Boolean(persistedUserMsgId); @@ -2011,14 +2384,23 @@ export default function NewChatPage() { messageDocumentsMap, setMessageDocumentsMap, tokenUsageStore, - toolsWithUI, + upsertAgentAction, + updateAgentActionReversible, + markAgentActionReverted, handleStreamTerminalError, persistAssistantTurn, persistUserTurn, ] ); - // Handle editing a message - truncates history and regenerates with new query + // Handle editing a message - truncates history and regenerates with new query. + // + // When ``message.sourceId`` is set (the assistant-ui way to say + // "this edit replaces an older message"), we pin + // ``from_message_id`` so the backend rewinds to the right LangGraph + // checkpoint instead of relying on the legacy "last 2 messages" + // rewind. We also count downstream reversible actions and prompt the + // user to revert / continue / cancel before regenerating. const onEdit = useCallback( async (message: AppendMessage) => { const { userQuery, userImages } = extractUserTurnForNewChatApi(message, []); @@ -2029,17 +2411,100 @@ export default function NewChatPage() { } const userMessageContent = message.content as unknown as ThreadMessageLike["content"]; - const sourceUserMessageId = - typeof (message as { id?: unknown }).id === "string" - ? ((message as { id?: string }).id ?? undefined) - : undefined; - await handleRegenerate(queryForApi, { + + // ``sourceId`` per @assistant-ui/core's ``AppendMessage`` is + // "the ID of the message that was edited". Parse the numeric + // suffix so we can map it back to a DB row. + const sourceId = (message as { sourceId?: string }).sourceId; + const fromMessageId = + sourceId && /^msg-\d+$/.test(sourceId) + ? Number.parseInt(sourceId.replace(/^msg-/, ""), 10) + : null; + + if (fromMessageId == null) { + // No source id (or non-DB id) — fall back to today's + // last-2 behaviour. The user gets the legacy edit flow. + await handleRegenerate(queryForApi, { + userMessageContent, + userImages, + sourceUserMessageId: sourceId, + }); + return; + } + + // Pre-flight: count reversible downstream actions so we can + // auto-skip the dialog for harmless edits. + // + // "Downstream" means messages AFTER the edited one. The + // previous slice ``messages.slice(editedIndex)`` included + // the edited message itself in both the total + // count and the reversibility scan (any actions on the + // edited turn would be double-counted). Slice from + // ``editedIndex + 1`` so the dialog text matches reality: + // "N downstream messages will be dropped". + const editedIndex = messages.findIndex((m) => m.id === `msg-${fromMessageId}`); + let downstreamReversibleCount = 0; + let downstreamTotalCount = 0; + if (editedIndex >= 0) { + const downstream = messages.slice(editedIndex + 1); + downstreamTotalCount = downstream.length; + const seenTurns = new Set<string>(); + for (const m of downstream) { + const meta = (m.metadata ?? {}) as { custom?: { chatTurnId?: string } }; + const tid = meta.custom?.chatTurnId; + if (!tid || seenTurns.has(tid)) continue; + seenTurns.add(tid); + const turnActions = agentActionsByChatTurnId.get(tid) ?? []; + for (const a of turnActions) { + if (a.reversible && a.revertedByActionId === null && !a.isRevertAction && !a.error) { + downstreamReversibleCount += 1; + } + } + } + } + + if (downstreamReversibleCount === 0) { + // Nothing to revert — submit silently. + await handleRegenerate( + queryForApi, + { userMessageContent, userImages, sourceUserMessageId: sourceId }, + { fromMessageId, revertActions: false } + ); + return; + } + + setEditDialogState({ + fromMessageId, + userQuery: queryForApi, userMessageContent, userImages, - sourceUserMessageId, + downstreamReversibleCount, + downstreamTotalCount, }); }, - [handleRegenerate] + [handleRegenerate, messages, agentActionsByChatTurnId] + ); + + const handleEditDialogChoice = useCallback( + async (choice: EditMessageDialogChoice) => { + const pending = editDialogState; + if (!pending) return; + setEditDialogState(null); + if (choice === "cancel") return; + await handleRegenerate( + pending.userQuery, + { + userMessageContent: pending.userMessageContent, + userImages: pending.userImages, + sourceUserMessageId: `msg-${pending.fromMessageId}`, + }, + { + fromMessageId: pending.fromMessageId, + revertActions: choice === "revert", + } + ); + }, + [editDialogState, handleRegenerate] ); // Handle reloading/refreshing the last AI response @@ -2089,6 +2554,7 @@ export default function NewChatPage() { <TokenUsageProvider store={tokenUsageStore}> <AssistantRuntimeProvider runtime={runtime}> <ThinkingStepsDataUI /> + <StepSeparatorDataUI /> <div key={searchSpaceId} className="flex h-full overflow-hidden"> <div className="flex-1 flex flex-col min-w-0 overflow-hidden"> <Thread /> @@ -2097,6 +2563,15 @@ export default function NewChatPage() { <MobileEditorPanel /> <MobileHitlEditPanel /> </div> + <EditMessageDialog + open={editDialogState !== null} + onOpenChange={(open) => { + if (!open) setEditDialogState(null); + }} + downstreamReversibleCount={editDialogState?.downstreamReversibleCount ?? 0} + downstreamTotalCount={editDialogState?.downstreamTotalCount ?? 0} + onChoose={handleEditDialogChoice} + /> </AssistantRuntimeProvider> </TokenUsageProvider> ); diff --git a/surfsense_web/atoms/chat/agent-actions.atom.ts b/surfsense_web/atoms/chat/agent-actions.atom.ts new file mode 100644 index 000000000..7830c8751 --- /dev/null +++ b/surfsense_web/atoms/chat/agent-actions.atom.ts @@ -0,0 +1,194 @@ +"use client"; + +import { atom } from "jotai"; + +/** + * Minimal per-row projection of ``AgentActionLog`` that the tool card + * needs to decide whether to render a Revert button. + * + * Fields are deliberately a subset of the full ``AgentAction`` so the + * SSE side-channel (``data-action-log`` / ``data-action-log-updated``) + * can populate them without depending on the REST endpoint + * ``GET /threads/.../actions`` (which 503s when + * ``SURFSENSE_ENABLE_ACTION_LOG`` is off). + */ +export interface AgentActionLite { + id: number; + threadId: number | null; + lcToolCallId: string | null; + chatTurnId: string | null; + toolName: string; + reversible: boolean; + reverseDescriptorPresent: boolean; + error: boolean; + revertedByActionId: number | null; + isRevertAction: boolean; + createdAt: string | null; +} + +/** + * Map keyed off the LangChain ``tool_call.id`` (mirrors ``ContentPart + * tool-call.langchainToolCallId``). + */ +export const agentActionByLcIdAtom = atom<Map<string, AgentActionLite>>(new Map()); + +/** + * Parallel map keyed off the synthetic chat-card ``toolCallId`` + * (``call_<run-id>``) so ``ToolFallback`` (which only receives the + * synthetic id from assistant-ui) can join its card to the action log. + * + * Both maps are kept in sync by ``upsertAgentActionAtom``. + */ +export const agentActionByToolCallIdAtom = atom<Map<string, AgentActionLite>>(new Map()); + +/** + * Index keyed by ``chat_turn_id`` so the per-turn revert UI can answer + * "how many reversible actions does this assistant turn contain?" in + * O(1). Each entry's array is ordered by insertion (which + * for a single turn matches ``created_at`` because action-log writes + * happen synchronously). + */ +export const agentActionsByChatTurnIdAtom = atom<Map<string, AgentActionLite[]>>(new Map()); + +/** + * Action to upsert one ``AgentActionLite`` row. + * + * ``toolCallId`` is the synthetic card id (``call_<run-id>`` from + * ``stream_new_chat.py``). When provided alongside ``lcToolCallId``, the + * action is indexed under BOTH ids so the tool card can perform the + * lookup without going via the streaming state. + */ +export const upsertAgentActionAtom = atom( + null, + (_get, set, payload: { action: AgentActionLite; toolCallId?: string | null }) => { + const { action, toolCallId } = payload; + const upsertInto = ( + prev: Map<string, AgentActionLite>, + key: string + ): Map<string, AgentActionLite> => { + const next = new Map(prev); + const existing = next.get(key); + next.set(key, { + ...action, + // Preserve the local "reverted" bookkeeping if a reversibility + // flip arrives AFTER the user already reverted via the REST + // route. We never want a stale ``reversible=true`` event to + // resurrect a Reverted card. + revertedByActionId: existing?.revertedByActionId ?? action.revertedByActionId, + isRevertAction: existing?.isRevertAction ?? action.isRevertAction, + }); + return next; + }; + if (action.lcToolCallId) { + set(agentActionByLcIdAtom, (prev) => upsertInto(prev, action.lcToolCallId as string)); + } + if (toolCallId) { + set(agentActionByToolCallIdAtom, (prev) => upsertInto(prev, toolCallId)); + } + if (action.chatTurnId) { + set(agentActionsByChatTurnIdAtom, (prev) => { + const next = new Map(prev); + const turnId = action.chatTurnId as string; + const existing = next.get(turnId) ?? []; + const priorEntry = existing.find((row) => row.id === action.id); + const merged: AgentActionLite = { + ...action, + revertedByActionId: priorEntry?.revertedByActionId ?? action.revertedByActionId, + isRevertAction: priorEntry?.isRevertAction ?? action.isRevertAction, + }; + const others = existing.filter((row) => row.id !== action.id); + next.set(turnId, [...others, merged]); + return next; + }); + } + } +); + +function mutateById( + prev: Map<string, AgentActionLite>, + id: number, + mutator: (entry: AgentActionLite) => AgentActionLite +): Map<string, AgentActionLite> { + let mutated = false; + const next = new Map(prev); + for (const [key, value] of next) { + if (value.id === id) { + next.set(key, mutator(value)); + mutated = true; + } + } + return mutated ? next : prev; +} + +function mutateByIdInTurnIndex( + prev: Map<string, AgentActionLite[]>, + id: number, + mutator: (entry: AgentActionLite) => AgentActionLite +): Map<string, AgentActionLite[]> { + let mutated = false; + const next = new Map(prev); + for (const [key, list] of next) { + let listMutated = false; + const updated = list.map((row) => { + if (row.id === id) { + listMutated = true; + return mutator(row); + } + return row; + }); + if (listMutated) { + next.set(key, updated); + mutated = true; + } + } + return mutated ? next : prev; +} + +/** + * Action to flip an existing entry's ``reversible`` flag, keyed by the + * AgentActionLog row id (the SSE ``data-action-log-updated`` payload + * does NOT carry ``lcToolCallId``). + */ +export const updateAgentActionReversibleAtom = atom( + null, + (_get, set, payload: { id: number; reversible: boolean }) => { + const apply = (entry: AgentActionLite): AgentActionLite => ({ + ...entry, + reversible: payload.reversible, + }); + set(agentActionByLcIdAtom, (prev) => mutateById(prev, payload.id, apply)); + set(agentActionByToolCallIdAtom, (prev) => mutateById(prev, payload.id, apply)); + set(agentActionsByChatTurnIdAtom, (prev) => mutateByIdInTurnIndex(prev, payload.id, apply)); + } +); + +/** Action to mark an existing entry as reverted (post-revert call). */ +export const markAgentActionRevertedAtom = atom( + null, + (_get, set, payload: { id: number; newActionId: number | null }) => { + const apply = (entry: AgentActionLite): AgentActionLite => ({ + ...entry, + revertedByActionId: payload.newActionId ?? -1, + }); + set(agentActionByLcIdAtom, (prev) => mutateById(prev, payload.id, apply)); + set(agentActionByToolCallIdAtom, (prev) => mutateById(prev, payload.id, apply)); + set(agentActionsByChatTurnIdAtom, (prev) => mutateByIdInTurnIndex(prev, payload.id, apply)); + } +); + +/** Mark every action in a turn as reverted, given a list of (id, newActionId) pairs. */ +export const markAgentActionsRevertedBatchAtom = atom( + null, + (_get, set, payload: { entries: Array<{ id: number; newActionId: number | null }> }) => { + for (const entry of payload.entries) { + set(markAgentActionRevertedAtom, entry); + } + } +); + +/** Reset all maps (e.g. when the active thread changes). */ +export const resetAgentActionMapAtom = atom(null, (_get, set) => { + set(agentActionByLcIdAtom, new Map()); + set(agentActionByToolCallIdAtom, new Map()); + set(agentActionsByChatTurnIdAtom, new Map()); +}); diff --git a/surfsense_web/components/agent-action-log/action-log-item.tsx b/surfsense_web/components/agent-action-log/action-log-item.tsx index 425714c1f..673189709 100644 --- a/surfsense_web/components/agent-action-log/action-log-item.tsx +++ b/surfsense_web/components/agent-action-log/action-log-item.tsx @@ -17,16 +17,12 @@ import { import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Separator } from "@/components/ui/separator"; -import { getToolIcon } from "@/contracts/enums/toolIcons"; +import { getToolDisplayName, getToolIcon } from "@/contracts/enums/toolIcons"; import { type AgentAction, agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; import { AppError } from "@/lib/error"; import { formatRelativeDate } from "@/lib/format-date"; import { cn } from "@/lib/utils"; -function formatToolName(name: string): string { - return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase()); -} - interface ActionLogItemProps { action: AgentAction; threadId: number; @@ -43,7 +39,7 @@ export function ActionLogItem({ action, threadId, onRevertSuccess }: ActionLogIt const hasError = action.error !== null && action.error !== undefined; const Icon = getToolIcon(action.tool_name); - const displayName = formatToolName(action.tool_name); + const displayName = getToolDisplayName(action.tool_name); const argsPreview = action.args ? JSON.stringify(action.args, null, 2) : null; const truncatedArgs = diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx index 6b9c2c87e..bfe0434b4 100644 --- a/surfsense_web/components/assistant-ui/assistant-message.tsx +++ b/surfsense_web/components/assistant-ui/assistant-message.tsx @@ -33,6 +33,8 @@ import { useAllCitationMetadata, } from "@/components/assistant-ui/citation-metadata-context"; import { MarkdownText } from "@/components/assistant-ui/markdown-text"; +import { ReasoningMessagePart } from "@/components/assistant-ui/reasoning-message-part"; +import { RevertTurnButton } from "@/components/assistant-ui/revert-turn-button"; import { useTokenUsage } from "@/components/assistant-ui/token-usage-context"; import { ToolFallback } from "@/components/assistant-ui/tool-fallback"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; @@ -491,6 +493,7 @@ const AssistantMessageInner: FC = () => { <MessagePrimitive.Parts components={{ Text: MarkdownText, + Reasoning: ReasoningMessagePart, tools: { by_name: { generate_report: GenerateReportToolUI, @@ -699,6 +702,13 @@ const AssistantActionBar: FC = () => { const isLast = useAuiState((s) => s.message.isLast); const aui = useAui(); const api = useElectronAPI(); + // Surface the persisted ``chat_turn_id`` so the per-turn revert + // affordance can scope to just this message's actions. Streamed + // turns get their id once the assistant message is hydrated/finalised. + const chatTurnId = useAuiState(({ message }) => { + const meta = message?.metadata as { custom?: { chatTurnId?: string | null } } | undefined; + return meta?.custom?.chatTurnId ?? null; + }); const isQuickAssist = !!api?.replaceText && IS_QUICK_ASSIST_WINDOW; @@ -743,6 +753,9 @@ const AssistantActionBar: FC = () => { </TooltipIconButton> )} <MessageInfoDropdown /> + <div className="ml-auto"> + <RevertTurnButton chatTurnId={chatTurnId} /> + </div> </ActionBarPrimitive.Root> ); }; diff --git a/surfsense_web/components/assistant-ui/edit-message-dialog.tsx b/surfsense_web/components/assistant-ui/edit-message-dialog.tsx new file mode 100644 index 000000000..807f16fe7 --- /dev/null +++ b/surfsense_web/components/assistant-ui/edit-message-dialog.tsx @@ -0,0 +1,106 @@ +"use client"; + +/** + * Confirmation dialog shown when the user edits a message that has + * reversible downstream actions. Three buttons: + * + * • "Revert all & resubmit" — POST regenerate with revert_actions=true + * • "Continue without revert" — POST regenerate with revert_actions=false + * • "Cancel" — abort the edit entirely + * + * The dialog is auto-skipped when zero reversible downstream actions + * exist (the caller checks first via ``downstreamReversibleCount``). + */ + +import { useEffect, useRef, useState } from "react"; +import { + AlertDialog, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from "@/components/ui/alert-dialog"; +import { Button } from "@/components/ui/button"; + +export type EditMessageDialogChoice = "revert" | "continue" | "cancel"; + +export interface EditMessageDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + downstreamReversibleCount: number; + downstreamTotalCount: number; + onChoose: (choice: EditMessageDialogChoice) => void | Promise<void>; +} + +export function EditMessageDialog({ + open, + onOpenChange, + downstreamReversibleCount, + downstreamTotalCount, + onChoose, +}: EditMessageDialogProps) { + const [busy, setBusy] = useState<EditMessageDialogChoice | null>(null); + + // The parent's ``handleEditDialogChoice`` calls + // ``setEditDialogState(null)`` BEFORE awaiting ``handleRegenerate``. + // That collapses the dialog (Radix unmounts it) while ``onChoose`` + // is still awaiting the long-running stream. Without this guard, + // the ``finally { setBusy(null) }`` below ran after unmount and + // produced a "state update on unmounted component" dev warning. + const mountedRef = useRef(true); + useEffect(() => { + mountedRef.current = true; + return () => { + mountedRef.current = false; + }; + }, []); + + const handle = async (choice: EditMessageDialogChoice) => { + setBusy(choice); + try { + await onChoose(choice); + } finally { + if (mountedRef.current) { + setBusy(null); + } + } + }; + + return ( + <AlertDialog open={open} onOpenChange={onOpenChange}> + <AlertDialogContent> + <AlertDialogHeader> + <AlertDialogTitle>Edit this message?</AlertDialogTitle> + <AlertDialogDescription> + This edit drops {downstreamTotalCount} downstream message + {downstreamTotalCount === 1 ? "" : "s"} from the thread. {downstreamReversibleCount}{" "} + action + {downstreamReversibleCount === 1 ? "" : "s"} (e.g. file writes, connector changes) can + be rolled back. Pick how to handle them before regenerating. + </AlertDialogDescription> + </AlertDialogHeader> + + <div className="grid gap-2"> + <Button variant="default" disabled={busy !== null} onClick={() => handle("revert")}> + {busy === "revert" + ? "Reverting & resubmitting…" + : `Revert ${downstreamReversibleCount} action${ + downstreamReversibleCount === 1 ? "" : "s" + } & resubmit`} + </Button> + <Button variant="outline" disabled={busy !== null} onClick={() => handle("continue")}> + {busy === "continue" ? "Resubmitting…" : "Continue without reverting"} + </Button> + </div> + + <AlertDialogFooter className="sm:justify-start"> + <AlertDialogCancel disabled={busy !== null} onClick={() => handle("cancel")}> + Cancel + </AlertDialogCancel> + </AlertDialogFooter> + </AlertDialogContent> + </AlertDialog> + ); +} diff --git a/surfsense_web/components/assistant-ui/reasoning-message-part.tsx b/surfsense_web/components/assistant-ui/reasoning-message-part.tsx new file mode 100644 index 000000000..70636eab8 --- /dev/null +++ b/surfsense_web/components/assistant-ui/reasoning-message-part.tsx @@ -0,0 +1,81 @@ +"use client"; + +import type { ReasoningMessagePartComponent } from "@assistant-ui/react"; +import { ChevronRightIcon } from "lucide-react"; +import { useEffect, useMemo, useState } from "react"; +import { TextShimmerLoader } from "@/components/prompt-kit/loader"; +import { cn } from "@/lib/utils"; + +/** + * Renders the structured `reasoning` part emitted by the backend's + * stream-parity v2 path (A1). + * + * Behaviour mirrors the existing `ThinkingStepsDisplay`: + * - collapsed by default; + * - auto-expanded while the part is still `running`; + * - auto-collapsed once status flips to `complete`. + * + * The component is registered via the `Reasoning` slot on + * `MessagePrimitive.Parts` in `assistant-message.tsx` so it lives at the + * exact ordinal position of the reasoning block in the message content + * array (i.e. above the assistant text that follows it). + */ +export const ReasoningMessagePart: ReasoningMessagePartComponent = ({ text, status }) => { + const isRunning = status?.type === "running"; + const [isOpen, setIsOpen] = useState(() => isRunning); + + useEffect(() => { + if (isRunning) { + setIsOpen(true); + } else if (status?.type === "complete") { + setIsOpen(false); + } + }, [isRunning, status?.type]); + + const headerLabel = useMemo(() => { + if (isRunning) return "Thinking"; + if (status?.type === "incomplete") return "Thinking interrupted"; + return "Thought"; + }, [isRunning, status?.type]); + + if (!text || text.length === 0) { + if (!isRunning) return null; + } + + return ( + <div className="mx-auto w-full max-w-(--thread-max-width) px-2 py-2"> + <div className="rounded-lg"> + <button + type="button" + onClick={() => setIsOpen((prev) => !prev)} + className={cn( + "flex w-full items-center gap-1.5 text-left text-sm transition-colors", + "text-muted-foreground hover:text-foreground" + )} + > + {isRunning ? ( + <TextShimmerLoader text={headerLabel} size="sm" /> + ) : ( + <span>{headerLabel}</span> + )} + <ChevronRightIcon + className={cn("size-4 transition-transform duration-200", isOpen && "rotate-90")} + /> + </button> + + <div + className={cn( + "grid transition-[grid-template-rows] duration-300 ease-out", + isOpen ? "grid-rows-[1fr]" : "grid-rows-[0fr]" + )} + > + <div className="overflow-hidden"> + <div className="mt-2 border-l border-muted-foreground/30 pl-3 text-sm leading-relaxed text-muted-foreground whitespace-pre-wrap wrap-break-word"> + {text} + </div> + </div> + </div> + </div> + </div> + ); +}; diff --git a/surfsense_web/components/assistant-ui/revert-turn-button.tsx b/surfsense_web/components/assistant-ui/revert-turn-button.tsx new file mode 100644 index 000000000..af71299d0 --- /dev/null +++ b/surfsense_web/components/assistant-ui/revert-turn-button.tsx @@ -0,0 +1,229 @@ +"use client"; + +/** + * "Revert turn" button rendered at the bottom of every completed + * assistant turn that has at least one reversible action. + * + * The button reads the action map keyed by ``chat_turn_id`` from the + * SSE side-channel (``data-action-log`` events). It shows a confirmation + * dialog summarising "N reversible / M total" and, on confirm, calls + * ``POST /threads/{id}/revert-turn/{chat_turn_id}``. + * + * The route returns a per-action result list and never collapses the + * batch into a 4xx — so we render any failed/not_reversible rows inline + * with their messages. + */ + +import { useAtomValue, useSetAtom } from "jotai"; +import { selectAtom } from "jotai/utils"; +import { CheckIcon, RotateCcw, XCircleIcon } from "lucide-react"; +import { useMemo, useState } from "react"; +import { toast } from "sonner"; +import { + type AgentActionLite, + agentActionsByChatTurnIdAtom, + markAgentActionsRevertedBatchAtom, +} from "@/atoms/chat/agent-actions.atom"; +import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, + AlertDialogTrigger, +} from "@/components/ui/alert-dialog"; +import { Button } from "@/components/ui/button"; +import { getToolDisplayName } from "@/contracts/enums/toolIcons"; +import { + agentActionsApiService, + type RevertTurnActionResult, +} from "@/lib/apis/agent-actions-api.service"; +import { AppError } from "@/lib/error"; +import { cn } from "@/lib/utils"; + +interface RevertTurnButtonProps { + chatTurnId: string | null | undefined; +} + +// Empty-array sentinel so the per-turn ``selectAtom`` slice returns a +// stable reference when the turn has no recorded actions yet. Without +// this every render allocates a fresh ``[]`` and Jotai's +// equality check would re-render the button on unrelated turn updates. +const EMPTY_ACTIONS: readonly AgentActionLite[] = Object.freeze([]); + +export function RevertTurnButton({ chatTurnId }: RevertTurnButtonProps) { + const session = useAtomValue(chatSessionStateAtom); + const markRevertedBatch = useSetAtom(markAgentActionsRevertedBatchAtom); + const [isReverting, setIsReverting] = useState(false); + const [confirmOpen, setConfirmOpen] = useState(false); + const [resultsOpen, setResultsOpen] = useState(false); + const [results, setResults] = useState<RevertTurnActionResult[]>([]); + + // Subscribe ONLY to the slice of the global action map that belongs + // to ``chatTurnId``. Previously the button read the whole + // ``agentActionsByChatTurnIdAtom``, which meant every action + // upsert (one per tool call) re-rendered every Revert button on + // the page. With ``selectAtom`` we re-render only when our turn's + // list reference changes — and the upsert/mark atoms produce a + // fresh list reference for the affected turn only. + const sliceAtom = useMemo( + () => + selectAtom( + agentActionsByChatTurnIdAtom, + (turnIndex) => (chatTurnId ? turnIndex.get(chatTurnId) : undefined) ?? EMPTY_ACTIONS + ), + [chatTurnId] + ); + const actions = useAtomValue(sliceAtom); + + const reversibleCount = useMemo( + () => + actions.filter( + (a) => a.reversible && a.revertedByActionId === null && !a.isRevertAction && !a.error + ).length, + [actions] + ); + const totalCount = useMemo(() => actions.filter((a) => !a.isRevertAction).length, [actions]); + + if (!chatTurnId) return null; + if (reversibleCount === 0) return null; + const threadId = session?.threadId; + if (!threadId) return null; + + const handleRevertTurn = async () => { + setIsReverting(true); + try { + const response = await agentActionsApiService.revertTurn(threadId, chatTurnId); + setResults(response.results); + const revertedEntries = response.results + .filter((r) => r.status === "reverted" || r.status === "already_reverted") + .map((r) => ({ id: r.action_id, newActionId: r.new_action_id ?? null })); + if (revertedEntries.length > 0) { + markRevertedBatch({ entries: revertedEntries }); + } + if (response.status === "ok") { + toast.success( + response.reverted === 1 ? "Reverted 1 action." : `Reverted ${response.reverted} actions.` + ); + } else { + // Every "not undone" bucket counts as a failure for the + // user-facing summary. ``skipped`` rows are batch + // artefacts (revert rows themselves) and intentionally + // excluded from the failure tally. + const failureCount = + response.failed + response.not_reversible + (response.permission_denied ?? 0); + toast.warning( + `Reverted ${response.reverted} of ${response.total}. ${failureCount} could not be undone.` + ); + setResultsOpen(true); + } + } catch (err) { + if (err instanceof AppError && err.status === 503) { + return; + } + const message = + err instanceof AppError + ? err.message + : err instanceof Error + ? err.message + : "Failed to revert turn."; + toast.error(message); + } finally { + setIsReverting(false); + setConfirmOpen(false); + } + }; + + return ( + <> + <AlertDialog open={confirmOpen} onOpenChange={setConfirmOpen}> + <AlertDialogTrigger asChild> + <Button + size="sm" + variant="ghost" + className="text-muted-foreground hover:text-foreground gap-1.5" + onClick={(e) => { + e.stopPropagation(); + setConfirmOpen(true); + }} + > + <RotateCcw className="size-3.5" /> + <span>Revert turn</span> + <span className="text-xs tabular-nums opacity-70"> + {reversibleCount}/{totalCount} + </span> + </Button> + </AlertDialogTrigger> + <AlertDialogContent> + <AlertDialogHeader> + <AlertDialogTitle>Revert this turn?</AlertDialogTitle> + <AlertDialogDescription> + This will undo {reversibleCount} of {totalCount} action + {totalCount === 1 ? "" : "s"} from this turn in reverse order. The chat history and + any read-only actions are preserved. Some rows may not be reversible — partial success + is normal. + </AlertDialogDescription> + </AlertDialogHeader> + <AlertDialogFooter> + <AlertDialogCancel disabled={isReverting}>Cancel</AlertDialogCancel> + <AlertDialogAction + onClick={(e) => { + e.preventDefault(); + handleRevertTurn(); + }} + disabled={isReverting} + > + {isReverting ? "Reverting…" : "Revert turn"} + </AlertDialogAction> + </AlertDialogFooter> + </AlertDialogContent> + </AlertDialog> + + <AlertDialog open={resultsOpen} onOpenChange={setResultsOpen}> + <AlertDialogContent> + <AlertDialogHeader> + <AlertDialogTitle>Revert results</AlertDialogTitle> + <AlertDialogDescription> + Some actions could not be reverted. Review per-row outcomes below. + </AlertDialogDescription> + </AlertDialogHeader> + <ul className="max-h-72 overflow-y-auto space-y-2 text-sm"> + {results.map((r) => ( + <RevertResultRow key={r.action_id} result={r} /> + ))} + </ul> + <AlertDialogFooter> + <AlertDialogAction onClick={() => setResultsOpen(false)}>Close</AlertDialogAction> + </AlertDialogFooter> + </AlertDialogContent> + </AlertDialog> + </> + ); +} + +function RevertResultRow({ result }: { result: RevertTurnActionResult }) { + const isOk = result.status === "reverted" || result.status === "already_reverted"; + const Icon = isOk ? CheckIcon : XCircleIcon; + return ( + <li className="flex items-start gap-2 rounded-md border bg-muted/30 px-3 py-2"> + <Icon + className={cn("size-4 mt-0.5 shrink-0", isOk ? "text-emerald-500" : "text-destructive")} + /> + <div className="min-w-0 flex-1"> + <p className="font-medium truncate"> + {getToolDisplayName(result.tool_name)}{" "} + <span className="ml-1 text-xs text-muted-foreground"> + {result.status.replace(/_/g, " ")} + </span> + </p> + {(result.message || result.error) && ( + <p className="text-xs text-muted-foreground mt-0.5">{result.error ?? result.message}</p> + )} + </div> + </li> + ); +} diff --git a/surfsense_web/components/assistant-ui/step-separator.tsx b/surfsense_web/components/assistant-ui/step-separator.tsx new file mode 100644 index 000000000..f59130661 --- /dev/null +++ b/surfsense_web/components/assistant-ui/step-separator.tsx @@ -0,0 +1,27 @@ +"use client"; + +import { makeAssistantDataUI } from "@assistant-ui/react"; + +/** + * Renders a thin horizontal divider between model steps within a single + * assistant turn. The data part is pushed by `addStepSeparator` in + * `streaming-state.ts` whenever a `start-step` SSE event arrives after + * the message already has non-step content. + * + * Today the backend emits one `start-step` / `finish-step` pair per turn, + * so most messages won't contain a separator. The renderer is wired up so + * the planned per-model-step refactor (A2 follow-up) can light up without + * touching the persistence path. + */ +function StepSeparatorDataRenderer() { + return ( + <div className="mx-auto my-3 w-full max-w-(--thread-max-width) px-2"> + <div className="border-t border-border/60" /> + </div> + ); +} + +export const StepSeparatorDataUI = makeAssistantDataUI({ + name: "step-separator", + render: StepSeparatorDataRenderer, +}); diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index 3095556dc..3e27e7adb 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -85,6 +85,7 @@ import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import { CONNECTOR_ICON_TO_TYPES, CONNECTOR_TOOL_ICON_PATHS, + getToolDisplayName, getToolIcon, } from "@/contracts/enums/toolIcons"; import type { Document } from "@/contracts/types/document.types"; @@ -1354,12 +1355,14 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false ); }; -/** Convert snake_case tool names to human-readable labels */ +/** + * Friendly tool name for display in the chat UI. Delegates to the + * shared map in ``contracts/enums/toolIcons`` so unix-style identifiers + * (``rm``, ``ls``, ``grep`` …) and snake_cased function names render as + * plain English (e.g. "Delete file", "List files", "Search in files"). + */ function formatToolName(name: string): string { - return name - .split("_") - .map((word) => word.charAt(0).toUpperCase() + word.slice(1)) - .join(" "); + return getToolDisplayName(name); } interface ToolGroup { diff --git a/surfsense_web/components/assistant-ui/tool-fallback.tsx b/surfsense_web/components/assistant-ui/tool-fallback.tsx index 112f3e1d8..cc7582695 100644 --- a/surfsense_web/components/assistant-ui/tool-fallback.tsx +++ b/surfsense_web/components/assistant-ui/tool-fallback.tsx @@ -1,20 +1,130 @@ import type { ToolCallMessagePartComponent } from "@assistant-ui/react"; -import { CheckIcon, ChevronDownIcon, ChevronUpIcon, XCircleIcon } from "lucide-react"; +import { useAtomValue, useSetAtom } from "jotai"; +import { CheckIcon, ChevronDownIcon, ChevronUpIcon, RotateCcw, XCircleIcon } from "lucide-react"; import { useMemo, useState } from "react"; +import { toast } from "sonner"; +import { + agentActionByToolCallIdAtom, + markAgentActionRevertedAtom, +} from "@/atoms/chat/agent-actions.atom"; +import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom"; import { DoomLoopApprovalToolUI, isDoomLoopInterrupt, } from "@/components/tool-ui/doom-loop-approval"; import { GenericHitlApprovalToolUI } from "@/components/tool-ui/generic-hitl-approval"; -import { getToolIcon } from "@/contracts/enums/toolIcons"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, + AlertDialogTrigger, +} from "@/components/ui/alert-dialog"; +import { Button } from "@/components/ui/button"; +import { getToolDisplayName, getToolIcon } from "@/contracts/enums/toolIcons"; +import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; +import { AppError } from "@/lib/error"; import { isInterruptResult } from "@/lib/hitl"; import { cn } from "@/lib/utils"; -function formatToolName(name: string): string { - return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase()); +/** + * Inline Revert button rendered on a tool card when the matching + * ``AgentActionLog`` row is reversible and hasn't been reverted yet. + * Reads from the SSE side-channel atom keyed by the synthetic + * ``toolCallId`` so it lights up even when ``GET /threads/.../actions`` + * is gated behind ``SURFSENSE_ENABLE_ACTION_LOG=False`` (503). + */ +function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) { + const session = useAtomValue(chatSessionStateAtom); + const actionMap = useAtomValue(agentActionByToolCallIdAtom); + const markReverted = useSetAtom(markAgentActionRevertedAtom); + const action = actionMap.get(toolCallId); + const [isReverting, setIsReverting] = useState(false); + const [confirmOpen, setConfirmOpen] = useState(false); + + if (!action) return null; + if (!action.reversible) return null; + if (action.revertedByActionId !== null) return null; + if (action.isRevertAction) return null; + if (action.error) return null; + const threadId = session?.threadId; + if (!threadId) return null; + + const handleRevert = async () => { + setIsReverting(true); + try { + const response = await agentActionsApiService.revert(threadId, action.id); + markReverted({ id: action.id, newActionId: response.new_action_id ?? null }); + toast.success(response.message || "Action reverted."); + } catch (err) { + // 503 means revert is gated off on this deployment — hide the + // button silently rather than nagging the user. Any other error + // is surfaced as a toast so the operator can investigate. + if (err instanceof AppError && err.status === 503) { + return; + } + const message = + err instanceof AppError + ? err.message + : err instanceof Error + ? err.message + : "Failed to revert action."; + toast.error(message); + } finally { + setIsReverting(false); + setConfirmOpen(false); + } + }; + + return ( + <AlertDialog open={confirmOpen} onOpenChange={setConfirmOpen}> + <AlertDialogTrigger asChild> + <Button + size="sm" + variant="outline" + className="gap-1.5" + onClick={(e) => { + e.stopPropagation(); + setConfirmOpen(true); + }} + > + <RotateCcw className="size-3.5" /> + Revert + </Button> + </AlertDialogTrigger> + <AlertDialogContent> + <AlertDialogHeader> + <AlertDialogTitle>Revert this action?</AlertDialogTitle> + <AlertDialogDescription> + This will undo{" "} + <span className="font-medium">{getToolDisplayName(action.toolName)}</span> and add a + new entry to the history. Your chat is preserved — only the changes the agent made to + your knowledge base or connected apps will be rolled back where possible. + </AlertDialogDescription> + </AlertDialogHeader> + <AlertDialogFooter> + <AlertDialogCancel disabled={isReverting}>Cancel</AlertDialogCancel> + <AlertDialogAction + onClick={(e) => { + e.preventDefault(); + handleRevert(); + }} + disabled={isReverting} + > + {isReverting ? "Reverting…" : "Revert"} + </AlertDialogAction> + </AlertDialogFooter> + </AlertDialogContent> + </AlertDialog> + ); } const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({ + toolCallId, toolName, argsText, result, @@ -51,7 +161,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({ : null; const Icon = getToolIcon(toolName); - const displayName = formatToolName(toolName); + const displayName = getToolDisplayName(toolName); return ( <div @@ -102,7 +212,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({ ? `Failed: ${displayName}` : displayName} </p> - {isRunning && <p className="text-xs text-muted-foreground mt-0.5">Running...</p>} + {isRunning && <p className="text-xs text-muted-foreground mt-0.5">Working…</p>} {cancelledReason && ( <p className="text-xs text-muted-foreground mt-0.5 truncate">{cancelledReason}</p> )} @@ -128,7 +238,7 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({ <div className="px-5 py-3 space-y-3"> {argsText && ( <div> - <p className="text-xs font-medium text-muted-foreground mb-1">Arguments</p> + <p className="text-xs font-medium text-muted-foreground mb-1">Inputs</p> <pre className="text-xs text-foreground/80 whitespace-pre-wrap break-all"> {argsText} </pre> @@ -145,6 +255,9 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({ </div> </> )} + <div className="flex justify-end"> + <ToolCardRevertButton toolCallId={toolCallId} /> + </div> </div> </> )} diff --git a/surfsense_web/components/free-chat/free-chat-page.tsx b/surfsense_web/components/free-chat/free-chat-page.tsx index dd6693b35..a287b9dc5 100644 --- a/surfsense_web/components/free-chat/free-chat-page.tsx +++ b/surfsense_web/components/free-chat/free-chat-page.tsx @@ -9,6 +9,7 @@ import { import { Turnstile, type TurnstileInstance } from "@marsidev/react-turnstile"; import { ShieldCheck } from "lucide-react"; import { useCallback, useEffect, useRef, useState } from "react"; +import { StepSeparatorDataUI } from "@/components/assistant-ui/step-separator"; import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps"; import { createTokenUsageStore, @@ -17,10 +18,13 @@ import { } from "@/components/assistant-ui/token-usage-context"; import { useAnonymousMode } from "@/contexts/anonymous-mode"; import { + addStepSeparator, addToolCall, + appendReasoning, appendText, buildContentForUI, type ContentPartsState, + endReasoning, FrameBatchedUpdater, readSSEStream, type ThinkingStepData, @@ -32,7 +36,9 @@ import { trackAnonymousChatMessageSent } from "@/lib/posthog/events"; import { FreeModelSelector } from "./free-model-selector"; import { FreeThread } from "./free-thread"; -const TOOLS_WITH_UI = new Set(["web_search", "document_qna"]); +// Render all tool calls via ToolFallback; backend keeps persisted +// payloads bounded by summarising / truncating outputs. +const TOOLS_WITH_UI = "all" as const; const TURNSTILE_SITE_KEY = process.env.NEXT_PUBLIC_TURNSTILE_SITE_KEY ?? ""; /** Try to parse a CAPTCHA_REQUIRED or CAPTCHA_INVALID code from a non-ok response. */ @@ -167,6 +173,7 @@ export function FreeChatPage() { const contentPartsState: ContentPartsState = { contentParts: [], currentTextPartIndex: -1, + currentReasoningPartIndex: -1, toolCallIndices: new Map(), }; const { toolCallIndices } = contentPartsState; @@ -190,28 +197,62 @@ export function FreeChatPage() { scheduleFlush(); break; + case "reasoning-delta": + appendReasoning(contentPartsState, parsed.delta); + scheduleFlush(); + break; + + case "reasoning-end": + endReasoning(contentPartsState); + scheduleFlush(); + break; + + case "start-step": + addStepSeparator(contentPartsState); + scheduleFlush(); + break; + + case "finish-step": + break; + case "tool-input-start": - addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {}); + addToolCall( + contentPartsState, + TOOLS_WITH_UI, + parsed.toolCallId, + parsed.toolName, + {}, + false, + parsed.langchainToolCallId + ); batcher.flush(); break; case "tool-input-available": if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {} }); + updateToolCall(contentPartsState, parsed.toolCallId, { + args: parsed.input || {}, + langchainToolCallId: parsed.langchainToolCallId, + }); } else { addToolCall( contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, - parsed.input || {} + parsed.input || {}, + false, + parsed.langchainToolCallId ); } batcher.flush(); break; case "tool-output-available": - updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output }); + updateToolCall(contentPartsState, parsed.toolCallId, { + result: parsed.output, + langchainToolCallId: parsed.langchainToolCallId, + }); batcher.flush(); break; @@ -413,6 +454,7 @@ export function FreeChatPage() { <TokenUsageProvider store={tokenUsageStore}> <AssistantRuntimeProvider runtime={runtime}> <ThinkingStepsDataUI /> + <StepSeparatorDataUI /> <div className="flex h-full flex-col overflow-hidden"> <div className="flex h-14 shrink-0 items-center justify-between border-b border-border/40 px-4"> <FreeModelSelector /> diff --git a/surfsense_web/components/public-chat/public-chat-view.tsx b/surfsense_web/components/public-chat/public-chat-view.tsx index f8dd6db5a..e47ba9bf1 100644 --- a/surfsense_web/components/public-chat/public-chat-view.tsx +++ b/surfsense_web/components/public-chat/public-chat-view.tsx @@ -1,6 +1,7 @@ "use client"; import { AssistantRuntimeProvider } from "@assistant-ui/react"; +import { StepSeparatorDataUI } from "@/components/assistant-ui/step-separator"; import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps"; import { Navbar } from "@/components/homepage/navbar"; import { ReportPanel } from "@/components/report-panel/report-panel"; @@ -41,6 +42,7 @@ export function PublicChatView({ shareToken }: PublicChatViewProps) { <Navbar scrolledBgClassName={navbarScrolledBg} /> <AssistantRuntimeProvider runtime={runtime}> <ThinkingStepsDataUI /> + <StepSeparatorDataUI /> <div className="flex h-screen pt-16 overflow-hidden"> <div className="flex-1 flex flex-col min-w-0 overflow-hidden"> <PublicThread footer={<PublicChatFooter shareToken={shareToken} />} /> diff --git a/surfsense_web/components/public-chat/public-thread.tsx b/surfsense_web/components/public-chat/public-thread.tsx index 627baf831..22e914988 100644 --- a/surfsense_web/components/public-chat/public-thread.tsx +++ b/surfsense_web/components/public-chat/public-thread.tsx @@ -13,6 +13,7 @@ import Image from "next/image"; import { type FC, type ReactNode, useState } from "react"; import { CitationMetadataProvider } from "@/components/assistant-ui/citation-metadata-context"; import { MarkdownText } from "@/components/assistant-ui/markdown-text"; +import { ReasoningMessagePart } from "@/components/assistant-ui/reasoning-message-part"; import { ToolFallback } from "@/components/assistant-ui/tool-fallback"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { GenerateImageToolUI } from "@/components/tool-ui/generate-image"; @@ -157,6 +158,7 @@ const PublicAssistantMessage: FC = () => { <MessagePrimitive.Parts components={{ Text: MarkdownText, + Reasoning: ReasoningMessagePart, tools: { by_name: { generate_podcast: GeneratePodcastToolUI, diff --git a/surfsense_web/components/tool-ui/generic-hitl-approval.tsx b/surfsense_web/components/tool-ui/generic-hitl-approval.tsx index ceb1d0209..a584084ff 100644 --- a/surfsense_web/components/tool-ui/generic-hitl-approval.tsx +++ b/surfsense_web/components/tool-ui/generic-hitl-approval.tsx @@ -8,6 +8,7 @@ import { TextShimmerLoader } from "@/components/prompt-kit/loader"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; import { Textarea } from "@/components/ui/textarea"; +import { getToolDisplayName } from "@/contracts/enums/toolIcons"; import { useHitlPhase } from "@/hooks/use-hitl-phase"; import { connectorsApiService } from "@/lib/apis/connectors-api.service"; import type { HitlDecision, InterruptResult } from "@/lib/hitl"; @@ -77,7 +78,7 @@ function GenericApprovalCard({ const [editedParams, setEditedParams] = useState<Record<string, unknown>>(args); const [isEditing, setIsEditing] = useState(false); - const displayName = toolName.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase()); + const displayName = getToolDisplayName(toolName); const mcpServer = interruptData.context?.mcp_server as string | undefined; const toolDescription = interruptData.context?.tool_description as string | undefined; @@ -186,12 +187,11 @@ function GenericApprovalCard({ </> )} - {/* Parameters */} {Object.keys(args).length > 0 && ( <> <div className="mx-5 h-px bg-border/50" /> <div className="px-5 py-4 space-y-2"> - <p className="text-xs font-medium text-muted-foreground">Parameters</p> + <p className="text-xs font-medium text-muted-foreground">Inputs</p> {phase === "pending" && isEditing ? ( <ParamEditor params={editedParams} diff --git a/surfsense_web/contracts/enums/toolIcons.tsx b/surfsense_web/contracts/enums/toolIcons.tsx index bc63bc1b0..bdb8222cb 100644 --- a/surfsense_web/contracts/enums/toolIcons.tsx +++ b/surfsense_web/contracts/enums/toolIcons.tsx @@ -1,33 +1,223 @@ import { BookOpen, Brain, + Calendar, + Check, + FileEdit, + FilePlus, FileText, FileUser, + FileX, Film, + FolderPlus, + FolderTree, + FolderX, Globe, ImageIcon, + ListTodo, type LucideIcon, + Mail, + MessagesSquare, + Move, + Plus, Podcast, ScanLine, + Search, + Send, + Trash2, Wrench, } from "lucide-react"; +/** + * Every tool now renders a card via ``ToolFallback``. The icon map is + * keyed on the canonical backend tool name (registered in + * ``surfsense_backend/app/agents/new_chat/tools/registry.py``); unknown + * names fall back to the generic ``Wrench`` icon so the card still + * communicates "this is a tool call". + */ const TOOL_ICONS: Record<string, LucideIcon> = { + // Generators generate_podcast: Podcast, generate_video_presentation: Film, generate_report: FileText, generate_resume: FileUser, generate_image: ImageIcon, + display_image: ImageIcon, + // Web / search scrape_webpage: ScanLine, web_search: Globe, search_surfsense_docs: BookOpen, + // Memory update_memory: Brain, + // Filesystem (built-in deepagent + middleware) + read_file: FileText, + write_file: FilePlus, + edit_file: FileEdit, + move_file: Move, + rm: FileX, + rmdir: FolderX, + mkdir: FolderPlus, + ls: FolderTree, + write_todos: ListTodo, + // Calendar + search_calendar_events: Search, + create_calendar_event: Calendar, + update_calendar_event: Calendar, + delete_calendar_event: Calendar, + // Gmail + search_gmail: Search, + read_gmail_email: Mail, + create_gmail_draft: Mail, + update_gmail_draft: FileEdit, + send_gmail_email: Send, + trash_gmail_email: Trash2, + // Notion / Confluence pages + create_notion_page: FilePlus, + update_notion_page: FileEdit, + delete_notion_page: FileX, + create_confluence_page: FilePlus, + update_confluence_page: FileEdit, + delete_confluence_page: FileX, + // Linear / Jira issues + create_linear_issue: Plus, + update_linear_issue: FileEdit, + delete_linear_issue: Trash2, + create_jira_issue: Plus, + update_jira_issue: FileEdit, + delete_jira_issue: Trash2, + // Drive-like file connectors + create_google_drive_file: FilePlus, + delete_google_drive_file: FileX, + create_dropbox_file: FilePlus, + delete_dropbox_file: FileX, + create_onedrive_file: FilePlus, + delete_onedrive_file: FileX, + // Chat connectors + list_discord_channels: MessagesSquare, + read_discord_messages: MessagesSquare, + send_discord_message: Send, + list_teams_channels: MessagesSquare, + read_teams_messages: MessagesSquare, + send_teams_message: Send, + // Luma + list_luma_events: Calendar, + read_luma_event: Calendar, + create_luma_event: Calendar, + // Misc + get_connected_accounts: Check, + execute: Wrench, + execute_code: Wrench, }; export function getToolIcon(name: string): LucideIcon { return TOOL_ICONS[name] ?? Wrench; } +/** + * Friendly display names for tools shown in the chat UI. + * + * Most users aren't engineers; they shouldn't see raw unix-style + * identifiers like ``rm`` / ``rmdir`` / ``ls`` / ``grep`` / ``glob`` or + * snake_cased function names. The map below renders each tool with + * plain English wording (verb + object) so non-technical users + * understand what the agent is doing at a glance. + * + * Unmapped tool names fall back to a snake_case-to-Title-Case + * conversion via :func:`getToolDisplayName`. + */ +const TOOL_DISPLAY_NAMES: Record<string, string> = { + // Filesystem / knowledge base + read_file: "Read file", + write_file: "Write file", + edit_file: "Edit file", + move_file: "Move file", + rm: "Delete file", + rmdir: "Delete folder", + mkdir: "Create folder", + ls: "List files", + glob: "Find files", + grep: "Search in files", + write_todos: "Plan tasks", + save_document: "Save document", + // Generators + generate_podcast: "Generate podcast", + generate_video_presentation: "Generate video presentation", + generate_report: "Generate report", + generate_resume: "Generate resume", + generate_image: "Generate image", + display_image: "Show image", + // Web / search + scrape_webpage: "Read webpage", + web_search: "Search the web", + search_surfsense_docs: "Search knowledge base", + // Memory + update_memory: "Update memory", + // Calendar + search_calendar_events: "Search calendar", + create_calendar_event: "Create event", + update_calendar_event: "Update event", + delete_calendar_event: "Delete event", + // Gmail + search_gmail: "Search Gmail", + read_gmail_email: "Read email", + create_gmail_draft: "Draft email", + update_gmail_draft: "Update draft", + send_gmail_email: "Send email", + trash_gmail_email: "Move email to trash", + // Notion + create_notion_page: "Create Notion page", + update_notion_page: "Update Notion page", + delete_notion_page: "Delete Notion page", + // Confluence + create_confluence_page: "Create Confluence page", + update_confluence_page: "Update Confluence page", + delete_confluence_page: "Delete Confluence page", + // Linear + create_linear_issue: "Create Linear issue", + update_linear_issue: "Update Linear issue", + delete_linear_issue: "Delete Linear issue", + // Jira + create_jira_issue: "Create Jira issue", + update_jira_issue: "Update Jira issue", + delete_jira_issue: "Delete Jira issue", + // Drive-like file connectors + create_google_drive_file: "Create Google Drive file", + delete_google_drive_file: "Delete Google Drive file", + create_dropbox_file: "Create Dropbox file", + delete_dropbox_file: "Delete Dropbox file", + create_onedrive_file: "Create OneDrive file", + delete_onedrive_file: "Delete OneDrive file", + // Discord + list_discord_channels: "List Discord channels", + read_discord_messages: "Read Discord messages", + send_discord_message: "Send Discord message", + // Teams + list_teams_channels: "List Teams channels", + read_teams_messages: "Read Teams messages", + send_teams_message: "Send Teams message", + // Luma + list_luma_events: "List Luma events", + read_luma_event: "Read Luma event", + create_luma_event: "Create Luma event", + // Misc + get_connected_accounts: "Check connected accounts", + execute: "Run command", + execute_code: "Run code", +}; + +/** + * Format a tool's canonical (snake_case) name for display in the chat UI. + * + * Looks up :data:`TOOL_DISPLAY_NAMES` first; falls back to a + * snake_case-to-Title-Case rewrite for tools that don't have a curated + * label (e.g. dynamically registered MCP tools). + */ +export function getToolDisplayName(name: string): string { + const friendly = TOOL_DISPLAY_NAMES[name]; + if (friendly) return friendly; + return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase()); +} + export const CONNECTOR_TOOL_ICON_PATHS: Record<string, { src: string; alt: string }> = { gmail: { src: "/connectors/google-gmail.svg", alt: "Gmail" }, google_calendar: { src: "/connectors/google-calendar.svg", alt: "Google Calendar" }, diff --git a/surfsense_web/lib/apis/agent-actions-api.service.ts b/surfsense_web/lib/apis/agent-actions-api.service.ts index 007bb131e..6634a11f7 100644 --- a/surfsense_web/lib/apis/agent-actions-api.service.ts +++ b/surfsense_web/lib/apis/agent-actions-api.service.ts @@ -15,6 +15,12 @@ const AgentActionReadSchema = z.object({ reverse_of: z.number().nullable(), reverted_by_action_id: z.number().nullable(), is_revert_action: z.boolean(), + // Correlation ids added in migration 135. The LangChain + // ``tool_call_id`` joins this row to the chat tool card via the + // ``data-action-log.lc_tool_call_id`` SSE event, and + // ``chat_turn_id`` keys the per-turn revert endpoint. + tool_call_id: z.string().nullable().optional(), + chat_turn_id: z.string().nullable().optional(), created_at: z.string(), }); @@ -38,6 +44,48 @@ const RevertResponseSchema = z.object({ export type RevertResponse = z.infer<typeof RevertResponseSchema>; +// Per-turn batch revert. The route never returns whole-batch 4xx; +// partial success is the common case and surfaced as +// ``status === "partial"`` with a per-action result list. +const RevertTurnActionResultSchema = z.object({ + action_id: z.number(), + tool_name: z.string(), + status: z.enum([ + "reverted", + "already_reverted", + "not_reversible", + "permission_denied", + "failed", + "skipped", + ]), + message: z.string().nullable().optional(), + new_action_id: z.number().nullable().optional(), + error: z.string().nullable().optional(), +}); + +export type RevertTurnActionResult = z.infer<typeof RevertTurnActionResultSchema>; + +const RevertTurnResponseSchema = z.object({ + status: z.enum(["ok", "partial"]), + chat_turn_id: z.string(), + total: z.number(), + reverted: z.number(), + already_reverted: z.number(), + not_reversible: z.number(), + // ``permission_denied`` and ``skipped`` are first-class counters so + // ``total === reverted + already_reverted + + // not_reversible + permission_denied + failed + skipped`` always + // holds. ``.default(0)`` keeps the schema backwards-compatible + // with older deployments that haven't shipped the response model + // update yet. + permission_denied: z.number().default(0), + failed: z.number(), + skipped: z.number().default(0), + results: z.array(RevertTurnActionResultSchema), +}); + +export type RevertTurnResponse = z.infer<typeof RevertTurnResponseSchema>; + class AgentActionsApiService { listForThread = async ( threadId: number, @@ -59,6 +107,14 @@ class AgentActionsApiService { { body: {} } ); }; + + revertTurn = async (threadId: number, chatTurnId: string): Promise<RevertTurnResponse> => { + return baseApiService.post( + `/api/v1/threads/${threadId}/revert-turn/${encodeURIComponent(chatTurnId)}`, + RevertTurnResponseSchema, + { body: {} } + ); + }; } export const agentActionsApiService = new AgentActionsApiService(); diff --git a/surfsense_web/lib/chat/message-utils.ts b/surfsense_web/lib/chat/message-utils.ts index 2d1a6976f..004542489 100644 --- a/surfsense_web/lib/chat/message-utils.ts +++ b/surfsense_web/lib/chat/message-utils.ts @@ -40,7 +40,7 @@ export function convertToThreadMessage(msg: MessageRecord): ThreadMessageLike { } const metadata = - msg.author_id || msg.token_usage + msg.author_id || msg.token_usage || msg.turn_id ? { custom: { ...(msg.author_id && { @@ -50,6 +50,10 @@ export function convertToThreadMessage(msg: MessageRecord): ThreadMessageLike { }, }), ...(msg.token_usage && { usage: msg.token_usage }), + // Surface ``chat_turn_id`` so the assistant message + // footer can scope its "Revert turn" button to just + // this turn's actions. Null on legacy rows. + ...(msg.turn_id && { chatTurnId: msg.turn_id }), }, } : undefined; diff --git a/surfsense_web/lib/chat/streaming-state.ts b/surfsense_web/lib/chat/streaming-state.ts index 9f2ac87a5..9dad198e3 100644 --- a/surfsense_web/lib/chat/streaming-state.ts +++ b/surfsense_web/lib/chat/streaming-state.ts @@ -9,21 +9,42 @@ export interface ThinkingStepData { export type ContentPart = | { type: "text"; text: string } + | { type: "reasoning"; text: string } | { type: "tool-call"; toolCallId: string; toolName: string; args: Record<string, unknown>; result?: unknown; + /** + * Authoritative LangChain ``tool_call.id`` propagated by the backend + * via ``langchainToolCallId`` on tool-input-start/available and + * tool-output-available events. Used to join a card to the + * matching ``AgentActionLog`` row exposed by + * ``GET /threads/{id}/actions`` and the streamed + * ``data-action-log`` events. + */ + langchainToolCallId?: string; } | { type: "data-thinking-steps"; data: { steps: ThinkingStepData[] }; + } + | { + /** + * Between-step separator. Pushed by `addStepSeparator` when + * a `start-step` SSE event arrives AFTER the message already + * has non-step content. Rendered by `StepSeparatorDataUI` + * (see assistant-ui/step-separator.tsx). + */ + type: "data-step-separator"; + data: { stepIndex: number }; }; export interface ContentPartsState { contentParts: ContentPart[]; currentTextPartIndex: number; + currentReasoningPartIndex: number; toolCallIndices: Map<string, number>; } @@ -74,6 +95,9 @@ export function updateThinkingSteps( if (state.currentTextPartIndex >= 0) { state.currentTextPartIndex += 1; } + if (state.currentReasoningPartIndex >= 0) { + state.currentReasoningPartIndex += 1; + } for (const [id, idx] of state.toolCallIndices) { state.toolCallIndices.set(id, idx + 1); } @@ -131,6 +155,12 @@ export class FrameBatchedUpdater { } export function appendText(state: ContentPartsState, delta: string): void { + // First text delta after a reasoning block: close the reasoning so + // the assistant-ui renderer treats them as separate parts (the + // reasoning block collapses; the answer streams below). + if (state.currentReasoningPartIndex >= 0) { + state.currentReasoningPartIndex = -1; + } if ( state.currentTextPartIndex >= 0 && state.contentParts[state.currentTextPartIndex]?.type === "text" @@ -143,36 +173,129 @@ export function appendText(state: ContentPartsState, delta: string): void { } } +export function appendReasoning(state: ContentPartsState, delta: string): void { + // Symmetric to appendText: open a fresh reasoning block on first + // delta, then accumulate into it. ``endReasoning`` simply closes + // the active block; subsequent reasoning deltas would open a new + // one (matching ``text-start/end`` semantics on the wire). + if (state.currentTextPartIndex >= 0) { + state.currentTextPartIndex = -1; + } + if ( + state.currentReasoningPartIndex >= 0 && + state.contentParts[state.currentReasoningPartIndex]?.type === "reasoning" + ) { + ( + state.contentParts[state.currentReasoningPartIndex] as { + type: "reasoning"; + text: string; + } + ).text += delta; + } else { + state.contentParts.push({ type: "reasoning", text: delta }); + state.currentReasoningPartIndex = state.contentParts.length - 1; + } +} + +export function endReasoning(state: ContentPartsState): void { + state.currentReasoningPartIndex = -1; +} + +export function addStepSeparator(state: ContentPartsState): void { + // Push a divider between consecutive model steps within a single + // assistant turn. We only emit it when the message already has + // non-step content (so the FIRST step of a turn doesn't + // generate a leading separator) and when the previous part isn't + // itself a separator (defensive against duplicate `start-step` + // events). + const hasContent = state.contentParts.some( + (p) => p.type === "text" || p.type === "reasoning" || p.type === "tool-call" + ); + if (!hasContent) return; + const last = state.contentParts[state.contentParts.length - 1]; + if (last && last.type === "data-step-separator") return; + + const stepIndex = state.contentParts.filter((p) => p.type === "data-step-separator").length; + state.contentParts.push({ type: "data-step-separator", data: { stepIndex } }); + state.currentTextPartIndex = -1; + state.currentReasoningPartIndex = -1; +} + +/** + * Allowlist of tool names that should produce a UI tool card. The + * sentinel ``"all"`` matches every tool — we dropped the legacy + * ``BASE_TOOLS_WITH_UI`` gate so that ALL tool calls render via the + * generic ``ToolFallback``. The backend's ``format_thinking_step`` + * summarisation and the defensive ``result_length``-only default for + * unknown tools keep persisted message JSON from ballooning. + */ +export type ToolUIGate = Set<string> | "all"; + +function _toolPasses(gate: ToolUIGate, toolName: string): boolean { + return gate === "all" || gate.has(toolName); +} + export function addToolCall( state: ContentPartsState, - toolsWithUI: Set<string>, + toolsWithUI: ToolUIGate, toolCallId: string, toolName: string, args: Record<string, unknown>, - force = false + force = false, + langchainToolCallId?: string ): void { - if (force || toolsWithUI.has(toolName)) { + if (force || _toolPasses(toolsWithUI, toolName)) { state.contentParts.push({ type: "tool-call", toolCallId, toolName, args, + ...(langchainToolCallId ? { langchainToolCallId } : {}), }); state.toolCallIndices.set(toolCallId, state.contentParts.length - 1); state.currentTextPartIndex = -1; + state.currentReasoningPartIndex = -1; } } +/** + * Reverse-lookup helper used by the SSE ``data-action-log`` handler: + * given the LangChain ``tool_call.id`` (set on the content part as + * ``langchainToolCallId``), return the synthetic ``toolCallId`` that + * the chat tool card uses (``call_<run-id>``). Returns ``null`` when no + * matching tool card has been seen yet — the action is still recorded + * in the LC-id-keyed atom so the card can pick it up when it eventually + * arrives. + */ +export function findToolCallIdByLcId( + state: ContentPartsState, + lcToolCallId: string +): string | null { + for (const part of state.contentParts) { + if (part.type === "tool-call" && part.langchainToolCallId === lcToolCallId) { + return part.toolCallId; + } + } + return null; +} + export function updateToolCall( state: ContentPartsState, toolCallId: string, - update: { args?: Record<string, unknown>; result?: unknown } + update: { args?: Record<string, unknown>; result?: unknown; langchainToolCallId?: string } ): void { const index = state.toolCallIndices.get(toolCallId); if (index !== undefined && state.contentParts[index]?.type === "tool-call") { const tc = state.contentParts[index] as ContentPart & { type: "tool-call" }; if (update.args) tc.args = update.args; if (update.result !== undefined) tc.result = update.result; + // Only backfill langchainToolCallId if not already set — the + // authoritative ``on_tool_end`` value should override an earlier + // best-effort match, but a NULL late-arriving value should not + // blow away a known good early one. + if (update.langchainToolCallId && !tc.langchainToolCallId) { + tc.langchainToolCallId = update.langchainToolCallId; + } } } @@ -184,13 +307,15 @@ function _hasInterruptResult(part: ContentPart): boolean { export function buildContentForUI( state: ContentPartsState, - toolsWithUI: Set<string> + toolsWithUI: ToolUIGate ): ThreadMessageLike["content"] { const filtered = state.contentParts.filter((part) => { if (part.type === "text") return part.text.length > 0; + if (part.type === "reasoning") return part.text.length > 0; if (part.type === "tool-call") - return toolsWithUI.has(part.toolName) || _hasInterruptResult(part); + return _toolPasses(toolsWithUI, part.toolName) || _hasInterruptResult(part); if (part.type === "data-thinking-steps") return true; + if (part.type === "data-step-separator") return true; return false; }); return filtered.length > 0 @@ -200,20 +325,28 @@ export function buildContentForUI( export function buildContentForPersistence( state: ContentPartsState, - toolsWithUI: Set<string> + toolsWithUI: ToolUIGate ): unknown[] { const parts: unknown[] = []; for (const part of state.contentParts) { if (part.type === "text" && part.text.length > 0) { parts.push(part); + } else if (part.type === "reasoning" && part.text.length > 0) { + // Persist reasoning blocks so a chat reload re-renders the + // collapsed thinking section instead of + // silently dropping it (mirrors the data-thinking-steps + // branch above). + parts.push(part); } else if ( part.type === "tool-call" && - (toolsWithUI.has(part.toolName) || _hasInterruptResult(part)) + (_toolPasses(toolsWithUI, part.toolName) || _hasInterruptResult(part)) ) { parts.push(part); } else if (part.type === "data-thinking-steps") { parts.push(part); + } else if (part.type === "data-step-separator") { + parts.push(part); } } @@ -221,23 +354,122 @@ export function buildContentForPersistence( } export type SSEEvent = - | { type: "text-delta"; delta: string } - | { type: "tool-input-start"; toolCallId: string; toolName: string } + | { type: "start"; messageId?: string } + | { type: "finish" } + | { type: "start-step" } + | { type: "finish-step" } + | { type: "text-start"; id: string } + | { type: "text-delta"; id?: string; delta: string } + | { type: "text-end"; id: string } + | { type: "reasoning-start"; id: string } + | { type: "reasoning-delta"; id?: string; delta: string } + | { type: "reasoning-end"; id: string } + | { + type: "tool-input-start"; + toolCallId: string; + toolName: string; + /** Authoritative LangChain ``tool_call.id``. Optional. */ + langchainToolCallId?: string; + } | { type: "tool-input-available"; toolCallId: string; toolName: string; input: Record<string, unknown>; + langchainToolCallId?: string; } | { type: "tool-output-available"; toolCallId: string; output: Record<string, unknown>; + /** Authoritative LangChain ``tool_call.id`` extracted from + * ``ToolMessage.tool_call_id`` at on_tool_end. Backfills cards + * that didn't get the id at tool-input-start time. */ + langchainToolCallId?: string; } | { type: "data-thinking-step"; data: ThinkingStepData } | { type: "data-thread-title-update"; data: { threadId: number; title: string } } | { type: "data-interrupt-request"; data: Record<string, unknown> } | { type: "data-documents-updated"; data: Record<string, unknown> } + | { + /** + * A freshly committed AgentActionLog row. Frontend stores + * this in a Map keyed off ``lc_tool_call_id`` so the chat + * tool card can light up its Revert button. + */ + type: "data-action-log"; + data: { + id: number; + lc_tool_call_id: string | null; + chat_turn_id: string | null; + tool_name: string; + reversible: boolean; + reverse_descriptor_present: boolean; + created_at: string | null; + error: boolean; + }; + } + | { + /** + * Reversibility flipped (filesystem op SAVEPOINT committed; + * cf. ``kb_persistence._dispatch_reversibility_update``). + */ + type: "data-action-log-updated"; + data: { id: number; reversible: boolean }; + } + | { + /** + * Emitted at the start of every stream so the frontend can + * stamp the per-turn correlation id onto the in-flight + * assistant message and replay it via + * ``appendMessage``. Pure-text turns never produce + * action-log events; this event guarantees the frontend + * always learns the turn id. + */ + type: "data-turn-info"; + data: { chat_turn_id: string }; + } + | { + /** + * Best-effort revert pass that ran BEFORE this regeneration. + * Per-action results are forwarded to the UI so the user + * can see which downstream actions were rolled + * back vs which couldn't be undone. + */ + type: "data-revert-results"; + data: { + status: "ok" | "partial"; + chat_turn_ids: string[]; + total: number; + reverted: number; + already_reverted: number; + not_reversible: number; + /** + * ``permission_denied`` and ``skipped`` are first-class + * counters so the response invariant + * ``total === sum(counters)`` always holds. Optional + * for forward compatibility with older backends; the + * frontend treats missing values as ``0``. + */ + permission_denied?: number; + failed: number; + skipped?: number; + results: Array<{ + action_id: number; + tool_name: string; + status: + | "reverted" + | "already_reverted" + | "not_reversible" + | "permission_denied" + | "failed" + | "skipped"; + message?: string | null; + new_action_id?: number | null; + error?: string | null; + }>; + }; + } | { type: "data-token-usage"; data: { diff --git a/surfsense_web/lib/chat/thread-persistence.ts b/surfsense_web/lib/chat/thread-persistence.ts index b5c5899b4..fc970c26e 100644 --- a/surfsense_web/lib/chat/thread-persistence.ts +++ b/surfsense_web/lib/chat/thread-persistence.ts @@ -46,6 +46,11 @@ export interface MessageRecord { author_display_name?: string | null; author_avatar_url?: string | null; token_usage?: TokenUsageSummary | null; + // Per-turn correlation id from ``configurable.turn_id`` at streaming + // time (added in migration 136). Used by the per-turn revert + // endpoint and edit-from-arbitrary-position. Nullable on legacy + // rows that predate the column. + turn_id?: string | null; } export interface ThreadListResponse { @@ -123,10 +128,20 @@ export async function getThreadMessages(threadId: number): Promise<ThreadHistory /** * Append a message to a thread. + * + * ``turn_id`` is the per-turn correlation id streamed by the backend + * via ``data-turn-info``. Persisting it lets later edits locate the + * matching LangGraph checkpoint without HumanMessage scanning. Older + * callers can still omit it for back-compat. */ export async function appendMessage( threadId: number, - message: { role: "user" | "assistant" | "system"; content: unknown; token_usage?: unknown } + message: { + role: "user" | "assistant" | "system"; + content: unknown; + token_usage?: unknown; + turn_id?: string | null; + } ): Promise<MessageRecord> { return baseApiService.post<MessageRecord>(`/api/v1/threads/${threadId}/messages`, undefined, { body: message,