mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-20 21:18:13 +02:00
document_converters, the github size-fallback chunker, revert_service restores, and the kb-persistence middleware now write explicit positions (the middleware read path also orders by position).
634 lines
20 KiB
Python
634 lines
20 KiB
Python
"""Revert service for the SurfSense agent action log.
|
|
|
|
Implements the actual revert workflow used by
|
|
``POST /api/threads/{thread_id}/revert/{action_id}``. The route handler is a
|
|
thin auth + flag wrapper around the functions defined here.
|
|
|
|
Operation outcomes mirror the plan:
|
|
|
|
* **KB-owned actions** (NOTE / FILE / FOLDER mutations): restore from
|
|
:class:`app.db.DocumentRevision` / :class:`app.db.FolderRevision` rows
|
|
written before the original mutation. ``rm``/``rmdir`` re-INSERT a fresh
|
|
row from the snapshot; ``write_file`` create / ``mkdir`` DELETE the row
|
|
that was created; everything else is an in-place restore.
|
|
* **Connector-owned actions with a declared ``reverse_descriptor``**: invoke
|
|
the inverse tool through the agent's normal permission stack (NOT
|
|
bypassed). Out of scope for this PR — returns ``REVERSE_NOT_IMPLEMENTED``.
|
|
* **Anything else** (deprecated tool / no descriptor / schema drift):
|
|
returns ``NOT_REVERSIBLE`` and the route surfaces it as 409.
|
|
|
|
A successful revert appends a NEW row to ``agent_action_log`` with
|
|
``reverse_of=<original_action_id>`` and the requesting user's
|
|
``user_id``, preserving an auditable chain.
|
|
|
|
Dispatch must be exact-match (``tool_name == name``), NOT prefix matching.
|
|
``"rmdir".startswith("rm")`` would otherwise mis-route directory revert
|
|
to the document branch (and ``delete_note`` vs ``delete_folder`` is the
|
|
same trap waiting to happen).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime
|
|
from typing import Any, Literal
|
|
|
|
from sqlalchemy import delete, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.agents.chat.runtime.path_resolver import (
|
|
DOCUMENTS_ROOT,
|
|
safe_filename,
|
|
safe_folder_segment,
|
|
)
|
|
from app.db import (
|
|
AgentActionLog,
|
|
Chunk,
|
|
Document,
|
|
DocumentRevision,
|
|
DocumentType,
|
|
Folder,
|
|
FolderRevision,
|
|
NewChatThread,
|
|
)
|
|
from app.utils.document_converters import (
|
|
embed_texts,
|
|
generate_content_hash,
|
|
generate_unique_identifier_hash,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
RevertOutcomeStatus = Literal[
|
|
"ok",
|
|
"not_reversible",
|
|
"not_found",
|
|
"permission_denied",
|
|
"tool_unavailable",
|
|
"reverse_not_implemented",
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class RevertOutcome:
|
|
"""Structured result of :func:`revert_action`."""
|
|
|
|
status: RevertOutcomeStatus
|
|
message: str
|
|
new_action_id: int | None = None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Lookup helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def load_action(
|
|
session: AsyncSession,
|
|
*,
|
|
action_id: int,
|
|
thread_id: int,
|
|
) -> AgentActionLog | None:
|
|
"""Load the action_log row for ``action_id`` if it belongs to the thread."""
|
|
stmt = select(AgentActionLog).where(
|
|
AgentActionLog.id == action_id,
|
|
AgentActionLog.thread_id == thread_id,
|
|
)
|
|
result = await session.execute(stmt)
|
|
return result.scalars().first()
|
|
|
|
|
|
async def load_thread(session: AsyncSession, *, thread_id: int) -> NewChatThread | None:
|
|
stmt = select(NewChatThread).where(NewChatThread.id == thread_id)
|
|
result = await session.execute(stmt)
|
|
return result.scalars().first()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Authorization
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def can_revert(
|
|
*,
|
|
requester_user_id: str | None,
|
|
action: AgentActionLog,
|
|
is_admin: bool,
|
|
) -> bool:
|
|
"""Return True iff the requester is allowed to revert this action.
|
|
|
|
The plan's rule: "requester must be the original `user_id` on the
|
|
action, or hold the search-space admin role." Anonymous actions
|
|
(``action.user_id is None``) can only be reverted by admins.
|
|
"""
|
|
if is_admin:
|
|
return True
|
|
if action.user_id is None:
|
|
return False
|
|
return str(action.user_id) == str(requester_user_id)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helper: reconstruct virtual path from a snapshot
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def _virtual_path_from_snapshot(
|
|
session: AsyncSession,
|
|
revision: DocumentRevision,
|
|
) -> str | None:
|
|
"""Reconstruct the virtual_path the document was at before mutation.
|
|
|
|
Preference order:
|
|
1. ``metadata_before["virtual_path"]`` — written by every snapshot
|
|
helper since this PR.
|
|
2. Compose ``"<folder_path>/<title_before>"`` from
|
|
``folder_id_before`` + ``title_before``. Walks the folder chain via
|
|
``parent_id``.
|
|
"""
|
|
metadata = revision.metadata_before or {}
|
|
candidate = metadata.get("virtual_path") if isinstance(metadata, dict) else None
|
|
if isinstance(candidate, str) and candidate.startswith(DOCUMENTS_ROOT):
|
|
return candidate
|
|
|
|
title = revision.title_before
|
|
if not isinstance(title, str) or not title:
|
|
return None
|
|
|
|
parts: list[str] = []
|
|
cursor: int | None = revision.folder_id_before
|
|
visited: set[int] = set()
|
|
while cursor is not None and cursor not in visited:
|
|
visited.add(cursor)
|
|
folder = await session.get(Folder, cursor)
|
|
if folder is None:
|
|
return None
|
|
parts.append(safe_folder_segment(str(folder.name or "")))
|
|
cursor = folder.parent_id
|
|
parts.reverse()
|
|
|
|
base = f"{DOCUMENTS_ROOT}/" + "/".join(parts) if parts else DOCUMENTS_ROOT
|
|
filename = safe_filename(title)
|
|
return f"{base}/{filename}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Document revision restore (write/edit/move/rm)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _set_field(target: Any, field: str, value: Any) -> None:
|
|
if value is not None:
|
|
setattr(target, field, value)
|
|
|
|
|
|
async def _restore_in_place_document(
|
|
session: AsyncSession,
|
|
*,
|
|
revision: DocumentRevision,
|
|
) -> RevertOutcome:
|
|
"""Apply an in-place restore to an existing :class:`Document`."""
|
|
if revision.document_id is None:
|
|
return RevertOutcome(
|
|
status="tool_unavailable",
|
|
message=(
|
|
"Original document was hard-deleted; in-place restore is not possible."
|
|
),
|
|
)
|
|
doc = await session.get(Document, revision.document_id)
|
|
if doc is None:
|
|
return RevertOutcome(
|
|
status="tool_unavailable",
|
|
message="Original document has been deleted; revert cannot proceed.",
|
|
)
|
|
|
|
_set_field(doc, "content", revision.content_before)
|
|
_set_field(doc, "source_markdown", revision.content_before)
|
|
_set_field(doc, "title", revision.title_before)
|
|
_set_field(doc, "folder_id", revision.folder_id_before)
|
|
metadata_before = revision.metadata_before or {}
|
|
if isinstance(metadata_before, dict) and metadata_before:
|
|
doc.document_metadata = dict(metadata_before)
|
|
|
|
if isinstance(revision.content_before, str):
|
|
doc.content_hash = generate_content_hash(
|
|
revision.content_before, doc.search_space_id
|
|
)
|
|
|
|
virtual_path = await _virtual_path_from_snapshot(session, revision)
|
|
if virtual_path:
|
|
doc.unique_identifier_hash = generate_unique_identifier_hash(
|
|
DocumentType.NOTE,
|
|
virtual_path,
|
|
doc.search_space_id,
|
|
)
|
|
|
|
chunks_before = revision.chunks_before
|
|
if isinstance(chunks_before, list):
|
|
await session.execute(delete(Chunk).where(Chunk.document_id == doc.id))
|
|
chunk_texts = [
|
|
str(c.get("content"))
|
|
for c in chunks_before
|
|
if isinstance(c, dict) and isinstance(c.get("content"), str)
|
|
]
|
|
if chunk_texts:
|
|
chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts)
|
|
session.add_all(
|
|
[
|
|
Chunk(
|
|
document_id=doc.id,
|
|
content=text,
|
|
embedding=embedding,
|
|
position=i,
|
|
)
|
|
for i, (text, embedding) in enumerate(
|
|
zip(chunk_texts, chunk_embeddings, strict=True)
|
|
)
|
|
]
|
|
)
|
|
if isinstance(revision.content_before, str):
|
|
doc.embedding = (
|
|
await asyncio.to_thread(embed_texts, [revision.content_before])
|
|
)[0]
|
|
|
|
doc.updated_at = datetime.now(UTC)
|
|
return RevertOutcome(status="ok", message="Document restored from snapshot.")
|
|
|
|
|
|
async def _reinsert_document_from_revision(
|
|
session: AsyncSession,
|
|
*,
|
|
revision: DocumentRevision,
|
|
) -> RevertOutcome:
|
|
"""Re-INSERT a deleted :class:`Document` from a snapshot row (``rm`` revert)."""
|
|
if not isinstance(revision.title_before, str) or not revision.title_before:
|
|
return RevertOutcome(
|
|
status="not_reversible",
|
|
message="Snapshot lacks title_before; cannot recreate document.",
|
|
)
|
|
if not isinstance(revision.content_before, str):
|
|
return RevertOutcome(
|
|
status="not_reversible",
|
|
message="Snapshot lacks content_before; cannot recreate document.",
|
|
)
|
|
|
|
virtual_path = await _virtual_path_from_snapshot(session, revision)
|
|
if not virtual_path:
|
|
return RevertOutcome(
|
|
status="not_reversible",
|
|
message=(
|
|
"Snapshot is missing both metadata_before['virtual_path'] AND "
|
|
"a resolvable (folder_id_before, title_before) pair."
|
|
),
|
|
)
|
|
|
|
search_space_id = revision.search_space_id
|
|
unique_identifier_hash = generate_unique_identifier_hash(
|
|
DocumentType.NOTE,
|
|
virtual_path,
|
|
search_space_id,
|
|
)
|
|
collision = await session.execute(
|
|
select(Document.id).where(
|
|
Document.search_space_id == search_space_id,
|
|
Document.unique_identifier_hash == unique_identifier_hash,
|
|
)
|
|
)
|
|
if collision.scalar_one_or_none() is not None:
|
|
return RevertOutcome(
|
|
status="tool_unavailable",
|
|
message=(
|
|
f"A document already exists at '{virtual_path}'; revert would "
|
|
"collide. Move the live doc out of the way first."
|
|
),
|
|
)
|
|
|
|
metadata = revision.metadata_before or {}
|
|
if not isinstance(metadata, dict):
|
|
metadata = {}
|
|
metadata = dict(metadata)
|
|
metadata["virtual_path"] = virtual_path
|
|
|
|
content = revision.content_before
|
|
new_doc = Document(
|
|
title=revision.title_before,
|
|
document_type=DocumentType.NOTE,
|
|
document_metadata=metadata,
|
|
content=content,
|
|
content_hash=generate_content_hash(content, search_space_id),
|
|
unique_identifier_hash=unique_identifier_hash,
|
|
source_markdown=content,
|
|
search_space_id=search_space_id,
|
|
folder_id=revision.folder_id_before,
|
|
updated_at=datetime.now(UTC),
|
|
)
|
|
session.add(new_doc)
|
|
await session.flush()
|
|
|
|
new_doc.embedding = (await asyncio.to_thread(embed_texts, [content]))[0]
|
|
chunk_texts = []
|
|
chunks_before = revision.chunks_before
|
|
if isinstance(chunks_before, list):
|
|
chunk_texts = [
|
|
str(c.get("content"))
|
|
for c in chunks_before
|
|
if isinstance(c, dict) and isinstance(c.get("content"), str)
|
|
]
|
|
if chunk_texts:
|
|
chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts)
|
|
session.add_all(
|
|
[
|
|
Chunk(
|
|
document_id=new_doc.id,
|
|
content=text,
|
|
embedding=embedding,
|
|
position=i,
|
|
)
|
|
for i, (text, embedding) in enumerate(
|
|
zip(chunk_texts, chunk_embeddings, strict=True)
|
|
)
|
|
]
|
|
)
|
|
|
|
# Repoint the snapshot at the recreated row so a follow-up revert of
|
|
# the same row works as expected.
|
|
revision.document_id = new_doc.id
|
|
return RevertOutcome(
|
|
status="ok",
|
|
message=f"Re-inserted document '{revision.title_before}' from snapshot.",
|
|
)
|
|
|
|
|
|
async def _delete_created_document(
|
|
session: AsyncSession,
|
|
*,
|
|
revision: DocumentRevision,
|
|
) -> RevertOutcome:
|
|
"""Delete the document that ``write_file`` created (``content_before IS NULL``)."""
|
|
if revision.document_id is None:
|
|
return RevertOutcome(
|
|
status="ok",
|
|
message="No live row to delete (already removed elsewhere).",
|
|
)
|
|
await session.execute(delete(Document).where(Document.id == revision.document_id))
|
|
return RevertOutcome(
|
|
status="ok",
|
|
message="Deleted the document that was created by this action.",
|
|
)
|
|
|
|
|
|
async def _restore_document_revision(
|
|
session: AsyncSession, *, action: AgentActionLog
|
|
) -> RevertOutcome:
|
|
"""Dispatch document-level revert based on ``action.tool_name``."""
|
|
stmt = (
|
|
select(DocumentRevision)
|
|
.where(DocumentRevision.agent_action_id == action.id)
|
|
.order_by(DocumentRevision.created_at.desc())
|
|
.limit(1)
|
|
)
|
|
result = await session.execute(stmt)
|
|
revision = result.scalars().first()
|
|
if revision is None:
|
|
return RevertOutcome(
|
|
status="not_reversible",
|
|
message="No document_revisions row tied to this action.",
|
|
)
|
|
|
|
tool_name = (action.tool_name or "").lower()
|
|
|
|
if tool_name == "rm":
|
|
return await _reinsert_document_from_revision(session, revision=revision)
|
|
|
|
if tool_name == "write_file" and revision.content_before is None:
|
|
return await _delete_created_document(session, revision=revision)
|
|
|
|
return await _restore_in_place_document(session, revision=revision)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Folder revision restore (mkdir/rmdir/rename/move)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def _restore_in_place_folder(
|
|
session: AsyncSession,
|
|
*,
|
|
revision: FolderRevision,
|
|
) -> RevertOutcome:
|
|
if revision.folder_id is None:
|
|
return RevertOutcome(
|
|
status="tool_unavailable",
|
|
message="Original folder was hard-deleted; in-place restore is impossible.",
|
|
)
|
|
folder = await session.get(Folder, revision.folder_id)
|
|
if folder is None:
|
|
return RevertOutcome(
|
|
status="tool_unavailable",
|
|
message="Original folder has been deleted; revert cannot proceed.",
|
|
)
|
|
_set_field(folder, "name", revision.name_before)
|
|
_set_field(folder, "parent_id", revision.parent_id_before)
|
|
_set_field(folder, "position", revision.position_before)
|
|
folder.updated_at = datetime.now(UTC)
|
|
return RevertOutcome(status="ok", message="Folder restored from snapshot.")
|
|
|
|
|
|
async def _reinsert_folder_from_revision(
|
|
session: AsyncSession,
|
|
*,
|
|
revision: FolderRevision,
|
|
) -> RevertOutcome:
|
|
if not isinstance(revision.name_before, str) or not revision.name_before:
|
|
return RevertOutcome(
|
|
status="not_reversible",
|
|
message="Snapshot lacks name_before; cannot recreate folder.",
|
|
)
|
|
new_folder = Folder(
|
|
name=revision.name_before,
|
|
parent_id=revision.parent_id_before,
|
|
position=revision.position_before,
|
|
search_space_id=revision.search_space_id,
|
|
updated_at=datetime.now(UTC),
|
|
)
|
|
session.add(new_folder)
|
|
await session.flush()
|
|
revision.folder_id = new_folder.id
|
|
return RevertOutcome(
|
|
status="ok",
|
|
message=f"Re-inserted folder '{revision.name_before}' from snapshot.",
|
|
)
|
|
|
|
|
|
async def _delete_created_folder(
|
|
session: AsyncSession,
|
|
*,
|
|
revision: FolderRevision,
|
|
) -> RevertOutcome:
|
|
if revision.folder_id is None:
|
|
return RevertOutcome(
|
|
status="ok",
|
|
message="No live folder row to delete (already removed elsewhere).",
|
|
)
|
|
folder_id = revision.folder_id
|
|
|
|
has_doc = await session.execute(
|
|
select(Document.id).where(Document.folder_id == folder_id).limit(1)
|
|
)
|
|
if has_doc.scalar_one_or_none() is not None:
|
|
return RevertOutcome(
|
|
status="tool_unavailable",
|
|
message=(
|
|
"Folder is no longer empty (documents have been added since "
|
|
"mkdir); cannot revert."
|
|
),
|
|
)
|
|
has_child = await session.execute(
|
|
select(Folder.id).where(Folder.parent_id == folder_id).limit(1)
|
|
)
|
|
if has_child.scalar_one_or_none() is not None:
|
|
return RevertOutcome(
|
|
status="tool_unavailable",
|
|
message=(
|
|
"Folder is no longer empty (sub-folders have been added "
|
|
"since mkdir); cannot revert."
|
|
),
|
|
)
|
|
|
|
await session.execute(delete(Folder).where(Folder.id == folder_id))
|
|
return RevertOutcome(
|
|
status="ok",
|
|
message="Deleted the folder that was created by this action.",
|
|
)
|
|
|
|
|
|
async def _restore_folder_revision(
|
|
session: AsyncSession, *, action: AgentActionLog
|
|
) -> RevertOutcome:
|
|
stmt = (
|
|
select(FolderRevision)
|
|
.where(FolderRevision.agent_action_id == action.id)
|
|
.order_by(FolderRevision.created_at.desc())
|
|
.limit(1)
|
|
)
|
|
result = await session.execute(stmt)
|
|
revision = result.scalars().first()
|
|
if revision is None:
|
|
return RevertOutcome(
|
|
status="not_reversible",
|
|
message="No folder_revisions row tied to this action.",
|
|
)
|
|
|
|
tool_name = (action.tool_name or "").lower()
|
|
|
|
if tool_name == "rmdir":
|
|
return await _reinsert_folder_from_revision(session, revision=revision)
|
|
|
|
if tool_name == "mkdir":
|
|
return await _delete_created_folder(session, revision=revision)
|
|
|
|
return await _restore_in_place_folder(session, revision=revision)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Dispatch
|
|
# ---------------------------------------------------------------------------
|
|
#
|
|
# Exact-name dispatch: ``tool_name == name``, NOT ``startswith(...)``.
|
|
# Prefix-matching mis-routes pairs like ``rm``/``rmdir`` and
|
|
# ``delete_note``/``delete_folder``.
|
|
|
|
_DOC_TOOLS: frozenset[str] = frozenset(
|
|
{
|
|
"edit_file",
|
|
"write_file",
|
|
"move_file",
|
|
"rm",
|
|
"update_memory",
|
|
"create_note",
|
|
"update_note",
|
|
"delete_note",
|
|
}
|
|
)
|
|
_FOLDER_TOOLS: frozenset[str] = frozenset(
|
|
{
|
|
"mkdir",
|
|
"rmdir",
|
|
"rename_folder",
|
|
"delete_folder",
|
|
}
|
|
)
|
|
|
|
|
|
async def revert_action(
|
|
session: AsyncSession,
|
|
*,
|
|
action: AgentActionLog,
|
|
requester_user_id: str | None,
|
|
) -> RevertOutcome:
|
|
"""Execute the revert for ``action`` and return a structured outcome.
|
|
|
|
The function does **not** commit — the caller is expected to commit on
|
|
success or roll back on failure. A new ``agent_action_log`` row is
|
|
added to the session on success with ``reverse_of=action.id``.
|
|
"""
|
|
tool_name = (action.tool_name or "").lower()
|
|
|
|
if tool_name in _DOC_TOOLS:
|
|
outcome = await _restore_document_revision(session, action=action)
|
|
elif tool_name in _FOLDER_TOOLS:
|
|
outcome = await _restore_folder_revision(session, action=action)
|
|
elif action.reverse_descriptor:
|
|
# Connector-owned reversibles run through the normal permission
|
|
# stack; out of scope for this PR — the route returns 503 anyway
|
|
# until UI ships, so 501-style "not implemented" is fine.
|
|
return RevertOutcome(
|
|
status="reverse_not_implemented",
|
|
message=(
|
|
"Connector-action revert is not yet implemented. The "
|
|
"reverse_descriptor is stored; future work will replay it "
|
|
"through PermissionMiddleware."
|
|
),
|
|
)
|
|
else:
|
|
return RevertOutcome(
|
|
status="not_reversible",
|
|
message=(
|
|
f"Tool {action.tool_name!r} is not reversible: no document "
|
|
"revision and no reverse_descriptor."
|
|
),
|
|
)
|
|
|
|
if outcome.status != "ok":
|
|
return outcome
|
|
|
|
new_row = AgentActionLog(
|
|
thread_id=action.thread_id,
|
|
user_id=requester_user_id,
|
|
search_space_id=action.search_space_id,
|
|
turn_id=None,
|
|
message_id=None,
|
|
tool_name=f"_revert:{action.tool_name}",
|
|
args={"reverted_action_id": action.id},
|
|
result_id=None,
|
|
reversible=False,
|
|
reverse_descriptor=None,
|
|
error=None,
|
|
reverse_of=action.id,
|
|
)
|
|
session.add(new_row)
|
|
await session.flush()
|
|
outcome.new_action_id = new_row.id
|
|
return outcome
|
|
|
|
|
|
__all__ = [
|
|
"RevertOutcome",
|
|
"can_revert",
|
|
"load_action",
|
|
"load_thread",
|
|
"revert_action",
|
|
]
|