"""Revert service for the SurfSense agent action log. Implements the actual revert workflow used by ``POST /api/threads/{thread_id}/revert/{action_id}``. The route handler is a thin auth + flag wrapper around the functions defined here. Operation outcomes mirror the plan: * **KB-owned actions** (NOTE / FILE / FOLDER mutations): restore from :class:`app.db.DocumentRevision` / :class:`app.db.FolderRevision` rows written before the original mutation. ``rm``/``rmdir`` re-INSERT a fresh row from the snapshot; ``write_file`` create / ``mkdir`` DELETE the row that was created; everything else is an in-place restore. * **Connector-owned actions with a declared ``reverse_descriptor``**: invoke the inverse tool through the agent's normal permission stack (NOT bypassed). Out of scope for this PR — returns ``REVERSE_NOT_IMPLEMENTED``. * **Anything else** (deprecated tool / no descriptor / schema drift): returns ``NOT_REVERSIBLE`` and the route surfaces it as 409. A successful revert appends a NEW row to ``agent_action_log`` with ``reverse_of=`` and the requesting user's ``user_id``, preserving an auditable chain. Dispatch must be exact-match (``tool_name == name``), NOT prefix matching. ``"rmdir".startswith("rm")`` would otherwise mis-route directory revert to the document branch (and ``delete_note`` vs ``delete_folder`` is the same trap waiting to happen). """ from __future__ import annotations import asyncio import logging from dataclasses import dataclass from datetime import UTC, datetime from typing import Any, Literal from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.path_resolver import ( DOCUMENTS_ROOT, safe_filename, safe_folder_segment, ) from app.db import ( AgentActionLog, Chunk, Document, DocumentRevision, DocumentType, Folder, FolderRevision, NewChatThread, ) from app.utils.document_converters import ( embed_texts, generate_content_hash, generate_unique_identifier_hash, ) logger = logging.getLogger(__name__) RevertOutcomeStatus = Literal[ "ok", "not_reversible", "not_found", "permission_denied", "tool_unavailable", "reverse_not_implemented", ] @dataclass class RevertOutcome: """Structured result of :func:`revert_action`.""" status: RevertOutcomeStatus message: str new_action_id: int | None = None # --------------------------------------------------------------------------- # Lookup helpers # --------------------------------------------------------------------------- async def load_action( session: AsyncSession, *, action_id: int, thread_id: int, ) -> AgentActionLog | None: """Load the action_log row for ``action_id`` if it belongs to the thread.""" stmt = select(AgentActionLog).where( AgentActionLog.id == action_id, AgentActionLog.thread_id == thread_id, ) result = await session.execute(stmt) return result.scalars().first() async def load_thread(session: AsyncSession, *, thread_id: int) -> NewChatThread | None: stmt = select(NewChatThread).where(NewChatThread.id == thread_id) result = await session.execute(stmt) return result.scalars().first() # --------------------------------------------------------------------------- # Authorization # --------------------------------------------------------------------------- def can_revert( *, requester_user_id: str | None, action: AgentActionLog, is_admin: bool, ) -> bool: """Return True iff the requester is allowed to revert this action. The plan's rule: "requester must be the original `user_id` on the action, or hold the search-space admin role." Anonymous actions (``action.user_id is None``) can only be reverted by admins. """ if is_admin: return True if action.user_id is None: return False return str(action.user_id) == str(requester_user_id) # --------------------------------------------------------------------------- # Helper: reconstruct virtual path from a snapshot # --------------------------------------------------------------------------- async def _virtual_path_from_snapshot( session: AsyncSession, revision: DocumentRevision, ) -> str | None: """Reconstruct the virtual_path the document was at before mutation. Preference order: 1. ``metadata_before["virtual_path"]`` — written by every snapshot helper since this PR. 2. Compose ``"/"`` from ``folder_id_before`` + ``title_before``. Walks the folder chain via ``parent_id``. """ metadata = revision.metadata_before or {} candidate = metadata.get("virtual_path") if isinstance(metadata, dict) else None if isinstance(candidate, str) and candidate.startswith(DOCUMENTS_ROOT): return candidate title = revision.title_before if not isinstance(title, str) or not title: return None parts: list[str] = [] cursor: int | None = revision.folder_id_before visited: set[int] = set() while cursor is not None and cursor not in visited: visited.add(cursor) folder = await session.get(Folder, cursor) if folder is None: return None parts.append(safe_folder_segment(str(folder.name or ""))) cursor = folder.parent_id parts.reverse() base = f"{DOCUMENTS_ROOT}/" + "/".join(parts) if parts else DOCUMENTS_ROOT filename = safe_filename(title) return f"{base}/{filename}" # --------------------------------------------------------------------------- # Document revision restore (write/edit/move/rm) # --------------------------------------------------------------------------- def _set_field(target: Any, field: str, value: Any) -> None: if value is not None: setattr(target, field, value) async def _restore_in_place_document( session: AsyncSession, *, revision: DocumentRevision, ) -> RevertOutcome: """Apply an in-place restore to an existing :class:`Document`.""" if revision.document_id is None: return RevertOutcome( status="tool_unavailable", message=( "Original document was hard-deleted; in-place restore is not possible." ), ) doc = await session.get(Document, revision.document_id) if doc is None: return RevertOutcome( status="tool_unavailable", message="Original document has been deleted; revert cannot proceed.", ) _set_field(doc, "content", revision.content_before) _set_field(doc, "source_markdown", revision.content_before) _set_field(doc, "title", revision.title_before) _set_field(doc, "folder_id", revision.folder_id_before) metadata_before = revision.metadata_before or {} if isinstance(metadata_before, dict) and metadata_before: doc.document_metadata = dict(metadata_before) if isinstance(revision.content_before, str): doc.content_hash = generate_content_hash( revision.content_before, doc.search_space_id ) virtual_path = await _virtual_path_from_snapshot(session, revision) if virtual_path: doc.unique_identifier_hash = generate_unique_identifier_hash( DocumentType.NOTE, virtual_path, doc.search_space_id, ) chunks_before = revision.chunks_before if isinstance(chunks_before, list): await session.execute(delete(Chunk).where(Chunk.document_id == doc.id)) chunk_texts = [ str(c.get("content")) for c in chunks_before if isinstance(c, dict) and isinstance(c.get("content"), str) ] if chunk_texts: chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts) session.add_all( [ Chunk(document_id=doc.id, content=text, embedding=embedding) for text, embedding in zip( chunk_texts, chunk_embeddings, strict=True ) ] ) if isinstance(revision.content_before, str): doc.embedding = ( await asyncio.to_thread(embed_texts, [revision.content_before]) )[0] doc.updated_at = datetime.now(UTC) return RevertOutcome(status="ok", message="Document restored from snapshot.") async def _reinsert_document_from_revision( session: AsyncSession, *, revision: DocumentRevision, ) -> RevertOutcome: """Re-INSERT a deleted :class:`Document` from a snapshot row (``rm`` revert).""" if not isinstance(revision.title_before, str) or not revision.title_before: return RevertOutcome( status="not_reversible", message="Snapshot lacks title_before; cannot recreate document.", ) if not isinstance(revision.content_before, str): return RevertOutcome( status="not_reversible", message="Snapshot lacks content_before; cannot recreate document.", ) virtual_path = await _virtual_path_from_snapshot(session, revision) if not virtual_path: return RevertOutcome( status="not_reversible", message=( "Snapshot is missing both metadata_before['virtual_path'] AND " "a resolvable (folder_id_before, title_before) pair." ), ) search_space_id = revision.search_space_id unique_identifier_hash = generate_unique_identifier_hash( DocumentType.NOTE, virtual_path, search_space_id, ) collision = await session.execute( select(Document.id).where( Document.search_space_id == search_space_id, Document.unique_identifier_hash == unique_identifier_hash, ) ) if collision.scalar_one_or_none() is not None: return RevertOutcome( status="tool_unavailable", message=( f"A document already exists at '{virtual_path}'; revert would " "collide. Move the live doc out of the way first." ), ) metadata = revision.metadata_before or {} if not isinstance(metadata, dict): metadata = {} metadata = dict(metadata) metadata["virtual_path"] = virtual_path content = revision.content_before new_doc = Document( title=revision.title_before, document_type=DocumentType.NOTE, document_metadata=metadata, content=content, content_hash=generate_content_hash(content, search_space_id), unique_identifier_hash=unique_identifier_hash, source_markdown=content, search_space_id=search_space_id, folder_id=revision.folder_id_before, updated_at=datetime.now(UTC), ) session.add(new_doc) await session.flush() new_doc.embedding = (await asyncio.to_thread(embed_texts, [content]))[0] chunk_texts = [] chunks_before = revision.chunks_before if isinstance(chunks_before, list): chunk_texts = [ str(c.get("content")) for c in chunks_before if isinstance(c, dict) and isinstance(c.get("content"), str) ] if chunk_texts: chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts) session.add_all( [ Chunk(document_id=new_doc.id, content=text, embedding=embedding) for text, embedding in zip(chunk_texts, chunk_embeddings, strict=True) ] ) # Repoint the snapshot at the recreated row so a follow-up revert of # the same row works as expected. revision.document_id = new_doc.id return RevertOutcome( status="ok", message=f"Re-inserted document '{revision.title_before}' from snapshot.", ) async def _delete_created_document( session: AsyncSession, *, revision: DocumentRevision, ) -> RevertOutcome: """Delete the document that ``write_file`` created (``content_before IS NULL``).""" if revision.document_id is None: return RevertOutcome( status="ok", message="No live row to delete (already removed elsewhere).", ) await session.execute(delete(Document).where(Document.id == revision.document_id)) return RevertOutcome( status="ok", message="Deleted the document that was created by this action.", ) async def _restore_document_revision( session: AsyncSession, *, action: AgentActionLog ) -> RevertOutcome: """Dispatch document-level revert based on ``action.tool_name``.""" stmt = ( select(DocumentRevision) .where(DocumentRevision.agent_action_id == action.id) .order_by(DocumentRevision.created_at.desc()) .limit(1) ) result = await session.execute(stmt) revision = result.scalars().first() if revision is None: return RevertOutcome( status="not_reversible", message="No document_revisions row tied to this action.", ) tool_name = (action.tool_name or "").lower() if tool_name == "rm": return await _reinsert_document_from_revision(session, revision=revision) if tool_name == "write_file" and revision.content_before is None: return await _delete_created_document(session, revision=revision) return await _restore_in_place_document(session, revision=revision) # --------------------------------------------------------------------------- # Folder revision restore (mkdir/rmdir/rename/move) # --------------------------------------------------------------------------- async def _restore_in_place_folder( session: AsyncSession, *, revision: FolderRevision, ) -> RevertOutcome: if revision.folder_id is None: return RevertOutcome( status="tool_unavailable", message="Original folder was hard-deleted; in-place restore is impossible.", ) folder = await session.get(Folder, revision.folder_id) if folder is None: return RevertOutcome( status="tool_unavailable", message="Original folder has been deleted; revert cannot proceed.", ) _set_field(folder, "name", revision.name_before) _set_field(folder, "parent_id", revision.parent_id_before) _set_field(folder, "position", revision.position_before) folder.updated_at = datetime.now(UTC) return RevertOutcome(status="ok", message="Folder restored from snapshot.") async def _reinsert_folder_from_revision( session: AsyncSession, *, revision: FolderRevision, ) -> RevertOutcome: if not isinstance(revision.name_before, str) or not revision.name_before: return RevertOutcome( status="not_reversible", message="Snapshot lacks name_before; cannot recreate folder.", ) new_folder = Folder( name=revision.name_before, parent_id=revision.parent_id_before, position=revision.position_before, search_space_id=revision.search_space_id, updated_at=datetime.now(UTC), ) session.add(new_folder) await session.flush() revision.folder_id = new_folder.id return RevertOutcome( status="ok", message=f"Re-inserted folder '{revision.name_before}' from snapshot.", ) async def _delete_created_folder( session: AsyncSession, *, revision: FolderRevision, ) -> RevertOutcome: if revision.folder_id is None: return RevertOutcome( status="ok", message="No live folder row to delete (already removed elsewhere).", ) folder_id = revision.folder_id has_doc = await session.execute( select(Document.id).where(Document.folder_id == folder_id).limit(1) ) if has_doc.scalar_one_or_none() is not None: return RevertOutcome( status="tool_unavailable", message=( "Folder is no longer empty (documents have been added since " "mkdir); cannot revert." ), ) has_child = await session.execute( select(Folder.id).where(Folder.parent_id == folder_id).limit(1) ) if has_child.scalar_one_or_none() is not None: return RevertOutcome( status="tool_unavailable", message=( "Folder is no longer empty (sub-folders have been added " "since mkdir); cannot revert." ), ) await session.execute(delete(Folder).where(Folder.id == folder_id)) return RevertOutcome( status="ok", message="Deleted the folder that was created by this action.", ) async def _restore_folder_revision( session: AsyncSession, *, action: AgentActionLog ) -> RevertOutcome: stmt = ( select(FolderRevision) .where(FolderRevision.agent_action_id == action.id) .order_by(FolderRevision.created_at.desc()) .limit(1) ) result = await session.execute(stmt) revision = result.scalars().first() if revision is None: return RevertOutcome( status="not_reversible", message="No folder_revisions row tied to this action.", ) tool_name = (action.tool_name or "").lower() if tool_name == "rmdir": return await _reinsert_folder_from_revision(session, revision=revision) if tool_name == "mkdir": return await _delete_created_folder(session, revision=revision) return await _restore_in_place_folder(session, revision=revision) # --------------------------------------------------------------------------- # Dispatch # --------------------------------------------------------------------------- # # Exact-name dispatch: ``tool_name == name``, NOT ``startswith(...)``. # Prefix-matching mis-routes pairs like ``rm``/``rmdir`` and # ``delete_note``/``delete_folder``. _DOC_TOOLS: frozenset[str] = frozenset( { "edit_file", "write_file", "move_file", "rm", "update_memory", "create_note", "update_note", "delete_note", } ) _FOLDER_TOOLS: frozenset[str] = frozenset( { "mkdir", "rmdir", "rename_folder", "delete_folder", } ) async def revert_action( session: AsyncSession, *, action: AgentActionLog, requester_user_id: str | None, ) -> RevertOutcome: """Execute the revert for ``action`` and return a structured outcome. The function does **not** commit — the caller is expected to commit on success or roll back on failure. A new ``agent_action_log`` row is added to the session on success with ``reverse_of=action.id``. """ tool_name = (action.tool_name or "").lower() if tool_name in _DOC_TOOLS: outcome = await _restore_document_revision(session, action=action) elif tool_name in _FOLDER_TOOLS: outcome = await _restore_folder_revision(session, action=action) elif action.reverse_descriptor: # Connector-owned reversibles run through the normal permission # stack; out of scope for this PR — the route returns 503 anyway # until UI ships, so 501-style "not implemented" is fine. return RevertOutcome( status="reverse_not_implemented", message=( "Connector-action revert is not yet implemented. The " "reverse_descriptor is stored; future work will replay it " "through PermissionMiddleware." ), ) else: return RevertOutcome( status="not_reversible", message=( f"Tool {action.tool_name!r} is not reversible: no document " "revision and no reverse_descriptor." ), ) if outcome.status != "ok": return outcome new_row = AgentActionLog( thread_id=action.thread_id, user_id=requester_user_id, search_space_id=action.search_space_id, turn_id=None, message_id=None, tool_name=f"_revert:{action.tool_name}", args={"reverted_action_id": action.id}, result_id=None, reversible=False, reverse_descriptor=None, error=None, reverse_of=action.id, ) session.add(new_row) await session.flush() outcome.new_action_id = new_row.id return outcome __all__ = [ "RevertOutcome", "can_revert", "load_action", "load_thread", "revert_action", ]