mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
refactor(agents): move filesystem_state, path_resolver, sandbox to app/agents/shared (slice 3b)
Relocate three leaf filesystem-cluster modules to the shared kernel and flip all 38 importers. No re-export shims needed (no frozen single-agent importer). This also resolves the pre-existing shared->new_chat back-edge from shared/receipt_command.py onto filesystem_state. filesystem_backends is intentionally deferred to slice 5: it depends on new_chat middleware (kb_postgres_backend, multi_root_local_folder_backend) that have not yet moved, so relocating it now would create a shared->new_chat edge.
This commit is contained in:
parent
1b536b8aee
commit
3efe51e6ec
41 changed files with 55 additions and 55 deletions
213
surfsense_backend/app/agents/shared/filesystem_state.py
Normal file
213
surfsense_backend/app/agents/shared/filesystem_state.py
Normal file
|
|
@ -0,0 +1,213 @@
|
|||
"""LangGraph state schema additions used by the SurfSense filesystem agent.
|
||||
|
||||
This schema extends deepagents' upstream :class:`FilesystemState` with the
|
||||
extra fields needed to implement Postgres-backed virtual filesystem semantics:
|
||||
|
||||
* ``cwd`` — current working directory (per-thread checkpointed).
|
||||
* ``staged_dirs`` — pending mkdir requests (cloud only).
|
||||
* ``staged_dir_tool_calls`` — sidecar map ``path -> tool_call_id`` for staged dirs.
|
||||
* ``pending_moves`` — pending move_file requests (cloud only).
|
||||
* ``pending_deletes`` — pending ``rm`` requests (cloud only).
|
||||
* ``pending_dir_deletes`` — pending ``rmdir`` requests (cloud only).
|
||||
* ``doc_id_by_path`` — virtual_path -> Document.id, populated by lazy reads.
|
||||
* ``dirty_paths`` — paths whose state file content differs from DB.
|
||||
* ``dirty_path_tool_calls`` — sidecar map ``path -> latest tool_call_id`` for
|
||||
dirty paths; used to bind the per-path snapshot to an action_id.
|
||||
* ``kb_priority`` — top-K priority hints rendered into a system message.
|
||||
* ``kb_matched_chunk_ids`` — internal hand-off for matched-chunk highlighting.
|
||||
* ``kb_anon_doc`` — Redis-loaded anonymous document (if any).
|
||||
* ``tree_version`` — bumped by persistence; invalidates the tree render cache.
|
||||
* ``workspace_tree_text`` — pre-rendered ``<workspace_tree>`` body for the turn.
|
||||
|
||||
Tools mutate these fields ONLY via ``Command(update=...)`` returns; the
|
||||
reducers in :mod:`app.agents.shared.state_reducers` handle merging.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Any, NotRequired
|
||||
|
||||
from deepagents.middleware.filesystem import FilesystemState
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.shared.state_reducers import (
|
||||
_add_unique_reducer,
|
||||
_dict_merge_with_tombstones_reducer,
|
||||
_int_counter_merge_reducer,
|
||||
_list_append_reducer,
|
||||
_replace_reducer,
|
||||
)
|
||||
from app.agents.shared.receipt import Receipt
|
||||
|
||||
|
||||
class PendingMove(TypedDict, total=False):
|
||||
"""A staged move_file operation pending end-of-turn commit.
|
||||
|
||||
``tool_call_id`` is optional for backward compatibility with checkpoints
|
||||
written before the snapshot/revert pipeline was wired up; new entries
|
||||
always include it so the persistence body can resolve an action_id.
|
||||
"""
|
||||
|
||||
source: str
|
||||
dest: str
|
||||
overwrite: bool
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class PendingDelete(TypedDict, total=False):
|
||||
"""A staged ``rm`` or ``rmdir`` operation pending end-of-turn commit.
|
||||
|
||||
``tool_call_id`` is required for new entries (it's the binding key used
|
||||
by :class:`KnowledgeBasePersistenceMiddleware` to find the matching
|
||||
:class:`AgentActionLog` row and bind the snapshot to it). Marked
|
||||
``total=False`` only to tolerate older checkpoint payloads.
|
||||
"""
|
||||
|
||||
path: str
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class KbPriorityEntry(TypedDict, total=False):
|
||||
path: str
|
||||
score: float
|
||||
document_id: int | None
|
||||
title: str
|
||||
mentioned: bool
|
||||
|
||||
|
||||
class KbAnonDoc(TypedDict, total=False):
|
||||
"""In-memory anonymous-session document loaded from Redis."""
|
||||
|
||||
path: str
|
||||
title: str
|
||||
content: str
|
||||
chunks: list[dict[str, Any]]
|
||||
|
||||
|
||||
class SurfSenseFilesystemState(FilesystemState):
|
||||
"""Filesystem state used by the SurfSense agent (cloud + desktop).
|
||||
|
||||
Extends deepagents' :class:`FilesystemState` (which provides ``files``)
|
||||
with cloud-mode staging fields and search-priority hints. All extra fields
|
||||
are reducer-backed so that ``Command(update=...)`` payloads merge cleanly
|
||||
across agent steps and across checkpoints.
|
||||
"""
|
||||
|
||||
cwd: NotRequired[Annotated[str, _replace_reducer]]
|
||||
"""Current working directory.
|
||||
|
||||
Defaults to ``"/documents"`` in cloud mode and ``"/"`` (or first mount) in
|
||||
desktop mode. Initialized once per thread by ``KnowledgeTreeMiddleware``.
|
||||
"""
|
||||
|
||||
staged_dirs: NotRequired[Annotated[list[str], _add_unique_reducer]]
|
||||
"""mkdir paths staged for end-of-turn folder creation (cloud only)."""
|
||||
|
||||
staged_dir_tool_calls: NotRequired[
|
||||
Annotated[dict[str, str], _dict_merge_with_tombstones_reducer]
|
||||
]
|
||||
"""``path -> tool_call_id`` sidecar for ``staged_dirs``.
|
||||
|
||||
Used by :class:`KnowledgeBasePersistenceMiddleware` to bind the
|
||||
:class:`FolderRevision` snapshot to the originating ``mkdir`` action.
|
||||
Kept separate from ``staged_dirs`` (which stays a unique-string list)
|
||||
to avoid breaking ``_add_unique_reducer`` semantics.
|
||||
"""
|
||||
|
||||
pending_moves: NotRequired[Annotated[list[PendingMove], _list_append_reducer]]
|
||||
"""move_file ops staged for end-of-turn commit (cloud only)."""
|
||||
|
||||
pending_deletes: NotRequired[Annotated[list[PendingDelete], _list_append_reducer]]
|
||||
"""``rm`` ops staged for end-of-turn ``DELETE FROM documents`` (cloud only).
|
||||
|
||||
Each entry is a dict ``{"path": ..., "tool_call_id": ...}``. Per-path
|
||||
uniqueness is enforced inside the commit body, not the reducer (we keep
|
||||
``tool_call_id`` per occurrence so snapshot binding works).
|
||||
"""
|
||||
|
||||
pending_dir_deletes: NotRequired[
|
||||
Annotated[list[PendingDelete], _list_append_reducer]
|
||||
]
|
||||
"""``rmdir`` ops staged for end-of-turn ``DELETE FROM folders`` (cloud only).
|
||||
|
||||
Same shape as :data:`pending_deletes`. Commit body re-verifies the
|
||||
folder is empty (in-DB AND with this turn's pending changes accounted
|
||||
for) before issuing the DELETE.
|
||||
"""
|
||||
|
||||
doc_id_by_path: NotRequired[
|
||||
Annotated[dict[str, int], _dict_merge_with_tombstones_reducer]
|
||||
]
|
||||
"""virtual_path -> ``Document.id`` for lazily loaded files.
|
||||
|
||||
Populated on first read of a KB document. Used by edit_file/move_file/
|
||||
aafter_agent to map paths back to a real DB row. ``None`` values delete
|
||||
the key (tombstones).
|
||||
"""
|
||||
|
||||
dirty_paths: NotRequired[Annotated[list[str], _add_unique_reducer]]
|
||||
"""Paths whose ``state["files"]`` content has been modified this turn."""
|
||||
|
||||
dirty_path_tool_calls: NotRequired[
|
||||
Annotated[dict[str, str], _dict_merge_with_tombstones_reducer]
|
||||
]
|
||||
"""``path -> latest tool_call_id`` sidecar for ``dirty_paths``.
|
||||
|
||||
The persistence body coalesces multiple writes/edits to the same path
|
||||
into one snapshot per turn. This map captures the most-recent
|
||||
``tool_call_id`` so the resulting :class:`DocumentRevision` is bound
|
||||
to the latest action_id (the one the user is most likely to revert).
|
||||
"""
|
||||
|
||||
kb_priority: NotRequired[Annotated[list[KbPriorityEntry], _replace_reducer]]
|
||||
"""Top-K priority hints rendered as a system message before the user turn."""
|
||||
|
||||
kb_matched_chunk_ids: NotRequired[Annotated[dict[int, list[int]], _replace_reducer]]
|
||||
"""Internal: ``Document.id`` -> list of matched chunk IDs from hybrid search."""
|
||||
|
||||
kb_anon_doc: NotRequired[Annotated[KbAnonDoc | None, _replace_reducer]]
|
||||
"""Anonymous-session document loaded from Redis (read-only, no DB row)."""
|
||||
|
||||
tree_version: NotRequired[Annotated[int, _replace_reducer]]
|
||||
"""Monotonically increasing counter; bumped when commits change the KB tree."""
|
||||
|
||||
workspace_tree_text: NotRequired[Annotated[str, _replace_reducer]]
|
||||
"""Pre-rendered ``<workspace_tree>`` body; shared with subagents to skip re-render."""
|
||||
|
||||
billable_calls: NotRequired[Annotated[dict[str, int], _int_counter_merge_reducer]]
|
||||
"""Per-subagent ``task(...)`` invocation counter, summed across the turn.
|
||||
|
||||
Incremented by ``task_tool.py`` each time a subagent invocation
|
||||
completes (single- or batch-mode). The orchestrator can read this map
|
||||
to self-limit when a runaway loop sends the same specialist 20 calls
|
||||
in a row; the runtime emits a soft warning ToolMessage once the
|
||||
cumulative count crosses :data:`DEFAULT_SUBAGENT_BILLABLE_THRESHOLD`.
|
||||
Cleared by checkpoint rollover (i.e. per turn).
|
||||
"""
|
||||
|
||||
receipts: NotRequired[Annotated[list[Receipt], _list_append_reducer]]
|
||||
"""Structured Receipt handles emitted by mutating subagent tools this turn.
|
||||
|
||||
Each mutating tool (deliverables, every connector, KB writes via the
|
||||
persistence middleware) wraps its native return into a
|
||||
:class:`~app.agents.shared.receipt.Receipt`
|
||||
and returns it under the ``"receipt"`` key alongside its existing
|
||||
payload. The subagent's tool-call middleware folds the receipt into
|
||||
this list, and ``_return_command_with_state_update`` in
|
||||
``checkpointed_subagent_middleware/task_tool.py`` carries the list up
|
||||
to the parent automatically (``"receipts"`` is not in
|
||||
``EXCLUDED_STATE_KEYS``).
|
||||
|
||||
Append-only across the turn; cleared by checkpoint rollover. The
|
||||
orchestrator reads it via the ``<verification>`` teaching to confirm
|
||||
side-effecting subagent claims (see ``shared/snippets/verifiable_handle.md``).
|
||||
"""
|
||||
|
||||
|
||||
__all__ = [
|
||||
"KbAnonDoc",
|
||||
"KbPriorityEntry",
|
||||
"PendingDelete",
|
||||
"PendingMove",
|
||||
"SurfSenseFilesystemState",
|
||||
]
|
||||
351
surfsense_backend/app/agents/shared/path_resolver.py
Normal file
351
surfsense_backend/app/agents/shared/path_resolver.py
Normal file
|
|
@ -0,0 +1,351 @@
|
|||
"""Canonical virtual-path resolver for SurfSense knowledge-base documents.
|
||||
|
||||
This module is the single source of truth for mapping ``Document`` rows to
|
||||
virtual paths under ``/documents/`` and back. It is used by:
|
||||
|
||||
* :class:`KnowledgeTreeMiddleware` (rendering the workspace tree)
|
||||
* :class:`KnowledgePriorityMiddleware` (computing priority paths)
|
||||
* :class:`KBPostgresBackend` (``als_info`` / ``aread`` / move operations)
|
||||
* :class:`KnowledgeBasePersistenceMiddleware` (resolving moves and creates)
|
||||
|
||||
Centralising the logic ensures that title-collision suffixes, folder paths,
|
||||
and ``unique_identifier_hash`` lookups never drift between renders and
|
||||
commits.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentType, Folder
|
||||
from app.utils.document_converters import generate_unique_identifier_hash
|
||||
|
||||
DOCUMENTS_ROOT = "/documents"
|
||||
"""Root virtual folder for all KB documents."""
|
||||
|
||||
_INVALID_FILENAME_CHARS = re.compile(r"[\\/:*?\"<>|]+")
|
||||
_WHITESPACE_RUN = re.compile(r"\s+")
|
||||
|
||||
|
||||
def safe_filename(value: str, *, fallback: str = "untitled.xml") -> str:
|
||||
"""Convert arbitrary text into a filesystem-safe ``.xml`` filename."""
|
||||
name = _INVALID_FILENAME_CHARS.sub("_", value).strip()
|
||||
name = _WHITESPACE_RUN.sub(" ", name)
|
||||
if not name:
|
||||
name = fallback
|
||||
if len(name) > 180:
|
||||
name = name[:180].rstrip()
|
||||
if not name.lower().endswith(".xml"):
|
||||
name = f"{name}.xml"
|
||||
return name
|
||||
|
||||
|
||||
def safe_folder_segment(value: str, *, fallback: str = "folder") -> str:
|
||||
"""Sanitize a single folder name into a path-safe segment."""
|
||||
name = _INVALID_FILENAME_CHARS.sub("_", value).strip()
|
||||
name = _WHITESPACE_RUN.sub(" ", name)
|
||||
if not name:
|
||||
return fallback
|
||||
if len(name) > 180:
|
||||
name = name[:180].rstrip()
|
||||
return name
|
||||
|
||||
|
||||
def _suffix_with_doc_id(filename: str, doc_id: int | None) -> str:
|
||||
if doc_id is None:
|
||||
return filename
|
||||
if not filename.lower().endswith(".xml"):
|
||||
return f"{filename} ({doc_id}).xml"
|
||||
stem = filename[:-4]
|
||||
return f"{stem} ({doc_id}).xml"
|
||||
|
||||
|
||||
_SUFFIX_PATTERN = re.compile(r"\s\((\d+)\)\.xml$", re.IGNORECASE)
|
||||
|
||||
|
||||
def parse_doc_id_suffix(filename: str) -> tuple[str, int | None]:
|
||||
"""Strip a trailing ``" (<doc_id>).xml"`` suffix; return ``(stem, doc_id)``.
|
||||
|
||||
If no suffix is present, returns ``(stem_without_xml_extension, None)``.
|
||||
"""
|
||||
match = _SUFFIX_PATTERN.search(filename)
|
||||
if match:
|
||||
doc_id = int(match.group(1))
|
||||
stem = filename[: match.start()]
|
||||
return stem, doc_id
|
||||
if filename.lower().endswith(".xml"):
|
||||
return filename[:-4], None
|
||||
return filename, None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PathIndex:
|
||||
"""In-memory occupancy snapshot used by :func:`doc_to_virtual_path`.
|
||||
|
||||
Built once per call site so collision handling is deterministic and so
|
||||
we don't perform N folder lookups per render.
|
||||
"""
|
||||
|
||||
folder_paths: dict[int, str] = field(default_factory=dict)
|
||||
"""``Folder.id`` -> absolute virtual folder path under ``/documents``."""
|
||||
|
||||
occupants: dict[str, int] = field(default_factory=dict)
|
||||
"""virtual path -> ``Document.id`` already occupying that path (this render)."""
|
||||
|
||||
|
||||
async def _build_folder_paths(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> dict[int, str]:
|
||||
"""Compute ``Folder.id`` -> absolute virtual path under ``/documents``."""
|
||||
result = await session.execute(
|
||||
select(Folder.id, Folder.name, Folder.parent_id).where(
|
||||
Folder.search_space_id == search_space_id
|
||||
)
|
||||
)
|
||||
rows = result.all()
|
||||
by_id = {row.id: {"name": row.name, "parent_id": row.parent_id} for row in rows}
|
||||
cache: dict[int, str] = {}
|
||||
|
||||
def resolve(folder_id: int) -> str:
|
||||
if folder_id in cache:
|
||||
return cache[folder_id]
|
||||
parts: list[str] = []
|
||||
cursor: int | None = folder_id
|
||||
visited: set[int] = set()
|
||||
while cursor is not None and cursor in by_id and cursor not in visited:
|
||||
visited.add(cursor)
|
||||
entry = by_id[cursor]
|
||||
parts.append(safe_folder_segment(str(entry["name"])))
|
||||
cursor = entry["parent_id"]
|
||||
parts.reverse()
|
||||
path = f"{DOCUMENTS_ROOT}/" + "/".join(parts) if parts else DOCUMENTS_ROOT
|
||||
cache[folder_id] = path
|
||||
return path
|
||||
|
||||
for folder_id in by_id:
|
||||
resolve(folder_id)
|
||||
return cache
|
||||
|
||||
|
||||
async def build_path_index(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
*,
|
||||
populate_occupants: bool = True,
|
||||
) -> PathIndex:
|
||||
"""Build a :class:`PathIndex` for a search space.
|
||||
|
||||
``populate_occupants`` controls whether the occupancy map is pre-seeded
|
||||
from existing ``Document`` rows. Most callers want this so that
|
||||
:func:`doc_to_virtual_path` can detect collisions across the whole space;
|
||||
the persistence middleware sets this to ``False`` when it is iterating to
|
||||
decide where to place fresh documents.
|
||||
"""
|
||||
folder_paths = await _build_folder_paths(session, search_space_id)
|
||||
occupants: dict[str, int] = {}
|
||||
if populate_occupants:
|
||||
rows = await session.execute(
|
||||
select(Document.id, Document.title, Document.folder_id).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
for row in rows.all():
|
||||
base = folder_paths.get(row.folder_id, DOCUMENTS_ROOT)
|
||||
filename = safe_filename(str(row.title or "untitled"))
|
||||
path = f"{base}/{filename}"
|
||||
if path in occupants and occupants[path] != row.id:
|
||||
path = f"{base}/{_suffix_with_doc_id(filename, row.id)}"
|
||||
occupants[path] = row.id
|
||||
return PathIndex(folder_paths=folder_paths, occupants=occupants)
|
||||
|
||||
|
||||
def doc_to_virtual_path(
|
||||
*,
|
||||
doc_id: int | None,
|
||||
title: str,
|
||||
folder_id: int | None,
|
||||
index: PathIndex,
|
||||
) -> str:
|
||||
"""Return the canonical virtual path for a document.
|
||||
|
||||
Mutates ``index.occupants`` so subsequent calls see this assignment and
|
||||
deterministically pick a different suffix for the next colliding doc.
|
||||
"""
|
||||
base = index.folder_paths.get(folder_id, DOCUMENTS_ROOT)
|
||||
filename = safe_filename(str(title or "untitled"))
|
||||
path = f"{base}/{filename}"
|
||||
occupant = index.occupants.get(path)
|
||||
if occupant is not None and occupant != doc_id:
|
||||
path = f"{base}/{_suffix_with_doc_id(filename, doc_id)}"
|
||||
if doc_id is not None:
|
||||
index.occupants[path] = doc_id
|
||||
return path
|
||||
|
||||
|
||||
async def virtual_path_to_doc(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
virtual_path: str,
|
||||
) -> Document | None:
|
||||
"""Resolve a virtual path back to a ``Document`` row.
|
||||
|
||||
Resolution order:
|
||||
1. ``Document.unique_identifier_hash`` lookup (fast path for paths created
|
||||
by SurfSense itself — every NOTE write goes through this hash).
|
||||
2. If the basename carries a ``" (<doc_id>).xml"`` disambiguation suffix,
|
||||
try a direct id lookup constrained to the search space.
|
||||
3. Title-from-basename + folder-resolution lookup as a last resort.
|
||||
"""
|
||||
if not virtual_path or not virtual_path.startswith(DOCUMENTS_ROOT):
|
||||
return None
|
||||
|
||||
unique_hash = generate_unique_identifier_hash(
|
||||
DocumentType.NOTE,
|
||||
virtual_path,
|
||||
search_space_id,
|
||||
)
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.unique_identifier_hash == unique_hash,
|
||||
)
|
||||
)
|
||||
document = result.scalar_one_or_none()
|
||||
if document is not None:
|
||||
return document
|
||||
|
||||
rel = virtual_path[len(DOCUMENTS_ROOT) :].lstrip("/")
|
||||
if not rel:
|
||||
return None
|
||||
parts = [p for p in rel.split("/") if p]
|
||||
if not parts:
|
||||
return None
|
||||
basename = parts[-1]
|
||||
folder_parts = parts[:-1]
|
||||
|
||||
stem, suffix_doc_id = parse_doc_id_suffix(basename)
|
||||
if suffix_doc_id is not None:
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.id == suffix_doc_id,
|
||||
)
|
||||
)
|
||||
document = result.scalar_one_or_none()
|
||||
if document is not None:
|
||||
return document
|
||||
|
||||
folder_id = await _resolve_folder_id(
|
||||
session, search_space_id=search_space_id, folder_parts=folder_parts
|
||||
)
|
||||
title_candidates: list[str] = []
|
||||
raw_title = stem
|
||||
title_candidates.append(raw_title)
|
||||
if raw_title.endswith(".xml"):
|
||||
title_candidates.append(raw_title[:-4])
|
||||
|
||||
for candidate in dict.fromkeys(title_candidates):
|
||||
if not candidate:
|
||||
continue
|
||||
query = select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.title == candidate,
|
||||
)
|
||||
if folder_id is None:
|
||||
query = query.where(Document.folder_id.is_(None))
|
||||
else:
|
||||
query = query.where(Document.folder_id == folder_id)
|
||||
result = await session.execute(query)
|
||||
document = result.scalars().first()
|
||||
if document is not None:
|
||||
return document
|
||||
|
||||
# Fallback: title-as-string lookup misses when the real DB title contains
|
||||
# characters that ``safe_filename`` lossily replaces (``:``, ``/``, ``*``,
|
||||
# etc.) — common for connector-imported docs (Google Calendar/Drive etc.).
|
||||
# The workspace tree shows the lossy filename, so the agent passes that
|
||||
# filename back here. Scan all documents in the resolved folder and match
|
||||
# by ``safe_filename(title)`` to recover the original document.
|
||||
folder_scan = select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
)
|
||||
if folder_id is None:
|
||||
folder_scan = folder_scan.where(Document.folder_id.is_(None))
|
||||
else:
|
||||
folder_scan = folder_scan.where(Document.folder_id == folder_id)
|
||||
result = await session.execute(folder_scan)
|
||||
for candidate_doc in result.scalars().all():
|
||||
encoded = safe_filename(str(candidate_doc.title or "untitled"))
|
||||
if encoded == basename:
|
||||
return candidate_doc
|
||||
return None
|
||||
|
||||
|
||||
async def _resolve_folder_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
folder_parts: list[str],
|
||||
) -> int | None:
|
||||
"""Look up the leaf folder id for a chain of folder names; return ``None`` if missing."""
|
||||
if not folder_parts:
|
||||
return None
|
||||
parent_id: int | None = None
|
||||
for raw in folder_parts:
|
||||
name = safe_folder_segment(raw)
|
||||
query = select(Folder.id).where(
|
||||
Folder.search_space_id == search_space_id,
|
||||
Folder.name == name,
|
||||
)
|
||||
if parent_id is None:
|
||||
query = query.where(Folder.parent_id.is_(None))
|
||||
else:
|
||||
query = query.where(Folder.parent_id == parent_id)
|
||||
result = await session.execute(query)
|
||||
row = result.first()
|
||||
if row is None:
|
||||
return None
|
||||
parent_id = row[0]
|
||||
return parent_id
|
||||
|
||||
|
||||
def parse_documents_path(virtual_path: str) -> tuple[list[str], str]:
|
||||
"""Parse a ``/documents/...`` path into ``(folder_parts, document_title)``.
|
||||
|
||||
The title has any ``.xml`` extension and trailing ``" (<doc_id>)"``
|
||||
disambiguation suffix stripped.
|
||||
"""
|
||||
if not virtual_path or not virtual_path.startswith(DOCUMENTS_ROOT):
|
||||
return [], ""
|
||||
rel = virtual_path[len(DOCUMENTS_ROOT) :].strip("/")
|
||||
if not rel:
|
||||
return [], ""
|
||||
parts = [p for p in rel.split("/") if p]
|
||||
if not parts:
|
||||
return [], ""
|
||||
folder_parts = parts[:-1]
|
||||
basename = parts[-1]
|
||||
stem, _ = parse_doc_id_suffix(basename)
|
||||
title = stem
|
||||
if title.endswith(".xml"):
|
||||
title = title[:-4]
|
||||
return folder_parts, title
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DOCUMENTS_ROOT",
|
||||
"PathIndex",
|
||||
"build_path_index",
|
||||
"doc_to_virtual_path",
|
||||
"parse_doc_id_suffix",
|
||||
"parse_documents_path",
|
||||
"safe_filename",
|
||||
"safe_folder_segment",
|
||||
"virtual_path_to_doc",
|
||||
]
|
||||
|
|
@ -6,7 +6,7 @@ participate in the verification teaching from
|
|||
``multi_agent_chat/subagents/shared/snippets/verifiable_handle.md`` those
|
||||
tools now also need to write a :class:`Receipt` into the parent's
|
||||
``state['receipts']`` list (declared on
|
||||
:class:`~app.agents.new_chat.filesystem_state.SurfSenseFilesystemState`
|
||||
:class:`~app.agents.shared.filesystem_state.SurfSenseFilesystemState`
|
||||
and backed by the append reducer).
|
||||
|
||||
:func:`with_receipt` wraps both behaviours: it returns the tool payload as
|
||||
|
|
@ -51,7 +51,7 @@ def with_receipt(
|
|||
"""Return a Command that ships ``payload`` as a ToolMessage AND appends ``receipt``.
|
||||
|
||||
The append happens via the ``_list_append_reducer`` on the ``receipts``
|
||||
field of :class:`~app.agents.new_chat.filesystem_state.SurfSenseFilesystemState`,
|
||||
field of :class:`~app.agents.shared.filesystem_state.SurfSenseFilesystemState`,
|
||||
so concurrent subagent batches (item 4 in the plan) won't clobber each
|
||||
other's receipts.
|
||||
"""
|
||||
|
|
|
|||
401
surfsense_backend/app/agents/shared/sandbox.py
Normal file
401
surfsense_backend/app/agents/shared/sandbox.py
Normal file
|
|
@ -0,0 +1,401 @@
|
|||
"""
|
||||
Daytona sandbox provider for SurfSense deep agent.
|
||||
|
||||
Manages the lifecycle of sandboxed code execution environments.
|
||||
Each conversation thread gets its own isolated sandbox instance
|
||||
via the Daytona cloud API, identified by labels.
|
||||
|
||||
Files created during a session are persisted to local storage before
|
||||
the sandbox is deleted so they remain downloadable after cleanup.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
from daytona import (
|
||||
CreateSandboxFromSnapshotParams,
|
||||
Daytona,
|
||||
DaytonaConfig,
|
||||
SandboxState,
|
||||
)
|
||||
from daytona.common.errors import DaytonaError
|
||||
from deepagents.backends.protocol import ExecuteResponse
|
||||
from langchain_daytona import DaytonaSandbox
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _TimeoutAwareSandbox(DaytonaSandbox):
|
||||
"""DaytonaSandbox subclass that accepts the per-command *timeout*
|
||||
kwarg required by the deepagents middleware.
|
||||
|
||||
The upstream ``langchain-daytona`` ``execute()`` ignores timeout,
|
||||
so deepagents raises *"This sandbox backend does not support
|
||||
per-command timeout overrides"* on every first call. This thin
|
||||
wrapper forwards the parameter to the Daytona SDK.
|
||||
"""
|
||||
|
||||
def execute(self, command: str, *, timeout: int | None = None) -> ExecuteResponse:
|
||||
t = timeout if timeout is not None else self._default_timeout
|
||||
result = self._sandbox.process.exec(command, timeout=t)
|
||||
return ExecuteResponse(
|
||||
output=result.result,
|
||||
exit_code=result.exit_code,
|
||||
truncated=False,
|
||||
)
|
||||
|
||||
async def aexecute(
|
||||
self, command: str, *, timeout: int | None = None
|
||||
) -> ExecuteResponse: # type: ignore[override]
|
||||
return await asyncio.to_thread(self.execute, command, timeout=timeout)
|
||||
|
||||
def download_file(self, path: str) -> bytes:
|
||||
"""Download a file from the sandbox filesystem."""
|
||||
return self._sandbox.fs.download_file(path)
|
||||
|
||||
|
||||
_daytona_client: Daytona | None = None
|
||||
_client_lock = threading.Lock()
|
||||
_sandbox_cache: dict[str, _TimeoutAwareSandbox] = {}
|
||||
_sandbox_locks: dict[str, asyncio.Lock] = {}
|
||||
_sandbox_locks_mu = asyncio.Lock()
|
||||
_seeded_files: dict[str, dict[str, str]] = {}
|
||||
_SANDBOX_CACHE_MAX_SIZE = 20
|
||||
THREAD_LABEL_KEY = "surfsense_thread"
|
||||
SANDBOX_DOCUMENTS_ROOT = "/home/daytona/documents"
|
||||
|
||||
|
||||
def is_sandbox_enabled() -> bool:
|
||||
return os.environ.get("DAYTONA_SANDBOX_ENABLED", "FALSE").upper() == "TRUE"
|
||||
|
||||
|
||||
def _get_client() -> Daytona:
|
||||
global _daytona_client
|
||||
with _client_lock:
|
||||
if _daytona_client is None:
|
||||
config = DaytonaConfig(
|
||||
api_key=os.environ.get("DAYTONA_API_KEY", ""),
|
||||
api_url=os.environ.get("DAYTONA_API_URL", "https://app.daytona.io/api"),
|
||||
target=os.environ.get("DAYTONA_TARGET", "us"),
|
||||
)
|
||||
_daytona_client = Daytona(config)
|
||||
return _daytona_client
|
||||
|
||||
|
||||
def _sandbox_create_params(
|
||||
labels: dict[str, str],
|
||||
) -> CreateSandboxFromSnapshotParams:
|
||||
snapshot_id = os.environ.get("DAYTONA_SNAPSHOT_ID") or None
|
||||
return CreateSandboxFromSnapshotParams(
|
||||
language="python",
|
||||
labels=labels,
|
||||
snapshot=snapshot_id,
|
||||
network_block_all=True,
|
||||
auto_stop_interval=10,
|
||||
auto_delete_interval=60,
|
||||
)
|
||||
|
||||
|
||||
def _find_or_create(thread_id: str) -> tuple[_TimeoutAwareSandbox, bool]:
|
||||
"""Find an existing sandbox for *thread_id*, or create a new one.
|
||||
|
||||
Returns a tuple of (sandbox, is_new) where *is_new* is True when a
|
||||
fresh sandbox was created (first time or replacement after failure).
|
||||
"""
|
||||
client = _get_client()
|
||||
labels = {THREAD_LABEL_KEY: thread_id}
|
||||
is_new = False
|
||||
|
||||
try:
|
||||
sandbox = client.find_one(labels=labels)
|
||||
logger.info("Found existing sandbox %s (state=%s)", sandbox.id, sandbox.state)
|
||||
|
||||
if sandbox.state in (
|
||||
SandboxState.STOPPED,
|
||||
SandboxState.STOPPING,
|
||||
SandboxState.ARCHIVED,
|
||||
):
|
||||
logger.info("Starting stopped sandbox %s …", sandbox.id)
|
||||
sandbox.start(timeout=60)
|
||||
logger.info("Sandbox %s is now started", sandbox.id)
|
||||
elif sandbox.state in (
|
||||
SandboxState.ERROR,
|
||||
SandboxState.BUILD_FAILED,
|
||||
SandboxState.DESTROYED,
|
||||
):
|
||||
logger.warning(
|
||||
"Sandbox %s in unrecoverable state %s — creating a new one",
|
||||
sandbox.id,
|
||||
sandbox.state,
|
||||
)
|
||||
try:
|
||||
client.delete(sandbox)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Could not delete broken sandbox %s", sandbox.id, exc_info=True
|
||||
)
|
||||
sandbox = client.create(_sandbox_create_params(labels))
|
||||
is_new = True
|
||||
logger.info("Created replacement sandbox: %s", sandbox.id)
|
||||
elif sandbox.state != SandboxState.STARTED:
|
||||
sandbox.wait_for_sandbox_start(timeout=60)
|
||||
|
||||
except DaytonaError:
|
||||
logger.info("No existing sandbox for thread %s — creating one", thread_id)
|
||||
sandbox = client.create(_sandbox_create_params(labels))
|
||||
is_new = True
|
||||
logger.info("Created new sandbox: %s", sandbox.id)
|
||||
|
||||
return _TimeoutAwareSandbox(sandbox=sandbox), is_new
|
||||
|
||||
|
||||
async def _get_thread_lock(key: str) -> asyncio.Lock:
|
||||
"""Return a per-thread asyncio lock, creating one if needed."""
|
||||
async with _sandbox_locks_mu:
|
||||
lock = _sandbox_locks.get(key)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_sandbox_locks[key] = lock
|
||||
return lock
|
||||
|
||||
|
||||
async def get_or_create_sandbox(
|
||||
thread_id: int | str,
|
||||
) -> tuple[_TimeoutAwareSandbox, bool]:
|
||||
"""Get or create a sandbox for a conversation thread.
|
||||
|
||||
Uses an in-process cache keyed by thread_id so subsequent messages
|
||||
in the same conversation reuse the sandbox object without an API call.
|
||||
A per-thread async lock prevents duplicate sandbox creation from
|
||||
concurrent requests.
|
||||
|
||||
Returns:
|
||||
Tuple of (sandbox, is_new). *is_new* is True when a fresh sandbox
|
||||
was created, signalling that file tracking should be reset.
|
||||
"""
|
||||
key = str(thread_id)
|
||||
lock = await _get_thread_lock(key)
|
||||
|
||||
async with lock:
|
||||
cached = _sandbox_cache.get(key)
|
||||
if cached is not None:
|
||||
logger.info("Reusing cached sandbox for thread %s", key)
|
||||
return cached, False
|
||||
sandbox, is_new = await asyncio.to_thread(_find_or_create, key)
|
||||
_sandbox_cache[key] = sandbox
|
||||
|
||||
if len(_sandbox_cache) > _SANDBOX_CACHE_MAX_SIZE:
|
||||
oldest_key = next(iter(_sandbox_cache))
|
||||
if oldest_key != key:
|
||||
evicted = _sandbox_cache.pop(oldest_key, None)
|
||||
_seeded_files.pop(oldest_key, None)
|
||||
logger.debug("Evicted sandbox cache entry: %s", oldest_key)
|
||||
if evicted is not None:
|
||||
_schedule_sandbox_delete(evicted)
|
||||
|
||||
return sandbox, is_new
|
||||
|
||||
|
||||
def _schedule_sandbox_delete(sandbox: _TimeoutAwareSandbox) -> None:
|
||||
"""Best-effort background deletion of an evicted sandbox."""
|
||||
|
||||
def _delete() -> None:
|
||||
try:
|
||||
client = _get_client()
|
||||
client.delete(sandbox._sandbox)
|
||||
logger.info("Deleted evicted sandbox: %s", sandbox._sandbox.id)
|
||||
except Exception:
|
||||
logger.debug("Could not delete evicted sandbox", exc_info=True)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.run_in_executor(None, _delete)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
async def sync_files_to_sandbox(
|
||||
thread_id: int | str,
|
||||
files: dict[str, dict],
|
||||
sandbox: _TimeoutAwareSandbox,
|
||||
is_new: bool,
|
||||
) -> None:
|
||||
"""Upload new or changed virtual-filesystem files to the sandbox.
|
||||
|
||||
Compares *files* (from ``state["files"]``) against the ``_seeded_files``
|
||||
tracking dict and uploads only what has changed. When *is_new* is True
|
||||
the tracking is reset so every file is re-uploaded.
|
||||
"""
|
||||
key = str(thread_id)
|
||||
if is_new:
|
||||
_seeded_files.pop(key, None)
|
||||
|
||||
tracked = _seeded_files.get(key, {})
|
||||
to_upload: list[tuple[str, bytes]] = []
|
||||
|
||||
for vpath, fdata in files.items():
|
||||
modified_at = fdata.get("modified_at", "")
|
||||
if tracked.get(vpath) == modified_at:
|
||||
continue
|
||||
content = "\n".join(fdata.get("content", []))
|
||||
sandbox_path = f"{SANDBOX_DOCUMENTS_ROOT}{vpath}"
|
||||
to_upload.append((sandbox_path, content.encode("utf-8")))
|
||||
|
||||
if not to_upload:
|
||||
return
|
||||
|
||||
def _upload() -> None:
|
||||
sandbox.upload_files(to_upload)
|
||||
|
||||
await asyncio.to_thread(_upload)
|
||||
|
||||
new_tracked = dict(tracked)
|
||||
for vpath, fdata in files.items():
|
||||
new_tracked[vpath] = fdata.get("modified_at", "")
|
||||
_seeded_files[key] = new_tracked
|
||||
logger.info("Synced %d file(s) to sandbox for thread %s", len(to_upload), key)
|
||||
|
||||
|
||||
def _evict_sandbox_cache(thread_id: int | str) -> None:
|
||||
key = str(thread_id)
|
||||
_sandbox_cache.pop(key, None)
|
||||
_seeded_files.pop(key, None)
|
||||
|
||||
|
||||
async def delete_sandbox(thread_id: int | str) -> None:
|
||||
"""Delete the sandbox for a conversation thread."""
|
||||
_evict_sandbox_cache(thread_id)
|
||||
|
||||
def _delete() -> None:
|
||||
client = _get_client()
|
||||
labels = {THREAD_LABEL_KEY: str(thread_id)}
|
||||
try:
|
||||
sandbox = client.find_one(labels=labels)
|
||||
except DaytonaError:
|
||||
logger.debug(
|
||||
"No sandbox to delete for thread %s (already removed)", thread_id
|
||||
)
|
||||
return
|
||||
try:
|
||||
client.delete(sandbox)
|
||||
logger.info("Sandbox deleted: %s", sandbox.id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to delete sandbox for thread %s",
|
||||
thread_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
await asyncio.to_thread(_delete)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Local file persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_sandbox_files_dir() -> Path:
|
||||
return Path(os.environ.get("SANDBOX_FILES_DIR", "sandbox_files"))
|
||||
|
||||
|
||||
def _local_path_for(thread_id: int | str, sandbox_path: str) -> Path:
|
||||
"""Map a sandbox-internal absolute path to a local filesystem path."""
|
||||
relative = sandbox_path.lstrip("/")
|
||||
base = (_get_sandbox_files_dir() / str(thread_id)).resolve()
|
||||
target = (base / relative).resolve()
|
||||
if not target.is_relative_to(base):
|
||||
raise ValueError(f"Path traversal blocked: {sandbox_path}")
|
||||
return target
|
||||
|
||||
|
||||
def get_local_sandbox_file(thread_id: int | str, sandbox_path: str) -> bytes | None:
|
||||
"""Read a previously-persisted sandbox file from local storage.
|
||||
|
||||
Returns the file bytes, or *None* if the file does not exist locally.
|
||||
"""
|
||||
local = _local_path_for(thread_id, sandbox_path)
|
||||
if local.is_file():
|
||||
return local.read_bytes()
|
||||
return None
|
||||
|
||||
|
||||
def delete_local_sandbox_files(thread_id: int | str) -> None:
|
||||
"""Remove all locally-persisted sandbox files for a thread."""
|
||||
thread_dir = _get_sandbox_files_dir() / str(thread_id)
|
||||
if thread_dir.is_dir():
|
||||
shutil.rmtree(thread_dir, ignore_errors=True)
|
||||
logger.info("Deleted local sandbox files for thread %s", thread_id)
|
||||
|
||||
|
||||
async def persist_and_delete_sandbox(
|
||||
thread_id: int | str,
|
||||
sandbox_file_paths: list[str],
|
||||
) -> None:
|
||||
"""Download sandbox files to local storage, then delete the sandbox.
|
||||
|
||||
Each file in *sandbox_file_paths* is downloaded from the Daytona
|
||||
sandbox and saved under ``{SANDBOX_FILES_DIR}/{thread_id}/…``.
|
||||
Per-file errors are logged but do **not** prevent the sandbox from
|
||||
being deleted — freeing Daytona storage is the priority.
|
||||
"""
|
||||
_evict_sandbox_cache(thread_id)
|
||||
|
||||
def _persist_and_delete() -> None:
|
||||
client = _get_client()
|
||||
labels = {THREAD_LABEL_KEY: str(thread_id)}
|
||||
|
||||
try:
|
||||
sandbox = client.find_one(labels=labels)
|
||||
except Exception:
|
||||
logger.info(
|
||||
"No sandbox found for thread %s — nothing to persist", thread_id
|
||||
)
|
||||
return
|
||||
|
||||
# Ensure the sandbox is running so we can download files
|
||||
if sandbox.state != SandboxState.STARTED:
|
||||
try:
|
||||
sandbox.start(timeout=60)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Could not start sandbox %s for file download — deleting anyway",
|
||||
sandbox.id,
|
||||
exc_info=True,
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
client.delete(sandbox)
|
||||
return
|
||||
|
||||
for path in sandbox_file_paths:
|
||||
try:
|
||||
content: bytes = sandbox.fs.download_file(path)
|
||||
local = _local_path_for(thread_id, path)
|
||||
local.parent.mkdir(parents=True, exist_ok=True)
|
||||
local.write_bytes(content)
|
||||
logger.info("Persisted sandbox file %s → %s", path, local)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist sandbox file %s for thread %s",
|
||||
path,
|
||||
thread_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
client.delete(sandbox)
|
||||
logger.info("Sandbox deleted after file persistence: %s", sandbox.id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to delete sandbox %s after persistence",
|
||||
sandbox.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
await asyncio.to_thread(_persist_and_delete)
|
||||
Loading…
Add table
Add a link
Reference in a new issue