SurfSense/surfsense_backend/app/services/revert_service.py
CREDO23 a8de98895a perf(revert-service): offload sync embed_texts to thread
_restore_in_place_document and _reinsert_document_from_revision are
async helpers invoked by the synchronous-feeling POST /api/threads/.../revert
route; both ran embed_texts inline, blocking the event loop while the
HTTP client waited.
2026-05-20 10:04:26 +02:00

622 lines
20 KiB
Python

"""Revert service for the SurfSense agent action log.
Implements the actual revert workflow used by
``POST /api/threads/{thread_id}/revert/{action_id}``. The route handler is a
thin auth + flag wrapper around the functions defined here.
Operation outcomes mirror the plan:
* **KB-owned actions** (NOTE / FILE / FOLDER mutations): restore from
:class:`app.db.DocumentRevision` / :class:`app.db.FolderRevision` rows
written before the original mutation. ``rm``/``rmdir`` re-INSERT a fresh
row from the snapshot; ``write_file`` create / ``mkdir`` DELETE the row
that was created; everything else is an in-place restore.
* **Connector-owned actions with a declared ``reverse_descriptor``**: invoke
the inverse tool through the agent's normal permission stack (NOT
bypassed). Out of scope for this PR — returns ``REVERSE_NOT_IMPLEMENTED``.
* **Anything else** (deprecated tool / no descriptor / schema drift):
returns ``NOT_REVERSIBLE`` and the route surfaces it as 409.
A successful revert appends a NEW row to ``agent_action_log`` with
``reverse_of=<original_action_id>`` and the requesting user's
``user_id``, preserving an auditable chain.
Dispatch must be exact-match (``tool_name == name``), NOT prefix matching.
``"rmdir".startswith("rm")`` would otherwise mis-route directory revert
to the document branch (and ``delete_note`` vs ``delete_folder`` is the
same trap waiting to happen).
"""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any, Literal
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.path_resolver import (
DOCUMENTS_ROOT,
safe_filename,
safe_folder_segment,
)
from app.db import (
AgentActionLog,
Chunk,
Document,
DocumentRevision,
DocumentType,
Folder,
FolderRevision,
NewChatThread,
)
from app.utils.document_converters import (
embed_texts,
generate_content_hash,
generate_unique_identifier_hash,
)
logger = logging.getLogger(__name__)
RevertOutcomeStatus = Literal[
"ok",
"not_reversible",
"not_found",
"permission_denied",
"tool_unavailable",
"reverse_not_implemented",
]
@dataclass
class RevertOutcome:
"""Structured result of :func:`revert_action`."""
status: RevertOutcomeStatus
message: str
new_action_id: int | None = None
# ---------------------------------------------------------------------------
# Lookup helpers
# ---------------------------------------------------------------------------
async def load_action(
session: AsyncSession,
*,
action_id: int,
thread_id: int,
) -> AgentActionLog | None:
"""Load the action_log row for ``action_id`` if it belongs to the thread."""
stmt = select(AgentActionLog).where(
AgentActionLog.id == action_id,
AgentActionLog.thread_id == thread_id,
)
result = await session.execute(stmt)
return result.scalars().first()
async def load_thread(session: AsyncSession, *, thread_id: int) -> NewChatThread | None:
stmt = select(NewChatThread).where(NewChatThread.id == thread_id)
result = await session.execute(stmt)
return result.scalars().first()
# ---------------------------------------------------------------------------
# Authorization
# ---------------------------------------------------------------------------
def can_revert(
*,
requester_user_id: str | None,
action: AgentActionLog,
is_admin: bool,
) -> bool:
"""Return True iff the requester is allowed to revert this action.
The plan's rule: "requester must be the original `user_id` on the
action, or hold the search-space admin role." Anonymous actions
(``action.user_id is None``) can only be reverted by admins.
"""
if is_admin:
return True
if action.user_id is None:
return False
return str(action.user_id) == str(requester_user_id)
# ---------------------------------------------------------------------------
# Helper: reconstruct virtual path from a snapshot
# ---------------------------------------------------------------------------
async def _virtual_path_from_snapshot(
session: AsyncSession,
revision: DocumentRevision,
) -> str | None:
"""Reconstruct the virtual_path the document was at before mutation.
Preference order:
1. ``metadata_before["virtual_path"]`` — written by every snapshot
helper since this PR.
2. Compose ``"<folder_path>/<title_before>"`` from
``folder_id_before`` + ``title_before``. Walks the folder chain via
``parent_id``.
"""
metadata = revision.metadata_before or {}
candidate = metadata.get("virtual_path") if isinstance(metadata, dict) else None
if isinstance(candidate, str) and candidate.startswith(DOCUMENTS_ROOT):
return candidate
title = revision.title_before
if not isinstance(title, str) or not title:
return None
parts: list[str] = []
cursor: int | None = revision.folder_id_before
visited: set[int] = set()
while cursor is not None and cursor not in visited:
visited.add(cursor)
folder = await session.get(Folder, cursor)
if folder is None:
return None
parts.append(safe_folder_segment(str(folder.name or "")))
cursor = folder.parent_id
parts.reverse()
base = f"{DOCUMENTS_ROOT}/" + "/".join(parts) if parts else DOCUMENTS_ROOT
filename = safe_filename(title)
return f"{base}/{filename}"
# ---------------------------------------------------------------------------
# Document revision restore (write/edit/move/rm)
# ---------------------------------------------------------------------------
def _set_field(target: Any, field: str, value: Any) -> None:
if value is not None:
setattr(target, field, value)
async def _restore_in_place_document(
session: AsyncSession,
*,
revision: DocumentRevision,
) -> RevertOutcome:
"""Apply an in-place restore to an existing :class:`Document`."""
if revision.document_id is None:
return RevertOutcome(
status="tool_unavailable",
message=(
"Original document was hard-deleted; in-place restore is not possible."
),
)
doc = await session.get(Document, revision.document_id)
if doc is None:
return RevertOutcome(
status="tool_unavailable",
message="Original document has been deleted; revert cannot proceed.",
)
_set_field(doc, "content", revision.content_before)
_set_field(doc, "source_markdown", revision.content_before)
_set_field(doc, "title", revision.title_before)
_set_field(doc, "folder_id", revision.folder_id_before)
metadata_before = revision.metadata_before or {}
if isinstance(metadata_before, dict) and metadata_before:
doc.document_metadata = dict(metadata_before)
if isinstance(revision.content_before, str):
doc.content_hash = generate_content_hash(
revision.content_before, doc.search_space_id
)
virtual_path = await _virtual_path_from_snapshot(session, revision)
if virtual_path:
doc.unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.NOTE,
virtual_path,
doc.search_space_id,
)
chunks_before = revision.chunks_before
if isinstance(chunks_before, list):
await session.execute(delete(Chunk).where(Chunk.document_id == doc.id))
chunk_texts = [
str(c.get("content"))
for c in chunks_before
if isinstance(c, dict) and isinstance(c.get("content"), str)
]
if chunk_texts:
chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts)
session.add_all(
[
Chunk(document_id=doc.id, content=text, embedding=embedding)
for text, embedding in zip(
chunk_texts, chunk_embeddings, strict=True
)
]
)
if isinstance(revision.content_before, str):
doc.embedding = (
await asyncio.to_thread(embed_texts, [revision.content_before])
)[0]
doc.updated_at = datetime.now(UTC)
return RevertOutcome(status="ok", message="Document restored from snapshot.")
async def _reinsert_document_from_revision(
session: AsyncSession,
*,
revision: DocumentRevision,
) -> RevertOutcome:
"""Re-INSERT a deleted :class:`Document` from a snapshot row (``rm`` revert)."""
if not isinstance(revision.title_before, str) or not revision.title_before:
return RevertOutcome(
status="not_reversible",
message="Snapshot lacks title_before; cannot recreate document.",
)
if not isinstance(revision.content_before, str):
return RevertOutcome(
status="not_reversible",
message="Snapshot lacks content_before; cannot recreate document.",
)
virtual_path = await _virtual_path_from_snapshot(session, revision)
if not virtual_path:
return RevertOutcome(
status="not_reversible",
message=(
"Snapshot is missing both metadata_before['virtual_path'] AND "
"a resolvable (folder_id_before, title_before) pair."
),
)
search_space_id = revision.search_space_id
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.NOTE,
virtual_path,
search_space_id,
)
collision = await session.execute(
select(Document.id).where(
Document.search_space_id == search_space_id,
Document.unique_identifier_hash == unique_identifier_hash,
)
)
if collision.scalar_one_or_none() is not None:
return RevertOutcome(
status="tool_unavailable",
message=(
f"A document already exists at '{virtual_path}'; revert would "
"collide. Move the live doc out of the way first."
),
)
metadata = revision.metadata_before or {}
if not isinstance(metadata, dict):
metadata = {}
metadata = dict(metadata)
metadata["virtual_path"] = virtual_path
content = revision.content_before
new_doc = Document(
title=revision.title_before,
document_type=DocumentType.NOTE,
document_metadata=metadata,
content=content,
content_hash=generate_content_hash(content, search_space_id),
unique_identifier_hash=unique_identifier_hash,
source_markdown=content,
search_space_id=search_space_id,
folder_id=revision.folder_id_before,
updated_at=datetime.now(UTC),
)
session.add(new_doc)
await session.flush()
new_doc.embedding = (await asyncio.to_thread(embed_texts, [content]))[0]
chunk_texts = []
chunks_before = revision.chunks_before
if isinstance(chunks_before, list):
chunk_texts = [
str(c.get("content"))
for c in chunks_before
if isinstance(c, dict) and isinstance(c.get("content"), str)
]
if chunk_texts:
chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts)
session.add_all(
[
Chunk(document_id=new_doc.id, content=text, embedding=embedding)
for text, embedding in zip(chunk_texts, chunk_embeddings, strict=True)
]
)
# Repoint the snapshot at the recreated row so a follow-up revert of
# the same row works as expected.
revision.document_id = new_doc.id
return RevertOutcome(
status="ok",
message=f"Re-inserted document '{revision.title_before}' from snapshot.",
)
async def _delete_created_document(
session: AsyncSession,
*,
revision: DocumentRevision,
) -> RevertOutcome:
"""Delete the document that ``write_file`` created (``content_before IS NULL``)."""
if revision.document_id is None:
return RevertOutcome(
status="ok",
message="No live row to delete (already removed elsewhere).",
)
await session.execute(delete(Document).where(Document.id == revision.document_id))
return RevertOutcome(
status="ok",
message="Deleted the document that was created by this action.",
)
async def _restore_document_revision(
session: AsyncSession, *, action: AgentActionLog
) -> RevertOutcome:
"""Dispatch document-level revert based on ``action.tool_name``."""
stmt = (
select(DocumentRevision)
.where(DocumentRevision.agent_action_id == action.id)
.order_by(DocumentRevision.created_at.desc())
.limit(1)
)
result = await session.execute(stmt)
revision = result.scalars().first()
if revision is None:
return RevertOutcome(
status="not_reversible",
message="No document_revisions row tied to this action.",
)
tool_name = (action.tool_name or "").lower()
if tool_name == "rm":
return await _reinsert_document_from_revision(session, revision=revision)
if tool_name == "write_file" and revision.content_before is None:
return await _delete_created_document(session, revision=revision)
return await _restore_in_place_document(session, revision=revision)
# ---------------------------------------------------------------------------
# Folder revision restore (mkdir/rmdir/rename/move)
# ---------------------------------------------------------------------------
async def _restore_in_place_folder(
session: AsyncSession,
*,
revision: FolderRevision,
) -> RevertOutcome:
if revision.folder_id is None:
return RevertOutcome(
status="tool_unavailable",
message="Original folder was hard-deleted; in-place restore is impossible.",
)
folder = await session.get(Folder, revision.folder_id)
if folder is None:
return RevertOutcome(
status="tool_unavailable",
message="Original folder has been deleted; revert cannot proceed.",
)
_set_field(folder, "name", revision.name_before)
_set_field(folder, "parent_id", revision.parent_id_before)
_set_field(folder, "position", revision.position_before)
folder.updated_at = datetime.now(UTC)
return RevertOutcome(status="ok", message="Folder restored from snapshot.")
async def _reinsert_folder_from_revision(
session: AsyncSession,
*,
revision: FolderRevision,
) -> RevertOutcome:
if not isinstance(revision.name_before, str) or not revision.name_before:
return RevertOutcome(
status="not_reversible",
message="Snapshot lacks name_before; cannot recreate folder.",
)
new_folder = Folder(
name=revision.name_before,
parent_id=revision.parent_id_before,
position=revision.position_before,
search_space_id=revision.search_space_id,
updated_at=datetime.now(UTC),
)
session.add(new_folder)
await session.flush()
revision.folder_id = new_folder.id
return RevertOutcome(
status="ok",
message=f"Re-inserted folder '{revision.name_before}' from snapshot.",
)
async def _delete_created_folder(
session: AsyncSession,
*,
revision: FolderRevision,
) -> RevertOutcome:
if revision.folder_id is None:
return RevertOutcome(
status="ok",
message="No live folder row to delete (already removed elsewhere).",
)
folder_id = revision.folder_id
has_doc = await session.execute(
select(Document.id).where(Document.folder_id == folder_id).limit(1)
)
if has_doc.scalar_one_or_none() is not None:
return RevertOutcome(
status="tool_unavailable",
message=(
"Folder is no longer empty (documents have been added since "
"mkdir); cannot revert."
),
)
has_child = await session.execute(
select(Folder.id).where(Folder.parent_id == folder_id).limit(1)
)
if has_child.scalar_one_or_none() is not None:
return RevertOutcome(
status="tool_unavailable",
message=(
"Folder is no longer empty (sub-folders have been added "
"since mkdir); cannot revert."
),
)
await session.execute(delete(Folder).where(Folder.id == folder_id))
return RevertOutcome(
status="ok",
message="Deleted the folder that was created by this action.",
)
async def _restore_folder_revision(
session: AsyncSession, *, action: AgentActionLog
) -> RevertOutcome:
stmt = (
select(FolderRevision)
.where(FolderRevision.agent_action_id == action.id)
.order_by(FolderRevision.created_at.desc())
.limit(1)
)
result = await session.execute(stmt)
revision = result.scalars().first()
if revision is None:
return RevertOutcome(
status="not_reversible",
message="No folder_revisions row tied to this action.",
)
tool_name = (action.tool_name or "").lower()
if tool_name == "rmdir":
return await _reinsert_folder_from_revision(session, revision=revision)
if tool_name == "mkdir":
return await _delete_created_folder(session, revision=revision)
return await _restore_in_place_folder(session, revision=revision)
# ---------------------------------------------------------------------------
# Dispatch
# ---------------------------------------------------------------------------
#
# Exact-name dispatch: ``tool_name == name``, NOT ``startswith(...)``.
# Prefix-matching mis-routes pairs like ``rm``/``rmdir`` and
# ``delete_note``/``delete_folder``.
_DOC_TOOLS: frozenset[str] = frozenset(
{
"edit_file",
"write_file",
"move_file",
"rm",
"update_memory",
"create_note",
"update_note",
"delete_note",
}
)
_FOLDER_TOOLS: frozenset[str] = frozenset(
{
"mkdir",
"rmdir",
"rename_folder",
"delete_folder",
}
)
async def revert_action(
session: AsyncSession,
*,
action: AgentActionLog,
requester_user_id: str | None,
) -> RevertOutcome:
"""Execute the revert for ``action`` and return a structured outcome.
The function does **not** commit — the caller is expected to commit on
success or roll back on failure. A new ``agent_action_log`` row is
added to the session on success with ``reverse_of=action.id``.
"""
tool_name = (action.tool_name or "").lower()
if tool_name in _DOC_TOOLS:
outcome = await _restore_document_revision(session, action=action)
elif tool_name in _FOLDER_TOOLS:
outcome = await _restore_folder_revision(session, action=action)
elif action.reverse_descriptor:
# Connector-owned reversibles run through the normal permission
# stack; out of scope for this PR — the route returns 503 anyway
# until UI ships, so 501-style "not implemented" is fine.
return RevertOutcome(
status="reverse_not_implemented",
message=(
"Connector-action revert is not yet implemented. The "
"reverse_descriptor is stored; future work will replay it "
"through PermissionMiddleware."
),
)
else:
return RevertOutcome(
status="not_reversible",
message=(
f"Tool {action.tool_name!r} is not reversible: no document "
"revision and no reverse_descriptor."
),
)
if outcome.status != "ok":
return outcome
new_row = AgentActionLog(
thread_id=action.thread_id,
user_id=requester_user_id,
search_space_id=action.search_space_id,
turn_id=None,
message_id=None,
tool_name=f"_revert:{action.tool_name}",
args={"reverted_action_id": action.id},
result_id=None,
reversible=False,
reverse_descriptor=None,
error=None,
reverse_of=action.id,
)
session.add(new_row)
await session.flush()
outcome.new_action_id = new_row.id
return outcome
__all__ = [
"RevertOutcome",
"can_revert",
"load_action",
"load_thread",
"revert_action",
]