mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-13 17:52:38 +02:00
feat: updated agent harness
This commit is contained in:
parent
9ec9b64348
commit
31a372bb84
139 changed files with 12583 additions and 1111 deletions
279
surfsense_backend/app/services/revert_service.py
Normal file
279
surfsense_backend/app/services/revert_service.py
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
"""Revert service for the SurfSense agent action log.
|
||||
|
||||
Implements the actual revert workflow used by
|
||||
``POST /api/threads/{thread_id}/revert/{action_id}``. The route handler is a
|
||||
thin auth + flag wrapper around the functions defined here.
|
||||
|
||||
Operation outcomes mirror the plan:
|
||||
|
||||
* **KB-owned actions** (NOTE / FILE / FOLDER mutations): restore from
|
||||
:class:`app.db.DocumentRevision` / :class:`app.db.FolderRevision` rows
|
||||
written before the original mutation.
|
||||
* **Connector-owned actions with a declared ``reverse_descriptor``**: invoke
|
||||
the inverse tool through the agent's normal permission stack (NOT
|
||||
bypassed). Out of scope for this PR — returns ``REVERSE_NOT_IMPLEMENTED``.
|
||||
* **Anything else** (deprecated tool / no descriptor / schema drift):
|
||||
returns ``NOT_REVERSIBLE`` and the route surfaces it as 409.
|
||||
|
||||
A successful revert appends a NEW row to ``agent_action_log`` with
|
||||
``reverse_of=<original_action_id>`` and the requesting user's
|
||||
``user_id``, preserving an auditable chain.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import (
|
||||
AgentActionLog,
|
||||
DocumentRevision,
|
||||
FolderRevision,
|
||||
NewChatThread,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
RevertOutcomeStatus = Literal[
|
||||
"ok",
|
||||
"not_reversible",
|
||||
"not_found",
|
||||
"permission_denied",
|
||||
"tool_unavailable",
|
||||
"reverse_not_implemented",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RevertOutcome:
|
||||
"""Structured result of :func:`revert_action`."""
|
||||
|
||||
status: RevertOutcomeStatus
|
||||
message: str
|
||||
new_action_id: int | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lookup helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def load_action(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
action_id: int,
|
||||
thread_id: int,
|
||||
) -> AgentActionLog | None:
|
||||
"""Load the action_log row for ``action_id`` if it belongs to the thread."""
|
||||
stmt = select(AgentActionLog).where(
|
||||
AgentActionLog.id == action_id,
|
||||
AgentActionLog.thread_id == thread_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
async def load_thread(
|
||||
session: AsyncSession, *, thread_id: int
|
||||
) -> NewChatThread | None:
|
||||
stmt = select(NewChatThread).where(NewChatThread.id == thread_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Authorization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def can_revert(
|
||||
*,
|
||||
requester_user_id: str | None,
|
||||
action: AgentActionLog,
|
||||
is_admin: bool,
|
||||
) -> bool:
|
||||
"""Return True iff the requester is allowed to revert this action.
|
||||
|
||||
The plan's rule: "requester must be the original `user_id` on the
|
||||
action, or hold the search-space admin role." Anonymous actions
|
||||
(``action.user_id is None``) can only be reverted by admins.
|
||||
"""
|
||||
if is_admin:
|
||||
return True
|
||||
if action.user_id is None:
|
||||
return False
|
||||
return str(action.user_id) == str(requester_user_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Revert paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _restore_document_revision(
|
||||
session: AsyncSession, *, action: AgentActionLog
|
||||
) -> RevertOutcome:
|
||||
"""Restore the most recent :class:`DocumentRevision` for ``action``."""
|
||||
stmt = (
|
||||
select(DocumentRevision)
|
||||
.where(DocumentRevision.agent_action_id == action.id)
|
||||
.order_by(DocumentRevision.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
revision = result.scalars().first()
|
||||
if revision is None:
|
||||
return RevertOutcome(
|
||||
status="not_reversible",
|
||||
message="No document_revisions row tied to this action.",
|
||||
)
|
||||
|
||||
from app.db import Document # late import to avoid cycles at module load
|
||||
|
||||
doc = await session.get(Document, revision.document_id)
|
||||
if doc is None:
|
||||
return RevertOutcome(
|
||||
status="tool_unavailable",
|
||||
message="Original document has been deleted; revert cannot proceed.",
|
||||
)
|
||||
|
||||
if revision.content_before is not None:
|
||||
doc.content = revision.content_before
|
||||
if revision.title_before is not None:
|
||||
doc.title = revision.title_before
|
||||
if revision.folder_id_before is not None:
|
||||
doc.folder_id = revision.folder_id_before
|
||||
doc.updated_at = datetime.now(UTC)
|
||||
return RevertOutcome(status="ok", message="Document restored from snapshot.")
|
||||
|
||||
|
||||
async def _restore_folder_revision(
|
||||
session: AsyncSession, *, action: AgentActionLog
|
||||
) -> RevertOutcome:
|
||||
stmt = (
|
||||
select(FolderRevision)
|
||||
.where(FolderRevision.agent_action_id == action.id)
|
||||
.order_by(FolderRevision.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
revision = result.scalars().first()
|
||||
if revision is None:
|
||||
return RevertOutcome(
|
||||
status="not_reversible",
|
||||
message="No folder_revisions row tied to this action.",
|
||||
)
|
||||
|
||||
from app.db import Folder
|
||||
|
||||
folder = await session.get(Folder, revision.folder_id)
|
||||
if folder is None:
|
||||
return RevertOutcome(
|
||||
status="tool_unavailable",
|
||||
message="Original folder has been deleted; revert cannot proceed.",
|
||||
)
|
||||
|
||||
if revision.name_before is not None:
|
||||
folder.name = revision.name_before
|
||||
if revision.parent_id_before is not None:
|
||||
folder.parent_id = revision.parent_id_before
|
||||
if revision.position_before is not None:
|
||||
folder.position = revision.position_before
|
||||
folder.updated_at = datetime.now(UTC)
|
||||
return RevertOutcome(status="ok", message="Folder restored from snapshot.")
|
||||
|
||||
|
||||
# Tool-name prefixes that route to KB document / folder revert paths. Kept
|
||||
# as data so a future PR adding new KB-owned tools doesn't have to touch
|
||||
# this module's control flow.
|
||||
_DOC_TOOL_PREFIXES: tuple[str, ...] = (
|
||||
"edit_file",
|
||||
"write_file",
|
||||
"update_memory",
|
||||
"create_note",
|
||||
"update_note",
|
||||
"delete_note",
|
||||
)
|
||||
_FOLDER_TOOL_PREFIXES: tuple[str, ...] = (
|
||||
"mkdir",
|
||||
"move_file",
|
||||
"rename_folder",
|
||||
"delete_folder",
|
||||
)
|
||||
|
||||
|
||||
async def revert_action(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
action: AgentActionLog,
|
||||
requester_user_id: str | None,
|
||||
) -> RevertOutcome:
|
||||
"""Execute the revert for ``action`` and return a structured outcome.
|
||||
|
||||
The function does **not** commit — the caller is expected to commit on
|
||||
success or roll back on failure. A new ``agent_action_log`` row is
|
||||
added to the session on success with ``reverse_of=action.id``.
|
||||
"""
|
||||
tool_name = (action.tool_name or "").lower()
|
||||
|
||||
if tool_name.startswith(_DOC_TOOL_PREFIXES):
|
||||
outcome = await _restore_document_revision(session, action=action)
|
||||
elif tool_name.startswith(_FOLDER_TOOL_PREFIXES):
|
||||
outcome = await _restore_folder_revision(session, action=action)
|
||||
elif action.reverse_descriptor:
|
||||
# Connector-owned reversibles run through the normal permission
|
||||
# stack; out of scope for this PR — the route returns 503 anyway
|
||||
# until UI ships, so 501-style "not implemented" is fine.
|
||||
return RevertOutcome(
|
||||
status="reverse_not_implemented",
|
||||
message=(
|
||||
"Connector-action revert is not yet implemented. The "
|
||||
"reverse_descriptor is stored; future work will replay it "
|
||||
"through PermissionMiddleware."
|
||||
),
|
||||
)
|
||||
else:
|
||||
return RevertOutcome(
|
||||
status="not_reversible",
|
||||
message=(
|
||||
f"Tool {action.tool_name!r} is not reversible: no document "
|
||||
"revision and no reverse_descriptor."
|
||||
),
|
||||
)
|
||||
|
||||
if outcome.status != "ok":
|
||||
return outcome
|
||||
|
||||
new_row = AgentActionLog(
|
||||
thread_id=action.thread_id,
|
||||
user_id=requester_user_id,
|
||||
search_space_id=action.search_space_id,
|
||||
turn_id=None,
|
||||
message_id=None,
|
||||
tool_name=f"_revert:{action.tool_name}",
|
||||
args={"reverted_action_id": action.id},
|
||||
result_id=None,
|
||||
reversible=False,
|
||||
reverse_descriptor=None,
|
||||
error=None,
|
||||
reverse_of=action.id,
|
||||
)
|
||||
session.add(new_row)
|
||||
await session.flush()
|
||||
outcome.new_action_id = new_row.id
|
||||
return outcome
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RevertOutcome",
|
||||
"can_revert",
|
||||
"load_action",
|
||||
"load_thread",
|
||||
"revert_action",
|
||||
]
|
||||
Loading…
Add table
Add a link
Reference in a new issue