diff --git a/.gitignore b/.gitignore index b45b1961c..2e6ed14e8 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ node_modules/ .pnpm-store .DS_Store deepagents/ -debug.log \ No newline at end of file +debug.log +opencode/ \ No newline at end of file diff --git a/surfsense_backend/app/agents/autocomplete/autocomplete_agent.py b/surfsense_backend/app/agents/autocomplete/autocomplete_agent.py index 2d8f05fd3..890b3e06e 100644 --- a/surfsense_backend/app/agents/autocomplete/autocomplete_agent.py +++ b/surfsense_backend/app/agents/autocomplete/autocomplete_agent.py @@ -28,13 +28,76 @@ from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, ToolMessage +from app.agents.new_chat.document_xml import build_document_xml from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware from app.agents.new_chat.middleware.knowledge_search import ( - build_scoped_filesystem, search_knowledge_base, ) +from app.agents.new_chat.path_resolver import ( + DOCUMENTS_ROOT, + build_path_index, + doc_to_virtual_path, +) +from app.db import shielded_async_session from app.services.new_streaming_service import VercelStreamingService +try: + from deepagents.backends.utils import create_file_data +except Exception: # pragma: no cover - defensive + + def create_file_data(content: str) -> dict[str, Any]: + return {"content": content.split("\n")} + + +async def _build_autocomplete_filesystem( + *, + documents: Any, + search_space_id: int, +) -> tuple[dict[str, Any], dict[int, str]]: + """Build a ``state['files']``-shaped dict from KB search results. + + This is the autocomplete-specific replacement for the previous + ``build_scoped_filesystem`` helper. It uses the canonical path resolver + so paths line up with the rest of the system, including collision + suffixes for duplicate titles. + """ + files: dict[str, Any] = {} + doc_id_to_path: dict[int, str] = {} + + if not documents: + return files, doc_id_to_path + + async with shielded_async_session() as session: + index = await build_path_index(session, search_space_id) + + for document in documents: + if not isinstance(document, dict): + continue + meta = document.get("document") or {} + doc_id = meta.get("id") + if not isinstance(doc_id, int): + continue + title = str(meta.get("title") or "untitled") + folder_id = meta.get("folder_id") + path = doc_to_virtual_path( + doc_id=doc_id, title=title, folder_id=folder_id, index=index + ) + chunk_ids = document.get("matched_chunk_ids") or [] + try: + matched_set = {int(c) for c in chunk_ids} + except (TypeError, ValueError): + matched_set = set() + xml = build_document_xml(document, matched_chunk_ids=matched_set) + files[path] = create_file_data(xml) + doc_id_to_path[doc_id] = path + + if not files: + # Ensure the synthetic /documents folder is visible even when empty. + files.setdefault(f"{DOCUMENTS_ROOT}/.placeholder", create_file_data("")) + + return files, doc_id_to_path + + logger = logging.getLogger(__name__) KB_TOP_K = 10 @@ -174,7 +237,7 @@ async def precompute_kb_filesystem( if not search_results: return _KBResult() - new_files, _ = await build_scoped_filesystem( + new_files, _ = await _build_autocomplete_filesystem( documents=search_results, search_space_id=search_space_id, ) @@ -215,13 +278,12 @@ async def precompute_kb_filesystem( class AutocompleteFilesystemMiddleware(SurfSenseFilesystemMiddleware): """Filesystem middleware for autocomplete — read-only exploration only. - Strips ``save_document`` (permanent KB persistence) and passes - ``search_space_id=None`` so ``write_file`` / ``edit_file`` stay ephemeral. + Passes ``search_space_id=None`` so the new persistence pipeline is + bypassed; the autocomplete flow only reads, never commits to Postgres. """ def __init__(self) -> None: super().__init__(search_space_id=None, created_by_id=None) - self.tools = [t for t in self.tools if t.name != "save_document"] # --------------------------------------------------------------------------- diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index ddf87cf2a..a4f9b048e 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -34,12 +34,15 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.context import SurfSenseContextSchema from app.agents.new_chat.filesystem_backends import build_backend_resolver -from app.agents.new_chat.filesystem_selection import FilesystemSelection +from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection from app.agents.new_chat.llm_config import AgentConfig from app.agents.new_chat.middleware import ( + AnonymousDocumentMiddleware, DedupHITLToolCallsMiddleware, FileIntentMiddleware, - KnowledgeBaseSearchMiddleware, + KnowledgeBasePersistenceMiddleware, + KnowledgePriorityMiddleware, + KnowledgeTreeMiddleware, MemoryInjectionMiddleware, SurfSenseFilesystemMiddleware, ) @@ -246,7 +249,12 @@ async def create_surfsense_deep_agent( """ _t_agent_total = time.perf_counter() filesystem_selection = filesystem_selection or FilesystemSelection() - backend_resolver = build_backend_resolver(filesystem_selection) + backend_resolver = build_backend_resolver( + filesystem_selection, + search_space_id=search_space_id + if filesystem_selection.mode == FilesystemMode.CLOUD + else None, + ) # Discover available connectors and document types for this search space available_connectors: list[str] | None = None @@ -299,7 +307,9 @@ async def create_surfsense_deep_agent( modified_disabled_tools = list(disabled_tools) if disabled_tools else [] modified_disabled_tools.extend(get_connector_gated_tools(available_connectors)) - # Remove direct KB search tool; we now pre-seed a scoped filesystem via middleware. + # Remove direct KB search tool; KnowledgePriorityMiddleware now runs hybrid + # search per turn and surfaces hits as a hint plus + # `` markers inside lazy-loaded XML. if "search_knowledge_base" not in modified_disabled_tools: modified_disabled_tools.append("search_knowledge_base") @@ -365,6 +375,11 @@ async def create_surfsense_deep_agent( ) # General-purpose subagent middleware + # Subagent omits AnonymousDocumentMiddleware, KnowledgeTreeMiddleware, + # KnowledgePriorityMiddleware, and KnowledgeBasePersistenceMiddleware - it + # inherits state and tools from the parent, but should not (a) re-load + # anon docs / re-render the tree / re-run hybrid search, or (b) commit at + # its own completion (only the top-level agent's aafter_agent commits). gp_middleware = [ TodoListMiddleware(), _memory_middleware, @@ -389,19 +404,35 @@ async def create_surfsense_deep_agent( } # Main agent middleware + # Order: AnonDoc -> Tree -> Priority -> FileIntent -> Filesystem -> Persistence -> ... + # before_agent hooks run in declared order; later injections sit closer to + # the latest human turn. Tree (large + cacheable) is injected earliest so + # provider-side prefix caching has more material to hit; FileIntent (most + # actionable per-turn contract) is injected closest to the user message. deepagent_middleware = [ TodoListMiddleware(), _memory_middleware, - FileIntentMiddleware(llm=llm), - KnowledgeBaseSearchMiddleware( + AnonymousDocumentMiddleware( + anon_session_id=anon_session_id, + ) + if filesystem_selection.mode == FilesystemMode.CLOUD + else None, + KnowledgeTreeMiddleware( + search_space_id=search_space_id, + filesystem_mode=filesystem_selection.mode, + llm=llm, + ) + if filesystem_selection.mode == FilesystemMode.CLOUD + else None, + KnowledgePriorityMiddleware( llm=llm, search_space_id=search_space_id, filesystem_mode=filesystem_selection.mode, available_connectors=available_connectors, available_document_types=available_document_types, mentioned_document_ids=mentioned_document_ids, - anon_session_id=anon_session_id, ), + FileIntentMiddleware(llm=llm), SurfSenseFilesystemMiddleware( backend=backend_resolver, filesystem_mode=filesystem_selection.mode, @@ -409,12 +440,20 @@ async def create_surfsense_deep_agent( created_by_id=user_id, thread_id=thread_id, ), + KnowledgeBasePersistenceMiddleware( + search_space_id=search_space_id, + created_by_id=user_id, + filesystem_mode=filesystem_selection.mode, + ) + if filesystem_selection.mode == FilesystemMode.CLOUD + else None, SubAgentMiddleware(backend=StateBackend, subagents=[general_purpose_spec]), create_safe_summarization_middleware(llm, StateBackend), PatchToolCallsMiddleware(), DedupHITLToolCallsMiddleware(agent_tools=tools), AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"), ] + deepagent_middleware = [m for m in deepagent_middleware if m is not None] # Combine system_prompt with BASE_AGENT_PROMPT (same as create_deep_agent) final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT diff --git a/surfsense_backend/app/agents/new_chat/document_xml.py b/surfsense_backend/app/agents/new_chat/document_xml.py new file mode 100644 index 000000000..60e586ae1 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/document_xml.py @@ -0,0 +1,103 @@ +"""Shared XML builder for KB documents. + +Produces the citation-friendly XML used by every read of a knowledge-base +document (lazy-loaded by :class:`KBPostgresBackend` and synthetic anonymous +files). The XML carries a ```` near the top so the LLM can jump +directly to matched-chunk line ranges via ``read_file(offset=…, limit=…)``. + +Extracted from the original ``knowledge_search.py`` so the backend, the +priority middleware, and any future renderer share a single implementation. +""" + +from __future__ import annotations + +import json +from typing import Any + + +def build_document_xml( + document: dict[str, Any], + matched_chunk_ids: set[int] | None = None, +) -> str: + """Build citation-friendly XML with a ```` for smart seeking. + + Args: + document: Dict shape produced by hybrid search / lazy-load helpers. + Expected keys: ``document`` (with ``id``, ``title``, + ``document_type``, ``metadata``) and ``chunks`` + (list of ``{chunk_id, content}``). + matched_chunk_ids: Optional set of chunk IDs to flag as + ``matched="true"`` in the chunk index. + """ + matched = matched_chunk_ids or set() + + doc_meta = document.get("document") or {} + metadata = (doc_meta.get("metadata") or {}) if isinstance(doc_meta, dict) else {} + document_id = doc_meta.get("id", document.get("document_id", "unknown")) + document_type = doc_meta.get("document_type", document.get("source", "UNKNOWN")) + title = doc_meta.get("title") or metadata.get("title") or "Untitled Document" + url = ( + metadata.get("url") or metadata.get("source") or metadata.get("page_url") or "" + ) + metadata_json = json.dumps(metadata, ensure_ascii=False) + + metadata_lines: list[str] = [ + "", + "", + f" {document_id}", + f" {document_type}", + f" <![CDATA[{title}]]>", + f" ", + f" ", + "", + "", + ] + + chunks = document.get("chunks") or [] + chunk_entries: list[tuple[int | None, str]] = [] + if isinstance(chunks, list): + for chunk in chunks: + if not isinstance(chunk, dict): + continue + chunk_id = chunk.get("chunk_id") or chunk.get("id") + chunk_content = str(chunk.get("content", "")).strip() + if not chunk_content: + continue + if chunk_id is None: + xml = f" " + else: + xml = f" " + chunk_entries.append((chunk_id, xml)) + + index_overhead = 1 + len(chunk_entries) + 1 + 1 + 1 + first_chunk_line = len(metadata_lines) + index_overhead + 1 + + current_line = first_chunk_line + index_entry_lines: list[str] = [] + for cid, xml_str in chunk_entries: + num_lines = xml_str.count("\n") + 1 + end_line = current_line + num_lines - 1 + matched_attr = ' matched="true"' if cid is not None and cid in matched else "" + if cid is not None: + index_entry_lines.append( + f' ' + ) + else: + index_entry_lines.append( + f' ' + ) + current_line = end_line + 1 + + lines = metadata_lines.copy() + lines.append("") + lines.extend(index_entry_lines) + lines.append("") + lines.append("") + lines.append("") + for _, xml_str in chunk_entries: + lines.append(xml_str) + lines.extend(["", ""]) + return "\n".join(lines) + + +__all__ = ["build_document_xml"] diff --git a/surfsense_backend/app/agents/new_chat/filesystem_backends.py b/surfsense_backend/app/agents/new_chat/filesystem_backends.py index 85ed5f801..c8288be71 100644 --- a/surfsense_backend/app/agents/new_chat/filesystem_backends.py +++ b/surfsense_backend/app/agents/new_chat/filesystem_backends.py @@ -5,10 +5,12 @@ from __future__ import annotations from collections.abc import Callable from functools import lru_cache +from deepagents.backends.protocol import BackendProtocol from deepagents.backends.state import StateBackend from langgraph.prebuilt.tool_node import ToolRuntime from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection +from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend from app.agents.new_chat.middleware.multi_root_local_folder_backend import ( MultiRootLocalFolderBackend, ) @@ -23,8 +25,20 @@ def _cached_multi_root_backend( def build_backend_resolver( selection: FilesystemSelection, -) -> Callable[[ToolRuntime], StateBackend | MultiRootLocalFolderBackend]: - """Create deepagents backend resolver for the selected filesystem mode.""" + *, + search_space_id: int | None = None, +) -> Callable[[ToolRuntime], BackendProtocol]: + """Create deepagents backend resolver for the selected filesystem mode. + + In cloud mode the resolver returns a fresh :class:`KBPostgresBackend` + bound to the current ``runtime`` so the backend can read staging state + (``staged_dirs``, ``pending_moves``, ``files`` cache, ``kb_anon_doc``, + ``kb_matched_chunk_ids``) for each tool call. When no ``search_space_id`` + is provided, the resolver falls back to :class:`StateBackend` (used by + sub-agents and tests that don't need DB-backed reads). + + Desktop-local mode unchanged. + """ if selection.mode == FilesystemMode.DESKTOP_LOCAL_FOLDER and selection.local_mounts: @@ -36,7 +50,14 @@ def build_backend_resolver( return _resolve_local - def _resolve_cloud(runtime: ToolRuntime) -> StateBackend: + if search_space_id is not None: + + def _resolve_kb(runtime: ToolRuntime) -> BackendProtocol: + return KBPostgresBackend(search_space_id, runtime) + + return _resolve_kb + + def _resolve_state(runtime: ToolRuntime) -> StateBackend: return StateBackend(runtime) - return _resolve_cloud + return _resolve_state diff --git a/surfsense_backend/app/agents/new_chat/filesystem_state.py b/surfsense_backend/app/agents/new_chat/filesystem_state.py new file mode 100644 index 000000000..18952ed6f --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/filesystem_state.py @@ -0,0 +1,113 @@ +"""LangGraph state schema additions used by the SurfSense filesystem agent. + +This schema extends deepagents' upstream :class:`FilesystemState` with the +extra fields needed to implement Postgres-backed virtual filesystem semantics: + +* ``cwd`` — current working directory (per-thread checkpointed). +* ``staged_dirs`` — pending mkdir requests (cloud only). +* ``pending_moves`` — pending move_file requests (cloud only). +* ``doc_id_by_path`` — virtual_path -> Document.id, populated by lazy reads. +* ``dirty_paths`` — paths whose state file content differs from DB. +* ``kb_priority`` — top-K priority hints rendered into a system message. +* ``kb_matched_chunk_ids`` — internal hand-off for matched-chunk highlighting. +* ``kb_anon_doc`` — Redis-loaded anonymous document (if any). +* ``tree_version`` — bumped by persistence; invalidates the tree render cache. + +Tools mutate these fields ONLY via ``Command(update=...)`` returns; the +reducers in :mod:`app.agents.new_chat.state_reducers` handle merging. +""" + +from __future__ import annotations + +from typing import Annotated, Any, NotRequired + +from deepagents.middleware.filesystem import FilesystemState +from typing_extensions import TypedDict + +from app.agents.new_chat.state_reducers import ( + _add_unique_reducer, + _dict_merge_with_tombstones_reducer, + _list_append_reducer, + _replace_reducer, +) + + +class PendingMove(TypedDict): + """A staged move_file operation pending end-of-turn commit.""" + + source: str + dest: str + overwrite: bool + + +class KbPriorityEntry(TypedDict, total=False): + path: str + score: float + document_id: int | None + title: str + mentioned: bool + + +class KbAnonDoc(TypedDict, total=False): + """In-memory anonymous-session document loaded from Redis.""" + + path: str + title: str + content: str + chunks: list[dict[str, Any]] + + +class SurfSenseFilesystemState(FilesystemState): + """Filesystem state used by the SurfSense agent (cloud + desktop). + + Extends deepagents' :class:`FilesystemState` (which provides ``files``) + with cloud-mode staging fields and search-priority hints. All extra fields + are reducer-backed so that ``Command(update=...)`` payloads merge cleanly + across agent steps and across checkpoints. + """ + + cwd: NotRequired[Annotated[str, _replace_reducer]] + """Current working directory. + + Defaults to ``"/documents"`` in cloud mode and ``"/"`` (or first mount) in + desktop mode. Initialized once per thread by ``KnowledgeTreeMiddleware``. + """ + + staged_dirs: NotRequired[Annotated[list[str], _add_unique_reducer]] + """mkdir paths staged for end-of-turn folder creation (cloud only).""" + + pending_moves: NotRequired[Annotated[list[PendingMove], _list_append_reducer]] + """move_file ops staged for end-of-turn commit (cloud only).""" + + doc_id_by_path: NotRequired[ + Annotated[dict[str, int], _dict_merge_with_tombstones_reducer] + ] + """virtual_path -> ``Document.id`` for lazily loaded files. + + Populated on first read of a KB document. Used by edit_file/move_file/ + aafter_agent to map paths back to a real DB row. ``None`` values delete + the key (tombstones). + """ + + dirty_paths: NotRequired[Annotated[list[str], _add_unique_reducer]] + """Paths whose ``state["files"]`` content has been modified this turn.""" + + kb_priority: NotRequired[Annotated[list[KbPriorityEntry], _replace_reducer]] + """Top-K priority hints rendered as a system message before the user turn.""" + + kb_matched_chunk_ids: NotRequired[Annotated[dict[int, list[int]], _replace_reducer]] + """Internal: ``Document.id`` -> list of matched chunk IDs from hybrid search.""" + + kb_anon_doc: NotRequired[Annotated[KbAnonDoc | None, _replace_reducer]] + """Anonymous-session document loaded from Redis (read-only, no DB row).""" + + tree_version: NotRequired[Annotated[int, _replace_reducer]] + """Monotonically increasing counter; bumped when commits change the KB tree.""" + + +__all__ = [ + "KbAnonDoc", + "KbPriorityEntry", + "PendingMove", + "SurfSenseFilesystemState", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/__init__.py b/surfsense_backend/app/agents/new_chat/middleware/__init__.py index 6e4542e1a..e885d9e6b 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/__init__.py +++ b/surfsense_backend/app/agents/new_chat/middleware/__init__.py @@ -1,5 +1,8 @@ """Middleware components for the SurfSense new chat agent.""" +from app.agents.new_chat.middleware.anonymous_document import ( + AnonymousDocumentMiddleware, +) from app.agents.new_chat.middleware.dedup_tool_calls import ( DedupHITLToolCallsMiddleware, ) @@ -9,17 +12,30 @@ from app.agents.new_chat.middleware.file_intent import ( from app.agents.new_chat.middleware.filesystem import ( SurfSenseFilesystemMiddleware, ) +from app.agents.new_chat.middleware.kb_persistence import ( + KnowledgeBasePersistenceMiddleware, + commit_staged_filesystem_state, +) from app.agents.new_chat.middleware.knowledge_search import ( KnowledgeBaseSearchMiddleware, + KnowledgePriorityMiddleware, +) +from app.agents.new_chat.middleware.knowledge_tree import ( + KnowledgeTreeMiddleware, ) from app.agents.new_chat.middleware.memory_injection import ( MemoryInjectionMiddleware, ) __all__ = [ + "AnonymousDocumentMiddleware", "DedupHITLToolCallsMiddleware", "FileIntentMiddleware", + "KnowledgeBasePersistenceMiddleware", "KnowledgeBaseSearchMiddleware", + "KnowledgePriorityMiddleware", + "KnowledgeTreeMiddleware", "MemoryInjectionMiddleware", "SurfSenseFilesystemMiddleware", + "commit_staged_filesystem_state", ] diff --git a/surfsense_backend/app/agents/new_chat/middleware/anonymous_document.py b/surfsense_backend/app/agents/new_chat/middleware/anonymous_document.py new file mode 100644 index 000000000..2893d2e11 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/anonymous_document.py @@ -0,0 +1,91 @@ +"""Lightweight middleware that loads the anonymous-session document into state. + +Anonymous chats receive a single uploaded document via Redis (no DB row, +read-only). This middleware loads it once on the first turn into +``state['kb_anon_doc']`` so: + +* :class:`KnowledgeTreeMiddleware` can render the synthetic ``/documents`` + view without touching the DB. +* :class:`KnowledgePriorityMiddleware` skips hybrid search and emits a + degenerate priority list. +* :class:`KBPostgresBackend` (``als_info`` / ``aread`` / ``_load_file_data``) + recognises the synthetic path. + +The middleware is a no-op when ``anon_session_id`` is not provided or when +the document is already cached in state. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +from langchain.agents.middleware import AgentMiddleware, AgentState +from langgraph.runtime import Runtime + +from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState +from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT, safe_filename + +logger = logging.getLogger(__name__) + + +class AnonymousDocumentMiddleware(AgentMiddleware): # type: ignore[type-arg] + """Load the anonymous user's uploaded document from Redis into state.""" + + tools = () + state_schema = SurfSenseFilesystemState + + def __init__(self, *, anon_session_id: str | None) -> None: + self.anon_session_id = anon_session_id + + async def abefore_agent( # type: ignore[override] + self, + state: AgentState, + runtime: Runtime[Any], + ) -> dict[str, Any] | None: + del runtime + if not self.anon_session_id: + return None + if state.get("kb_anon_doc"): + return None + + anon_doc = await self._load_anon_document() + if anon_doc is None: + return None + return {"kb_anon_doc": anon_doc} + + async def _load_anon_document(self) -> dict[str, Any] | None: + """Read ``anon:doc:`` from Redis.""" + try: + import redis.asyncio as aioredis # local import to keep cold paths cheap + + from app.config import config + + redis_client = aioredis.from_url( + config.REDIS_APP_URL, decode_responses=True + ) + try: + redis_key = f"anon:doc:{self.anon_session_id}" + data = await redis_client.get(redis_key) + if not data: + return None + payload = json.loads(data) + finally: + await redis_client.aclose() + except Exception as exc: + logger.warning("Failed to load anonymous document from Redis: %s", exc) + return None + + title = str(payload.get("filename") or "uploaded_document") + content = str(payload.get("content") or "") + path = f"{DOCUMENTS_ROOT}/{safe_filename(title)}" + return { + "path": path, + "title": title, + "content": content, + "chunks": [{"chunk_id": -1, "content": content}] if content else [], + } + + +__all__ = ["AnonymousDocumentMiddleware"] diff --git a/surfsense_backend/app/agents/new_chat/middleware/file_intent.py b/surfsense_backend/app/agents/new_chat/middleware/file_intent.py index 05cb230ce..7897e13d6 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/file_intent.py +++ b/surfsense_backend/app/agents/new_chat/middleware/file_intent.py @@ -21,7 +21,7 @@ from typing import Any from langchain.agents.middleware import AgentMiddleware, AgentState from langchain_core.language_models import BaseChatModel -from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langgraph.runtime import Runtime from pydantic import BaseModel, Field, ValidationError @@ -217,8 +217,19 @@ def _build_recent_conversation( messages: list[BaseMessage], *, max_messages: int = 6 ) -> str: rows: list[str] = [] - for msg in messages[-max_messages:]: - role = "user" if isinstance(msg, HumanMessage) else "assistant" + filtered: list[tuple[str, BaseMessage]] = [] + for msg in messages: + role: str | None = None + if isinstance(msg, HumanMessage): + role = "user" + elif isinstance(msg, AIMessage): + if getattr(msg, "tool_calls", None): + continue + role = "assistant" + else: + continue + filtered.append((role, msg)) + for role, msg in filtered[-max_messages:]: text = re.sub(r"\s+", " ", _extract_text_from_message(msg)).strip() if text: rows.append(f"{role}: {text[:280]}") diff --git a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py index cb50693f1..62316d69e 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py +++ b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py @@ -1,7 +1,26 @@ """Custom filesystem middleware for the SurfSense agent. -This middleware customizes prompts and persists write/edit operations for -`/documents/*` files into SurfSense's `Document`/`Chunk` tables. +This middleware fully overrides every deepagents filesystem tool so that the +``Command(update=...)`` payload can carry SurfSense-specific state fields +(``cwd``, ``staged_dirs``, ``pending_moves``, ``doc_id_by_path``, +``dirty_paths``) atomically alongside the standard ``files`` update. + +In CLOUD mode the backend is :class:`KBPostgresBackend` (lazy DB reads, no DB +writes). End-of-turn persistence is handled by +:class:`KnowledgeBasePersistenceMiddleware`. In DESKTOP_LOCAL_FOLDER mode the +backend is :class:`MultiRootLocalFolderBackend` and writes go straight to disk. + +New tools introduced here: + +* ``mkdir`` — cloud-only stages folder paths to ``state['staged_dirs']``; + desktop creates real directories. +* ``cd`` / ``pwd`` — manage ``state['cwd']`` (per-thread). +* ``move_file`` — staged commit in cloud, real disk move in desktop. +* ``list_tree`` — works in both modes (cloud uses + :func:`KBPostgresBackend.alist_tree_listing`). + +The middleware no longer ships ``save_document``; persistence is inferred +from ``write_file`` / ``edit_file`` against ``/documents/*`` paths. """ from __future__ import annotations @@ -9,66 +28,92 @@ from __future__ import annotations import asyncio import json import logging +import posixpath import re import secrets -from datetime import UTC, datetime from typing import Annotated, Any from daytona.common.errors import DaytonaError from deepagents import FilesystemMiddleware from deepagents.backends.protocol import EditResult, WriteResult -from deepagents.backends.utils import validate_path -from deepagents.middleware.filesystem import FilesystemState -from fractional_indexing import generate_key_between +from deepagents.backends.utils import ( + create_file_data, + format_read_response, + validate_path, +) from langchain.tools import ToolRuntime -from langchain_core.callbacks import dispatch_custom_event from langchain_core.messages import ToolMessage from langchain_core.tools import BaseTool, StructuredTool from langgraph.types import Command -from sqlalchemy import delete, select from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState +from app.agents.new_chat.middleware.kb_postgres_backend import ( + KBPostgresBackend, + paginate_listing, +) from app.agents.new_chat.middleware.multi_root_local_folder_backend import ( MultiRootLocalFolderBackend, ) +from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT from app.agents.new_chat.sandbox import ( _evict_sandbox_cache, delete_sandbox, get_or_create_sandbox, is_sandbox_enabled, ) -from app.db import Chunk, Document, DocumentType, Folder, shielded_async_session -from app.indexing_pipeline.document_chunker import chunk_text -from app.utils.document_converters import ( - embed_texts, - generate_content_hash, - generate_unique_identifier_hash, -) +from app.agents.new_chat.state_reducers import _CLEAR logger = logging.getLogger(__name__) -# ============================================================================= -# System Prompt (injected into every model call by wrap_model_call) -# ============================================================================= -SURFSENSE_FILESYSTEM_SYSTEM_PROMPT = """## Following Conventions +# ============================================================================= +# System Prompt (built per-session based on filesystem_mode) +# ============================================================================= +# +# Each chat session runs in exactly one filesystem mode. Including rules for +# the OTHER mode just wastes tokens and confuses the model, so we build the +# prompt + tool descriptions for the active mode only. + +_COMMON_PROMPT_HEADER = """## Following Conventions - Read files before editing — understand existing content before making changes. - Mimic existing style, naming conventions, and patterns. - Never claim a file was created/updated unless filesystem tool output confirms success. - If a file write/edit fails, explicitly report the failure. +""" +_CLOUD_SYSTEM_PROMPT = ( + _COMMON_PROMPT_HEADER + + """ ## Filesystem Tools -All file paths must start with a `/`. -- ls: list files and directories at a given path. -- read_file: read a file from the filesystem. -- write_file: create a temporary file in the session (not persisted). -- edit_file: edit a file in the session (not persisted for /documents/ files). -- glob: find files matching a pattern (e.g., "**/*.xml"). -- grep: search for text within files. -- save_document: **permanently** save a new document to the user's knowledge - base. Use only when the user explicitly asks to save/create a document. +All file paths must start with `/`. Relative paths resolve against the +current working directory (`cwd`, default `/documents`). + +- ls(path, offset=0, limit=200): list files and directories at the given path. +- read_file(path, offset, limit): read a file (paginated) from the filesystem. +- write_file(path, content): create a new text file in the workspace. +- edit_file(path, old, new): exact string-replacement edit (lazy-loads KB + documents on first edit). +- glob(pattern, path): find files matching a glob pattern. +- grep(pattern, path, glob): substring search across files. +- mkdir(path): create a folder under `/documents/` (committed at end of turn). +- cd(path): change the current working directory. +- pwd(): print the current working directory. +- move_file(source, dest): move/rename a file under `/documents/`. +- list_tree(path, max_depth, page_size): recursively list files/folders. + +## Persistence Rules + +- Files written under `/documents/<...>` are **persisted** at end of turn as + Documents in the user's knowledge base. +- Files whose **basename** starts with `temp_` (e.g. `temp_plan.md` or + `/documents/temp_scratch.md`) are **discarded** at end of turn — use this + prefix for any scratch/working content you do NOT want saved. +- All other paths (outside `/documents/` and not `temp_*`) are rejected. +- mkdir/move_file are staged this turn and committed at end of turn alongside + any new/edited documents. ## Reading Documents Efficiently @@ -85,23 +130,107 @@ those sections instead of reading the entire file sequentially. Use `` values as citation IDs in your answers. -## User-Mentioned Documents +## Priority List -When the `ls` output tags a file with `[MENTIONED BY USER — read deeply]`, -the user **explicitly selected** that document. These files are your highest- -priority sources: -1. **Always read them thoroughly** — scan the full ``, then read - all major sections, not just matched chunks. -2. **Prefer their content** over other search results when answering. -3. **Cite from them first** whenever applicable. +You receive a `` system message each turn listing the +top-K paths most relevant to the user's query (by hybrid search). Read those +first — matched sections are flagged inside each document's ``. + +## Workspace Tree + +You receive a `` system message each turn with the current +folder/document layout. The tree may be truncated past a hard cap; in that +case, drill into specific folders with `ls(...)` or `list_tree(...)`. + +## grep Line Numbers + +`grep` searches across both your in-memory edits and the indexed chunks in +Postgres. State-cached files return real line numbers; database hits return +`line=0` because their position depends on per-document XML layout — call +`read_file(path)` to find the exact line. """ +) + +_DESKTOP_SYSTEM_PROMPT = ( + _COMMON_PROMPT_HEADER + + """ +## Local Folder Mode + +This chat operates directly on the user's local folders. Writes and edits +hit disk immediately — there is no end-of-turn staging, no `/documents/` +namespace, and no `temp_` semantics. + +## Filesystem Tools + +All file paths must start with `/` and use mount-prefixed absolute paths +like `//file.ext`. Relative paths resolve against the current working +directory (`cwd`). + +- ls(path, offset=0, limit=200): list files and directories at the given path. +- read_file(path, offset, limit): read a file (paginated) from disk. +- write_file(path, content): write a file to disk. +- edit_file(path, old, new): exact string-replacement edit on disk. +- glob(pattern, path): find files matching a glob pattern. +- grep(pattern, path, glob): substring search across files. +- mkdir(path): create a directory on disk. +- cd(path): change the current working directory. +- pwd(): print the current working directory. +- move_file(source, dest): move/rename a file. +- list_tree(path, max_depth, page_size): recursively list files/folders. + +## Workflow Tips + +- If you are unsure which mounts are available, call `ls('/')` first. +- For large trees, prefer `list_tree` then `grep` then `read_file` over + brute-force directory traversal. +- Cross-mount moves are not supported. +""" +) + +_SANDBOX_PROMPT_ADDENDUM = ( + "\n- execute_code: run Python code in an isolated sandbox." + "\n\n## Code Execution" + "\n\nUse execute_code whenever a task benefits from running code." + " Never perform arithmetic manually." + "\n\nDocuments here are XML-wrapped markdown, not raw data files." + " To work with them programmatically, read the document first," + " extract the data, write it as a clean file (CSV, JSON, etc.)," + " and then run your code against it." +) + + +def _build_filesystem_system_prompt( + filesystem_mode: FilesystemMode, + *, + sandbox_available: bool, +) -> str: + """Build the filesystem system prompt for a given session mode. + + The prompt only describes rules and tools that actually apply in the + chosen mode — there is no cross-mode noise. + """ + base = ( + _CLOUD_SYSTEM_PROMPT + if filesystem_mode == FilesystemMode.CLOUD + else _DESKTOP_SYSTEM_PROMPT + ) + if sandbox_available: + base += _SANDBOX_PROMPT_ADDENDUM + return base + + +# Backwards-compatible alias retained for any external imports. +SURFSENSE_FILESYSTEM_SYSTEM_PROMPT = _CLOUD_SYSTEM_PROMPT # ============================================================================= # Per-Tool Descriptions (shown to the LLM as the tool's docstring) # ============================================================================= -SURFSENSE_LIST_FILES_TOOL_DESCRIPTION = """Lists files and directories at the given path. -""" +# ============================================================================= +# Per-Tool Descriptions (mode-specific; injected as the tool's docstring) +# ============================================================================= + +# --- mode-agnostic --------------------------------------------------------- SURFSENSE_READ_FILE_TOOL_DESCRIPTION = """Reads a file from the filesystem. @@ -116,105 +245,241 @@ Usage: - Use chunk IDs (``) as citations in answers. """ -SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION = """Writes a new text file to the in-memory filesystem (session-only). - -Use this to create scratch/working files during the conversation. Files created -here are ephemeral and will not be saved to the user's knowledge base. - -To permanently save a document to the user's knowledge base, use the -`save_document` tool instead. - -Supported outputs include common LLM-friendly text formats like markdown, json, -yaml, csv, xml, html, css, sql, and code files. - -When creating content from open-ended prompts, produce concrete and useful text, -not placeholders. Avoid adding dates/timestamps unless the user explicitly asks -for them. -""" - -SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION = """Performs exact string replacements in files. - -IMPORTANT: -- Read the file before editing. -- Preserve exact indentation and formatting. -- Edits to documents under `/documents/` are session-only (not persisted to the - database) because those files use an XML citation wrapper around the original - content. -""" - -SURFSENSE_MOVE_FILE_TOOL_DESCRIPTION = """Moves or renames a file or folder. - -Use absolute paths for both source and destination. - -Notes: -- In local-folder mode, paths should use mount prefixes (e.g., //foo.txt). -- Rename is a special case of move (same folder, different filename). -- Cross-mount moves are not supported. -""" - -SURFSENSE_LIST_TREE_TOOL_DESCRIPTION = """Lists files/folders recursively in a single bounded call. - -Use this in desktop local-folder mode to discover nested files at scale. - -Args: -- path: absolute mount-prefixed path (e.g., //src) or "/" for mount roots. -- max_depth: recursion depth limit (default 8). -- page_size: maximum number of entries returned (max 1000). -- include_files/include_dirs: filter returned entry types. - -Returns JSON with: -- entries: [{path, is_dir, size, modified_at, depth}] -- truncated: true when additional entries were omitted due to page_size -""" - SURFSENSE_GLOB_TOOL_DESCRIPTION = """Find files matching a glob pattern. Supports standard glob patterns: `*`, `**`, `?`. Returns absolute file paths. """ -SURFSENSE_GREP_TOOL_DESCRIPTION = """Search for a literal text pattern across files. +SURFSENSE_CD_TOOL_DESCRIPTION = """Changes the current working directory (cwd). -Use this to locate relevant document files/chunks before reading full files. +Args: +- path: absolute or relative directory path. Relative paths resolve against + the current cwd. + +The new cwd is used by other filesystem tools whenever a relative path is +given. Returns the resolved cwd. """ +SURFSENSE_PWD_TOOL_DESCRIPTION = """Prints the current working directory.""" + SURFSENSE_EXECUTE_CODE_TOOL_DESCRIPTION = """Executes Python code in an isolated sandbox environment. Common data-science packages are pre-installed (pandas, numpy, matplotlib, scipy, scikit-learn). -When to use this tool: use execute_code for numerical computation, data -analysis, statistics, and any task that benefits from running Python code. -Never perform arithmetic manually when this tool is available. - Usage notes: - No outbound network access. - Returns combined stdout/stderr with exit code. - Use print() to produce output. -- You can create files, run shell commands via subprocess or os.system(), - and use any standard library module. - Use the optional timeout parameter to override the default timeout. """ -SURFSENSE_SAVE_DOCUMENT_TOOL_DESCRIPTION = """Permanently saves a document to the user's knowledge base. +# --- cloud-only ------------------------------------------------------------ -This is an expensive operation — it creates a new Document record in the -database, chunks the content, and generates embeddings for search. +_CLOUD_LIST_FILES_TOOL_DESCRIPTION = """Lists files and directories at the given path. -Use ONLY when the user explicitly asks to save/create/store a document. -Do NOT use this for scratch work; use `write_file` for temporary files. +Usage: +- Provide an absolute path under `/documents` (relative paths resolve under + the current cwd, which defaults to `/documents`). +- For very large folders, use `offset` and `limit` to paginate the listing. +- Returns one entry per line; directories end with a trailing `/`. +""" + +_CLOUD_WRITE_FILE_TOOL_DESCRIPTION = """Writes a new text file to the workspace. + +Usage: +- Files written under `/documents/<...>` are persisted as Documents at end + of turn. +- Use a `temp_` filename prefix (e.g. `temp_plan.md` or `/documents/temp_x.md`) + for scratch/working files; they are automatically discarded at end of turn. +- Writes outside `/documents/` are rejected unless the basename starts with + `temp_`. +- Supported outputs include common LLM-friendly text formats like markdown, + json, yaml, csv, xml, html, css, sql, and code files. +- Avoid placeholders; produce concrete and useful text. +""" + +_CLOUD_EDIT_FILE_TOOL_DESCRIPTION = """Performs exact string replacements in files. + +IMPORTANT: +- Read the file before editing. +- Preserve exact indentation and formatting. +- Edits to documents under `/documents/` are persisted at end of turn. +- Edits to `temp_*` files are discarded at end of turn. +""" + +_CLOUD_MOVE_FILE_TOOL_DESCRIPTION = """Moves or renames a file or folder. + +Use absolute paths for both source and destination. + +Notes: +- `move_file` is staged this turn and committed at end of turn. +- The agent cannot overwrite an existing destination — pass a fresh dest + path or move the existing destination away first. +- The anonymous uploaded document is read-only and cannot be moved. +- Rename is a special case of move (same folder, different filename). +""" + +_CLOUD_LIST_TREE_TOOL_DESCRIPTION = """Lists files/folders recursively in a single bounded call. Args: - title: The document title (e.g., "Meeting Notes 2025-06-01"). - content: The plain-text or markdown content to save. Do NOT include XML - citation wrappers — pass only the actual document text. - folder_path: Optional folder path under /documents/ (e.g., "Work/Notes"). - Folders are created automatically if they don't exist. +- path: absolute path to start from. Defaults to `/documents`. +- max_depth: recursion depth limit (default 8). +- page_size: maximum number of entries returned (max 1000). +- include_files / include_dirs: filter returned entry types. + +Returns JSON with: +- entries: [{path, is_dir, size, modified_at, depth}] +- truncated: true when additional entries were omitted due to page_size. +""" + +_CLOUD_GREP_TOOL_DESCRIPTION = """Search for a literal text pattern across files. + +Searches both your in-memory edits and the indexed chunks in Postgres. +State-cached file matches include real line numbers; database hits return +`line=0` because their position depends on per-document XML layout — call +`read_file(path)` afterwards to find the exact line. +""" + +_CLOUD_MKDIR_TOOL_DESCRIPTION = """Creates a directory under `/documents/`. + +Stages the folder for end-of-turn commit; the Folder row is inserted only +after the agent's turn finishes successfully. + +Args: +- path: absolute path of the new directory (must start with + `/documents/`). + +Notes: +- Parent folders are created as needed. +""" + +# --- desktop-only ---------------------------------------------------------- + +_DESKTOP_LIST_FILES_TOOL_DESCRIPTION = """Lists files and directories at the given path. + +Usage: +- Provide an absolute path using a mount prefix (e.g. `//sub/dir`). + Use `ls('/')` to discover available mounts. +- For very large folders, use `offset` and `limit` to paginate the listing. +- Returns one entry per line; directories end with a trailing `/`. +""" + +_DESKTOP_WRITE_FILE_TOOL_DESCRIPTION = """Writes a text file to disk. + +Usage: +- Use mount-prefixed absolute paths like `//sub/file.ext`. +- Writes hit disk immediately. There is no end-of-turn staging. +- Supported outputs include common LLM-friendly text formats like markdown, + json, yaml, csv, xml, html, css, sql, and code files. +- Avoid placeholders; produce concrete and useful text. +""" + +_DESKTOP_EDIT_FILE_TOOL_DESCRIPTION = """Performs exact string replacements in files on disk. + +IMPORTANT: +- Read the file before editing. +- Preserve exact indentation and formatting. +- Edits hit disk immediately. +""" + +_DESKTOP_MOVE_FILE_TOOL_DESCRIPTION = """Moves or renames a file or folder on disk. + +Use mount-prefixed absolute paths for both source and destination +(e.g. `//old.txt` -> `//new.txt`). + +Notes: +- Cross-mount moves are not supported. +- Rename is a special case of move (same folder, different filename). +""" + +_DESKTOP_LIST_TREE_TOOL_DESCRIPTION = """Lists files/folders recursively in a single bounded call. + +Args: +- path: absolute path to start from. Defaults to `/`. +- max_depth: recursion depth limit (default 8). +- page_size: maximum number of entries returned (max 1000). +- include_files / include_dirs: filter returned entry types. + +Returns JSON with: +- entries: [{path, is_dir, size, modified_at, depth}] +- truncated: true when additional entries were omitted due to page_size. +""" + +_DESKTOP_GREP_TOOL_DESCRIPTION = """Search for a literal text pattern across files. + +Searches files on disk and any in-memory edits. Returns real line numbers. +""" + +_DESKTOP_MKDIR_TOOL_DESCRIPTION = """Creates a directory on disk. + +Args: +- path: absolute mount-prefixed path of the new directory. + +Notes: +- Parent folders are created as needed. """ +def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]: + """Pick the active-mode description for every filesystem tool.""" + if filesystem_mode == FilesystemMode.CLOUD: + return { + "ls": _CLOUD_LIST_FILES_TOOL_DESCRIPTION, + "read_file": SURFSENSE_READ_FILE_TOOL_DESCRIPTION, + "write_file": _CLOUD_WRITE_FILE_TOOL_DESCRIPTION, + "edit_file": _CLOUD_EDIT_FILE_TOOL_DESCRIPTION, + "move_file": _CLOUD_MOVE_FILE_TOOL_DESCRIPTION, + "list_tree": _CLOUD_LIST_TREE_TOOL_DESCRIPTION, + "glob": SURFSENSE_GLOB_TOOL_DESCRIPTION, + "grep": _CLOUD_GREP_TOOL_DESCRIPTION, + "mkdir": _CLOUD_MKDIR_TOOL_DESCRIPTION, + "cd": SURFSENSE_CD_TOOL_DESCRIPTION, + "pwd": SURFSENSE_PWD_TOOL_DESCRIPTION, + } + return { + "ls": _DESKTOP_LIST_FILES_TOOL_DESCRIPTION, + "read_file": SURFSENSE_READ_FILE_TOOL_DESCRIPTION, + "write_file": _DESKTOP_WRITE_FILE_TOOL_DESCRIPTION, + "edit_file": _DESKTOP_EDIT_FILE_TOOL_DESCRIPTION, + "move_file": _DESKTOP_MOVE_FILE_TOOL_DESCRIPTION, + "list_tree": _DESKTOP_LIST_TREE_TOOL_DESCRIPTION, + "glob": SURFSENSE_GLOB_TOOL_DESCRIPTION, + "grep": _DESKTOP_GREP_TOOL_DESCRIPTION, + "mkdir": _DESKTOP_MKDIR_TOOL_DESCRIPTION, + "cd": SURFSENSE_CD_TOOL_DESCRIPTION, + "pwd": SURFSENSE_PWD_TOOL_DESCRIPTION, + } + + +# Backwards-compatible aliases retained for any external imports/tests that +# referenced the original CLOUD-flavoured constants. +SURFSENSE_LIST_FILES_TOOL_DESCRIPTION = _CLOUD_LIST_FILES_TOOL_DESCRIPTION +SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION = _CLOUD_WRITE_FILE_TOOL_DESCRIPTION +SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION = _CLOUD_EDIT_FILE_TOOL_DESCRIPTION +SURFSENSE_MOVE_FILE_TOOL_DESCRIPTION = _CLOUD_MOVE_FILE_TOOL_DESCRIPTION +SURFSENSE_LIST_TREE_TOOL_DESCRIPTION = _CLOUD_LIST_TREE_TOOL_DESCRIPTION +SURFSENSE_GREP_TOOL_DESCRIPTION = _CLOUD_GREP_TOOL_DESCRIPTION +SURFSENSE_MKDIR_TOOL_DESCRIPTION = _CLOUD_MKDIR_TOOL_DESCRIPTION + + +# ============================================================================= +# Helpers +# ============================================================================= + + +_TEMP_PREFIX = "temp_" + + +def _basename(path: str) -> str: + return path.rsplit("/", 1)[-1] + + class SurfSenseFilesystemMiddleware(FilesystemMiddleware): - """SurfSense-specific filesystem middleware with DB persistence for docs.""" + """SurfSense-specific filesystem middleware (cloud + desktop).""" + + state_schema = SurfSenseFilesystemState _MAX_EXECUTE_TIMEOUT = 300 @@ -234,582 +499,45 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): self._thread_id = thread_id self._sandbox_available = is_sandbox_enabled() and thread_id is not None - system_prompt = SURFSENSE_FILESYSTEM_SYSTEM_PROMPT - if self._sandbox_available: - system_prompt += ( - "\n- execute_code: run Python code in an isolated sandbox." - "\n\n## Code Execution" - "\n\nUse execute_code whenever a task benefits from running code." - " Never perform arithmetic manually." - "\n\nDocuments here are XML-wrapped markdown, not raw data files." - " To work with them programmatically, read the document first," - " extract the data, write it as a clean file (CSV, JSON, etc.)," - " and then run your code against it." - ) - if filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: - system_prompt += ( - "\n- move_file: move or rename files/folders in local-folder mode." - "\n- list_tree: recursively list nested local paths in one bounded response." - "\n\n## Local Folder Mode" - "\n\nThis chat is running in desktop local-folder mode." - " Keep all file operations local. Do not use save_document." - " Always use mount-prefixed absolute paths like //file.ext." - " If you are unsure which mounts are available, call ls('/') first." - " For big trees: use list_tree, then grep, then read_file." - ) + # Build the prompt + tool descriptions for the active mode only — + # mixing both modes wastes tokens and confuses the model with rules + # it can't actually use this session. + system_prompt = _build_filesystem_system_prompt( + filesystem_mode, + sandbox_available=self._sandbox_available, + ) super().__init__( backend=backend, system_prompt=system_prompt, - custom_tool_descriptions={ - "ls": SURFSENSE_LIST_FILES_TOOL_DESCRIPTION, - "read_file": SURFSENSE_READ_FILE_TOOL_DESCRIPTION, - "write_file": SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION, - "edit_file": SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION, - "move_file": SURFSENSE_MOVE_FILE_TOOL_DESCRIPTION, - "list_tree": SURFSENSE_LIST_TREE_TOOL_DESCRIPTION, - "glob": SURFSENSE_GLOB_TOOL_DESCRIPTION, - "grep": SURFSENSE_GREP_TOOL_DESCRIPTION, - }, + custom_tool_descriptions=_build_tool_descriptions(filesystem_mode), tool_token_limit_before_evict=tool_token_limit_before_evict, max_execute_timeout=self._MAX_EXECUTE_TIMEOUT, ) self.tools = [t for t in self.tools if t.name != "execute"] - if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: - self.tools.append(self._create_move_file_tool()) - self.tools.append(self._create_list_tree_tool()) - if self._should_persist_documents(): - self.tools.append(self._create_save_document_tool()) + self.tools.append(self._create_mkdir_tool()) + self.tools.append(self._create_cd_tool()) + self.tools.append(self._create_pwd_tool()) + self.tools.append(self._create_move_file_tool()) + self.tools.append(self._create_list_tree_tool()) if self._sandbox_available: self.tools.append(self._create_execute_code_tool()) + # ------------------------------------------------------------------ helpers + + def _is_cloud(self) -> bool: + return self._filesystem_mode == FilesystemMode.CLOUD + @staticmethod def _run_async_blocking(coro: Any) -> Any: - """Run async coroutine from sync code path when no event loop is running.""" try: loop = asyncio.get_running_loop() if loop.is_running(): - return "Error: sync filesystem persistence not supported inside an active event loop." + return "Error: sync filesystem operation not supported inside an active event loop." except RuntimeError: pass return asyncio.run(coro) - @staticmethod - def _parse_virtual_path(file_path: str) -> tuple[list[str], str]: - """Parse /documents/... path into folder parts and a document title.""" - if not file_path.startswith("/documents/"): - return [], "" - rel = file_path[len("/documents/") :].strip("/") - if not rel: - return [], "" - parts = [part for part in rel.split("/") if part] - file_name = parts[-1] - title = file_name[:-4] if file_name.lower().endswith(".xml") else file_name - return parts[:-1], title - - async def _ensure_folder_hierarchy( - self, - *, - folder_parts: list[str], - search_space_id: int, - ) -> int | None: - """Ensure folder hierarchy exists and return leaf folder ID.""" - if not folder_parts: - return None - async with shielded_async_session() as session: - parent_id: int | None = None - for name in folder_parts: - result = await session.execute( - select(Folder).where( - Folder.search_space_id == search_space_id, - Folder.parent_id == parent_id - if parent_id is not None - else Folder.parent_id.is_(None), - Folder.name == name, - ) - ) - folder = result.scalar_one_or_none() - if folder is None: - sibling_result = await session.execute( - select(Folder.position) - .where( - Folder.search_space_id == search_space_id, - Folder.parent_id == parent_id - if parent_id is not None - else Folder.parent_id.is_(None), - ) - .order_by(Folder.position.desc()) - .limit(1) - ) - last_position = sibling_result.scalar_one_or_none() - folder = Folder( - name=name, - position=generate_key_between(last_position, None), - parent_id=parent_id, - search_space_id=search_space_id, - created_by_id=self._created_by_id, - updated_at=datetime.now(UTC), - ) - session.add(folder) - await session.flush() - parent_id = folder.id - await session.commit() - return parent_id - - async def _persist_new_document( - self, *, file_path: str, content: str - ) -> dict[str, Any] | str: - """Persist a new NOTE document from a newly written file. - - Returns a dict with document metadata on success, or an error string. - """ - if self._search_space_id is None: - return {} - folder_parts, title = self._parse_virtual_path(file_path) - if not title: - return "Error: write_file for document persistence requires path under /documents/.xml" - folder_id = await self._ensure_folder_hierarchy( - folder_parts=folder_parts, - search_space_id=self._search_space_id, - ) - async with shielded_async_session() as session: - content_hash = generate_content_hash(content, self._search_space_id) - existing = await session.execute( - select(Document.id).where(Document.content_hash == content_hash) - ) - if existing.scalar_one_or_none() is not None: - return "Error: A document with identical content already exists." - unique_identifier_hash = generate_unique_identifier_hash( - DocumentType.NOTE, - file_path, - self._search_space_id, - ) - doc = Document( - title=title, - document_type=DocumentType.NOTE, - document_metadata={"virtual_path": file_path}, - content=content, - content_hash=content_hash, - unique_identifier_hash=unique_identifier_hash, - source_markdown=content, - search_space_id=self._search_space_id, - folder_id=folder_id, - created_by_id=self._created_by_id, - updated_at=datetime.now(UTC), - ) - session.add(doc) - await session.flush() - - summary_embedding = embed_texts([content])[0] - doc.embedding = summary_embedding - chunk_texts = chunk_text(content) - if chunk_texts: - chunk_embeddings = embed_texts(chunk_texts) - chunks = [ - Chunk(document_id=doc.id, content=text, embedding=embedding) - for text, embedding in zip( - chunk_texts, chunk_embeddings, strict=True - ) - ] - session.add_all(chunks) - await session.commit() - - return { - "id": doc.id, - "title": title, - "documentType": DocumentType.NOTE.value, - "searchSpaceId": self._search_space_id, - "folderId": folder_id, - "createdById": str(self._created_by_id) - if self._created_by_id - else None, - } - - async def _persist_edited_document( - self, *, file_path: str, updated_content: str - ) -> str | None: - """Persist edits for an existing NOTE document and recreate chunks.""" - if self._search_space_id is None: - return None - unique_identifier_hash = generate_unique_identifier_hash( - DocumentType.NOTE, - file_path, - self._search_space_id, - ) - doc_id_from_xml: int | None = None - match = re.search(r"\s*(\d+)\s*", updated_content) - if match: - doc_id_from_xml = int(match.group(1)) - async with shielded_async_session() as session: - doc_result = await session.execute( - select(Document).where( - Document.search_space_id == self._search_space_id, - Document.unique_identifier_hash == unique_identifier_hash, - ) - ) - document = doc_result.scalar_one_or_none() - if document is None and doc_id_from_xml is not None: - by_id_result = await session.execute( - select(Document).where( - Document.search_space_id == self._search_space_id, - Document.id == doc_id_from_xml, - ) - ) - document = by_id_result.scalar_one_or_none() - if document is None: - return "Error: Could not map edited file to an existing document." - - document.content = updated_content - document.source_markdown = updated_content - document.content_hash = generate_content_hash( - updated_content, self._search_space_id - ) - document.updated_at = datetime.now(UTC) - if not document.document_metadata: - document.document_metadata = {} - document.document_metadata["virtual_path"] = file_path - - summary_embedding = embed_texts([updated_content])[0] - document.embedding = summary_embedding - - await session.execute(delete(Chunk).where(Chunk.document_id == document.id)) - chunk_texts = chunk_text(updated_content) - if chunk_texts: - chunk_embeddings = embed_texts(chunk_texts) - session.add_all( - [ - Chunk( - document_id=document.id, content=text, embedding=embedding - ) - for text, embedding in zip( - chunk_texts, chunk_embeddings, strict=True - ) - ] - ) - await session.commit() - return None - - def _create_save_document_tool(self) -> BaseTool: - """Create save_document tool that persists a new document to the KB.""" - - def sync_save_document( - title: Annotated[str, "Title for the new document."], - content: Annotated[ - str, - "Plain-text or markdown content to save. Do NOT include XML wrappers.", - ], - runtime: ToolRuntime[None, FilesystemState], - folder_path: Annotated[ - str, - "Optional folder path under /documents/ (e.g. 'Work/Notes'). Created automatically.", - ] = "", - ) -> Command | str: - if not content.strip(): - return "Error: content cannot be empty." - file_name = re.sub(r'[\\/:*?"<>|]+', "_", title).strip() or "untitled" - if not file_name.lower().endswith(".xml"): - file_name = f"{file_name}.xml" - folder = folder_path.strip().strip("/") if folder_path else "" - virtual_path = ( - f"/documents/{folder}/{file_name}" - if folder - else f"/documents/{file_name}" - ) - - persist_result = self._run_async_blocking( - self._persist_new_document(file_path=virtual_path, content=content) - ) - if isinstance(persist_result, str): - return persist_result - if isinstance(persist_result, dict) and persist_result.get("id"): - dispatch_custom_event("document_created", persist_result) - return f"Document '{title}' saved to knowledge base (path: {virtual_path})." - - async def async_save_document( - title: Annotated[str, "Title for the new document."], - content: Annotated[ - str, - "Plain-text or markdown content to save. Do NOT include XML wrappers.", - ], - runtime: ToolRuntime[None, FilesystemState], - folder_path: Annotated[ - str, - "Optional folder path under /documents/ (e.g. 'Work/Notes'). Created automatically.", - ] = "", - ) -> Command | str: - if not content.strip(): - return "Error: content cannot be empty." - file_name = re.sub(r'[\\/:*?"<>|]+', "_", title).strip() or "untitled" - if not file_name.lower().endswith(".xml"): - file_name = f"{file_name}.xml" - folder = folder_path.strip().strip("/") if folder_path else "" - virtual_path = ( - f"/documents/{folder}/{file_name}" - if folder - else f"/documents/{file_name}" - ) - - persist_result = await self._persist_new_document( - file_path=virtual_path, content=content - ) - if isinstance(persist_result, str): - return persist_result - if isinstance(persist_result, dict) and persist_result.get("id"): - dispatch_custom_event("document_created", persist_result) - return f"Document '{title}' saved to knowledge base (path: {virtual_path})." - - return StructuredTool.from_function( - name="save_document", - description=SURFSENSE_SAVE_DOCUMENT_TOOL_DESCRIPTION, - func=sync_save_document, - coroutine=async_save_document, - ) - - def _create_execute_code_tool(self) -> BaseTool: - """Create execute_code tool backed by a Daytona sandbox.""" - - def sync_execute_code( - command: Annotated[ - str, "Python code to execute. Use print() to see output." - ], - runtime: ToolRuntime[None, FilesystemState], - timeout: Annotated[ - int | None, - "Optional timeout in seconds.", - ] = None, - ) -> str: - if timeout is not None: - if timeout < 0: - return f"Error: timeout must be non-negative, got {timeout}." - if timeout > self._MAX_EXECUTE_TIMEOUT: - return f"Error: timeout {timeout}s exceeds maximum ({self._MAX_EXECUTE_TIMEOUT}s)." - return self._run_async_blocking( - self._execute_in_sandbox(command, runtime, timeout) - ) - - async def async_execute_code( - command: Annotated[ - str, "Python code to execute. Use print() to see output." - ], - runtime: ToolRuntime[None, FilesystemState], - timeout: Annotated[ - int | None, - "Optional timeout in seconds.", - ] = None, - ) -> str: - if timeout is not None: - if timeout < 0: - return f"Error: timeout must be non-negative, got {timeout}." - if timeout > self._MAX_EXECUTE_TIMEOUT: - return f"Error: timeout {timeout}s exceeds maximum ({self._MAX_EXECUTE_TIMEOUT}s)." - return await self._execute_in_sandbox(command, runtime, timeout) - - return StructuredTool.from_function( - name="execute_code", - description=SURFSENSE_EXECUTE_CODE_TOOL_DESCRIPTION, - func=sync_execute_code, - coroutine=async_execute_code, - ) - - @staticmethod - def _wrap_as_python(code: str) -> str: - """Wrap Python code in a shell invocation for the sandbox.""" - sentinel = f"_PYEOF_{secrets.token_hex(8)}" - return f"python3 << '{sentinel}'\n{code}\n{sentinel}" - - async def _execute_in_sandbox( - self, - command: str, - runtime: ToolRuntime[None, FilesystemState], - timeout: int | None, - ) -> str: - """Core logic: get sandbox, sync files, run command, handle retries.""" - assert self._thread_id is not None - command = self._wrap_as_python(command) - - try: - return await self._try_sandbox_execute(command, runtime, timeout) - except (DaytonaError, Exception) as first_err: - logger.warning( - "Sandbox execute failed for thread %s, retrying: %s", - self._thread_id, - first_err, - ) - try: - await delete_sandbox(self._thread_id) - except Exception: - _evict_sandbox_cache(self._thread_id) - try: - return await self._try_sandbox_execute(command, runtime, timeout) - except Exception: - logger.exception( - "Sandbox retry also failed for thread %s", self._thread_id - ) - return "Error: Code execution is temporarily unavailable. Please try again." - - async def _try_sandbox_execute( - self, - command: str, - runtime: ToolRuntime[None, FilesystemState], - timeout: int | None, - ) -> str: - sandbox, _is_new = await get_or_create_sandbox(self._thread_id) - # NOTE: sync_files_to_sandbox is intentionally disabled. - # The virtual FS contains XML-wrapped KB documents whose paths - # would double-nest under SANDBOX_DOCUMENTS_ROOT (e.g. - # /home/daytona/documents/documents/Report.xml) and uploading - # all KB docs on the first execute_code call adds significant - # latency. Re-enable once path mapping is fixed and upload is - # limited to user-created scratch files. - # files = runtime.state.get("files") or {} - # await sync_files_to_sandbox(self._thread_id, files, sandbox, is_new) - result = await sandbox.aexecute(command, timeout=timeout) - output = (result.output or "").strip() - if not output and result.exit_code == 0: - return ( - "[Code executed successfully but produced no output. " - "Use print() to display results, then try again.]" - ) - parts = [result.output] - if result.exit_code is not None: - status = "succeeded" if result.exit_code == 0 else "failed" - parts.append(f"\n[Command {status} with exit code {result.exit_code}]") - if result.truncated: - parts.append("\n[Output was truncated due to size limits]") - return "".join(parts) - - def _create_write_file_tool(self) -> BaseTool: - """Create write_file — ephemeral for /documents/*, persisted otherwise.""" - tool_description = ( - self._custom_tool_descriptions.get("write_file") - or SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION - ) - - def sync_write_file( - file_path: Annotated[ - str, - "Absolute path where the file should be created. Must be absolute, not relative.", - ], - content: Annotated[ - str, - "The text content to write to the file. This parameter is required.", - ], - runtime: ToolRuntime[None, FilesystemState], - ) -> Command | str: - resolved_backend = self._get_backend(runtime) - target_path = self._resolve_write_target_path(file_path, runtime) - try: - validated_path = validate_path(target_path) - except ValueError as exc: - return f"Error: {exc}" - res: WriteResult = resolved_backend.write(validated_path, content) - if res.error: - return res.error - verify_error = self._verify_written_content_sync( - backend=resolved_backend, - path=validated_path, - expected_content=content, - ) - if verify_error: - return verify_error - - if self._should_persist_documents() and not self._is_kb_document( - validated_path - ): - persist_result = self._run_async_blocking( - self._persist_new_document( - file_path=validated_path, content=content - ) - ) - if isinstance(persist_result, str): - return persist_result - if isinstance(persist_result, dict) and persist_result.get("id"): - dispatch_custom_event("document_created", persist_result) - - if res.files_update is not None: - return Command( - update={ - "files": res.files_update, - "messages": [ - ToolMessage( - content=f"Updated file {res.path}", - tool_call_id=runtime.tool_call_id, - ) - ], - } - ) - return f"Updated file {res.path}" - - async def async_write_file( - file_path: Annotated[ - str, - "Absolute path where the file should be created. Must be absolute, not relative.", - ], - content: Annotated[ - str, - "The text content to write to the file. This parameter is required.", - ], - runtime: ToolRuntime[None, FilesystemState], - ) -> Command | str: - resolved_backend = self._get_backend(runtime) - target_path = self._resolve_write_target_path(file_path, runtime) - try: - validated_path = validate_path(target_path) - except ValueError as exc: - return f"Error: {exc}" - res: WriteResult = await resolved_backend.awrite(validated_path, content) - if res.error: - return res.error - verify_error = await self._verify_written_content_async( - backend=resolved_backend, - path=validated_path, - expected_content=content, - ) - if verify_error: - return verify_error - - if self._should_persist_documents() and not self._is_kb_document( - validated_path - ): - persist_result = await self._persist_new_document( - file_path=validated_path, - content=content, - ) - if isinstance(persist_result, str): - return persist_result - if isinstance(persist_result, dict) and persist_result.get("id"): - dispatch_custom_event("document_created", persist_result) - - if res.files_update is not None: - return Command( - update={ - "files": res.files_update, - "messages": [ - ToolMessage( - content=f"Updated file {res.path}", - tool_call_id=runtime.tool_call_id, - ) - ], - } - ) - return f"Updated file {res.path}" - - return StructuredTool.from_function( - name="write_file", - description=tool_description, - func=sync_write_file, - coroutine=async_write_file, - ) - - @staticmethod - def _is_kb_document(path: str) -> bool: - """Return True for paths under /documents/ (KB-sourced, XML-wrapped).""" - return path.startswith("/documents/") - - def _should_persist_documents(self) -> bool: - """Only cloud mode persists file content to Document/Chunk tables.""" - return self._filesystem_mode == FilesystemMode.CLOUD - @staticmethod def _normalize_absolute_path(candidate: str) -> str: normalized = re.sub(r"/+", "/", candidate.strip().replace("\\", "/")) @@ -857,7 +585,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): def _normalize_local_mount_path( self, candidate: str, - runtime: ToolRuntime[None, FilesystemState], + runtime: ToolRuntime[None, SurfSenseFilesystemState], ) -> str: normalized = self._normalize_absolute_path(candidate) backend = self._get_backend(runtime) @@ -904,276 +632,674 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): return f"/{backend.default_mount()}{normalized}" + def _default_cwd(self) -> str: + return DOCUMENTS_ROOT if self._is_cloud() else "/" + + def _current_cwd(self, runtime: ToolRuntime[None, SurfSenseFilesystemState]) -> str: + cwd = runtime.state.get("cwd") if hasattr(runtime, "state") else None + if isinstance(cwd, str) and cwd.startswith("/"): + return cwd + return self._default_cwd() + def _get_contract_suggested_path( - self, runtime: ToolRuntime[None, FilesystemState] + self, runtime: ToolRuntime[None, SurfSenseFilesystemState] ) -> str: contract = runtime.state.get("file_operation_contract") or {} suggested = contract.get("suggested_path") if isinstance(suggested, str) and suggested.strip(): return self._normalize_absolute_path(suggested) - return "/notes.md" + return self._default_cwd().rstrip("/") + "/notes.md" + + def _resolve_relative( + self, + path: str, + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> str: + candidate = path.strip() + if not candidate: + return self._current_cwd(runtime) + if candidate.startswith("/"): + return self._normalize_absolute_path(candidate) + cwd = self._current_cwd(runtime) + joined = posixpath.normpath(posixpath.join(cwd, candidate)) + return self._normalize_absolute_path(joined) def _resolve_write_target_path( self, file_path: str, - runtime: ToolRuntime[None, FilesystemState], + runtime: ToolRuntime[None, SurfSenseFilesystemState], ) -> str: candidate = file_path.strip() if not candidate: return self._get_contract_suggested_path(runtime) if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: return self._normalize_local_mount_path(candidate, runtime) - if not candidate.startswith("/"): - return f"/{candidate.lstrip('/')}" - return candidate + return self._resolve_relative(candidate, runtime) def _resolve_move_target_path( self, file_path: str, - runtime: ToolRuntime[None, FilesystemState], + runtime: ToolRuntime[None, SurfSenseFilesystemState], ) -> str: candidate = file_path.strip() if not candidate: return "" if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: return self._normalize_local_mount_path(candidate, runtime) - if not candidate.startswith("/"): - return f"/{candidate.lstrip('/')}" - return candidate + return self._resolve_relative(candidate, runtime) def _resolve_list_target_path( self, path: str, - runtime: ToolRuntime[None, FilesystemState], + runtime: ToolRuntime[None, SurfSenseFilesystemState], ) -> str: - candidate = path.strip() or "/" + candidate = path.strip() or self._current_cwd(runtime) if candidate == "/": return "/" if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: return self._normalize_local_mount_path(candidate, runtime) - if not candidate.startswith("/"): - return f"/{candidate.lstrip('/')}" - return candidate + return self._resolve_relative(candidate, runtime) - @staticmethod - def _is_error_text(value: str) -> bool: - return value.startswith("Error:") + # ------------------------------------------------------------------ namespace policy - @staticmethod - def _read_for_verification_sync(backend: Any, path: str) -> str: - read_raw = getattr(backend, "read_raw", None) - if callable(read_raw): - return read_raw(path) - return backend.read(path, offset=0, limit=200000) - - @staticmethod - async def _read_for_verification_async(backend: Any, path: str) -> str: - aread_raw = getattr(backend, "aread_raw", None) - if callable(aread_raw): - return await aread_raw(path) - return await backend.aread(path, offset=0, limit=200000) - - def _verify_written_content_sync( + def _check_cloud_write_namespace( self, - *, - backend: Any, path: str, - expected_content: str, + runtime: ToolRuntime[None, SurfSenseFilesystemState], ) -> str | None: - actual = self._read_for_verification_sync(backend, path) - if self._is_error_text(actual): - return f"Error: could not verify written file '{path}'." - if actual.rstrip() != expected_content.rstrip(): - return ( - "Error: file write verification failed; expected content was not fully written " - f"to '{path}'." - ) - return None + """Return an error string if cloud writes to ``path`` are not allowed. - async def _verify_written_content_async( - self, - *, - backend: Any, - path: str, - expected_content: str, - ) -> str | None: - actual = await self._read_for_verification_async(backend, path) - if self._is_error_text(actual): - return f"Error: could not verify written file '{path}'." - if actual.rstrip() != expected_content.rstrip(): - return ( - "Error: file write verification failed; expected content was not fully written " - f"to '{path}'." - ) - return None + Order matters: + 1. Reject writes to the anonymous read-only doc. + 2. Allow ``/documents/*``. + 3. Allow ``temp_*`` basename anywhere. + 4. Reject everything else. + """ + if not self._is_cloud(): + return None + anon = runtime.state.get("kb_anon_doc") or {} + if isinstance(anon, dict): + anon_path = str(anon.get("path") or "") + if anon_path and anon_path == path: + return "Error: the anonymous uploaded document is read-only." + if path.startswith(DOCUMENTS_ROOT + "/") or path == DOCUMENTS_ROOT: + return None + if _basename(path).startswith(_TEMP_PREFIX): + return None + return ( + "Error: cloud writes must target /documents/<...> or use a 'temp_' " + f"basename for scratch (got '{path}')." + ) - def _verify_edited_content_sync( - self, - *, - backend: Any, - path: str, - new_string: str, - ) -> tuple[str | None, str | None]: - updated_content = self._read_for_verification_sync(backend, path) - if self._is_error_text(updated_content): - return ( - f"Error: could not verify edited file '{path}'.", - None, - ) - if new_string and new_string not in updated_content: - return ( - "Error: edit verification failed; updated content was not found in " - f"'{path}'.", - None, - ) - return None, updated_content + # ------------------------------------------------------------------ tool: ls - async def _verify_edited_content_async( - self, - *, - backend: Any, - path: str, - new_string: str, - ) -> tuple[str | None, str | None]: - updated_content = await self._read_for_verification_async(backend, path) - if self._is_error_text(updated_content): - return ( - f"Error: could not verify edited file '{path}'.", - None, + def _create_ls_tool(self) -> BaseTool: + tool_description = ( + self._custom_tool_descriptions.get("ls") + or SURFSENSE_LIST_FILES_TOOL_DESCRIPTION + ) + + def sync_ls( + runtime: ToolRuntime[None, SurfSenseFilesystemState], + path: Annotated[ + str, + "Absolute path to the directory to list. Relative paths resolve against the current cwd.", + ] = "", + offset: Annotated[ + int, + "Number of entries to skip. Use for paginating large folders. Defaults to 0.", + ] = 0, + limit: Annotated[ + int, + "Maximum number of entries to return. Defaults to 200.", + ] = 200, + ) -> str: + return self._run_async_blocking( + async_ls(runtime, path=path, offset=offset, limit=limit) ) - if new_string and new_string not in updated_content: - return ( - "Error: edit verification failed; updated content was not found in " - f"'{path}'.", - None, + + async def async_ls( + runtime: ToolRuntime[None, SurfSenseFilesystemState], + path: Annotated[ + str, + "Absolute path to the directory to list. Relative paths resolve against the current cwd.", + ] = "", + offset: Annotated[ + int, + "Number of entries to skip. Use for paginating large folders. Defaults to 0.", + ] = 0, + limit: Annotated[ + int, + "Maximum number of entries to return. Defaults to 200.", + ] = 200, + ) -> str: + target = self._resolve_list_target_path(path, runtime) + try: + validated = validate_path(target) + except ValueError as exc: + return f"Error: {exc}" + if offset < 0: + offset = 0 + if limit < 1: + limit = 1 + backend = self._get_backend(runtime) + infos = await backend.als_info(validated) + page = paginate_listing(infos, offset=offset, limit=limit) + paths = [ + f"{fi.get('path', '')}/" if fi.get("is_dir") else fi.get("path", "") + for fi in page + ] + total = len(infos) + shown = len(page) + header = ( + f"{validated} ({shown} of {total} entries" + f"{f', offset={offset}' if offset else ''})" ) - return None, updated_content + if not paths: + return f"{header}\n(empty)" + body = "\n".join(paths) + if total > offset + shown: + body += ( + f"\n... {total - offset - shown} more — call ls(" + f"'{validated}', offset={offset + shown}, limit={limit})" + ) + return f"{header}\n{body}" + + return StructuredTool.from_function( + name="ls", + description=tool_description, + func=sync_ls, + coroutine=async_ls, + ) + + # ------------------------------------------------------------------ tool: read_file + + def _create_read_file_tool(self) -> BaseTool: + tool_description = ( + self._custom_tool_descriptions.get("read_file") + or SURFSENSE_READ_FILE_TOOL_DESCRIPTION + ) + + async def async_read_file( + file_path: Annotated[ + str, + "Absolute path to the file to read. Relative paths resolve against the current cwd.", + ], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + offset: Annotated[ + int, + "Line number to start reading from (0-indexed).", + ] = 0, + limit: Annotated[ + int, + "Maximum number of lines to read.", + ] = 100, + ) -> Command | str: + target = self._resolve_relative(file_path, runtime) + try: + validated = validate_path(target) + except ValueError as exc: + return f"Error: {exc}" + + files = runtime.state.get("files") or {} + if validated in files: + return format_read_response(files[validated], offset, limit) + + backend = self._get_backend(runtime) + if isinstance(backend, KBPostgresBackend): + loaded = await backend._load_file_data(validated) + if loaded is None: + return f"Error: File '{validated}' not found" + file_data, doc_id = loaded + rendered = format_read_response(file_data, offset, limit) + update: dict[str, Any] = { + "files": {validated: file_data}, + "messages": [ + ToolMessage( + content=rendered, + tool_call_id=runtime.tool_call_id, + ) + ], + } + if doc_id is not None: + update["doc_id_by_path"] = {validated: doc_id} + return Command(update=update) + + try: + rendered = await backend.aread(validated, offset=offset, limit=limit) + except Exception as exc: # pragma: no cover - defensive + return f"Error: {exc}" + return rendered + + def sync_read_file( + file_path: Annotated[ + str, + "Absolute path to the file to read. Relative paths resolve against the current cwd.", + ], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + offset: Annotated[ + int, + "Line number to start reading from (0-indexed).", + ] = 0, + limit: Annotated[ + int, + "Maximum number of lines to read.", + ] = 100, + ) -> Command | str: + return self._run_async_blocking( + async_read_file(file_path, runtime, offset, limit) + ) + + return StructuredTool.from_function( + name="read_file", + description=tool_description, + func=sync_read_file, + coroutine=async_read_file, + ) + + # ------------------------------------------------------------------ tool: write_file + + def _create_write_file_tool(self) -> BaseTool: + tool_description = ( + self._custom_tool_descriptions.get("write_file") + or SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION + ) + + async def async_write_file( + file_path: Annotated[ + str, + "Absolute path where the file should be created. Relative paths resolve against the current cwd.", + ], + content: Annotated[str, "Text content to write to the file."], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> Command | str: + target = self._resolve_write_target_path(file_path, runtime) + try: + validated = validate_path(target) + except ValueError as exc: + return f"Error: {exc}" + + namespace_error = self._check_cloud_write_namespace(validated, runtime) + if namespace_error: + return namespace_error + + backend = self._get_backend(runtime) + res: WriteResult = await backend.awrite(validated, content) + if res.error: + return res.error + + path = res.path or validated + files_update = res.files_update or {path: create_file_data(content)} + update: dict[str, Any] = { + "files": files_update, + "messages": [ + ToolMessage( + content=f"Updated file {path}", + tool_call_id=runtime.tool_call_id, + ) + ], + } + if self._is_cloud(): + update["dirty_paths"] = [path] + return Command(update=update) + + def sync_write_file( + file_path: Annotated[ + str, + "Absolute path where the file should be created. Relative paths resolve against the current cwd.", + ], + content: Annotated[str, "Text content to write to the file."], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> Command | str: + return self._run_async_blocking( + async_write_file(file_path, content, runtime) + ) + + return StructuredTool.from_function( + name="write_file", + description=tool_description, + func=sync_write_file, + coroutine=async_write_file, + ) + + # ------------------------------------------------------------------ tool: edit_file + + def _create_edit_file_tool(self) -> BaseTool: + tool_description = ( + self._custom_tool_descriptions.get("edit_file") + or SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION + ) + + async def async_edit_file( + file_path: Annotated[ + str, + "Absolute path to the file to edit. Relative paths resolve against the current cwd.", + ], + old_string: Annotated[ + str, + "Exact text to replace. Must be unique unless replace_all is True.", + ], + new_string: Annotated[ + str, + "Replacement text. Must differ from old_string.", + ], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + *, + replace_all: Annotated[ + bool, + "If True, replace all occurrences of old_string. Defaults to False.", + ] = False, + ) -> Command | str: + target = self._resolve_relative(file_path, runtime) + try: + validated = validate_path(target) + except ValueError as exc: + return f"Error: {exc}" + + namespace_error = self._check_cloud_write_namespace(validated, runtime) + if namespace_error: + return namespace_error + + backend = self._get_backend(runtime) + files_state = runtime.state.get("files") or {} + doc_id_to_attach: int | None = None + + if ( + self._is_cloud() + and validated not in files_state + and isinstance(backend, KBPostgresBackend) + ): + loaded = await backend._load_file_data(validated) + if loaded is None: + return f"Error: File '{validated}' not found" + _, doc_id_to_attach = loaded + + res: EditResult = await backend.aedit( + validated, old_string, new_string, replace_all=replace_all + ) + if res.error: + return res.error + + path = res.path or validated + files_update = res.files_update or {} + update: dict[str, Any] = { + "files": files_update, + "messages": [ + ToolMessage( + content=( + f"Successfully replaced {res.occurrences} instance(s) " + f"of the string in '{path}'" + ), + tool_call_id=runtime.tool_call_id, + ) + ], + } + if self._is_cloud(): + update["dirty_paths"] = [path] + if doc_id_to_attach is not None: + update["doc_id_by_path"] = {path: doc_id_to_attach} + return Command(update=update) + + def sync_edit_file( + file_path: Annotated[ + str, + "Absolute path to the file to edit. Relative paths resolve against the current cwd.", + ], + old_string: Annotated[ + str, + "Exact text to replace. Must be unique unless replace_all is True.", + ], + new_string: Annotated[ + str, + "Replacement text. Must differ from old_string.", + ], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + *, + replace_all: Annotated[ + bool, + "If True, replace all occurrences of old_string. Defaults to False.", + ] = False, + ) -> Command | str: + return self._run_async_blocking( + async_edit_file( + file_path, old_string, new_string, runtime, replace_all=replace_all + ) + ) + + return StructuredTool.from_function( + name="edit_file", + description=tool_description, + func=sync_edit_file, + coroutine=async_edit_file, + ) + + # ------------------------------------------------------------------ tool: mkdir + + def _create_mkdir_tool(self) -> BaseTool: + tool_description = ( + self._custom_tool_descriptions.get("mkdir") + or SURFSENSE_MKDIR_TOOL_DESCRIPTION + ) + + async def async_mkdir( + path: Annotated[str, "Absolute or relative directory path to create."], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> Command | str: + target = self._resolve_relative(path, runtime) + try: + validated = validate_path(target) + except ValueError as exc: + return f"Error: {exc}" + + if self._is_cloud(): + if not ( + validated.startswith(DOCUMENTS_ROOT + "/") + or validated == DOCUMENTS_ROOT + ): + return ( + "Error: cloud mkdir must target a path under /documents/ " + f"(got '{validated}')." + ) + return Command( + update={ + "staged_dirs": [validated], + "messages": [ + ToolMessage( + content=( + f"Staged directory '{validated}' (will be created " + "at end of turn)." + ), + tool_call_id=runtime.tool_call_id, + ) + ], + } + ) + + backend = self._get_backend(runtime) + local_method = getattr(backend, "amkdir", None) or getattr( + backend, "mkdir", None + ) + if callable(local_method): + try: + res = local_method(validated, parents=True, exist_ok=True) + if asyncio.iscoroutine(res): + await res + except TypeError: + res = local_method(validated) + if asyncio.iscoroutine(res): + await res + except Exception as exc: # pragma: no cover + return f"Error: {exc}" + return f"Created directory {validated}" + + def sync_mkdir( + path: Annotated[str, "Absolute or relative directory path to create."], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> Command | str: + return self._run_async_blocking(async_mkdir(path, runtime)) + + return StructuredTool.from_function( + name="mkdir", + description=tool_description, + func=sync_mkdir, + coroutine=async_mkdir, + ) + + # ------------------------------------------------------------------ tool: cd + + def _create_cd_tool(self) -> BaseTool: + tool_description = ( + self._custom_tool_descriptions.get("cd") or SURFSENSE_CD_TOOL_DESCRIPTION + ) + + async def async_cd( + path: Annotated[str, "Absolute or relative directory path to switch into."], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> Command | str: + target = self._resolve_relative(path, runtime) + try: + validated = validate_path(target) + except ValueError as exc: + return f"Error: {exc}" + + backend = self._get_backend(runtime) + try: + infos = await backend.als_info(validated) + except Exception as exc: # pragma: no cover - defensive + return f"Error: {exc}" + staged_dirs = list(runtime.state.get("staged_dirs") or []) + files = runtime.state.get("files") or {} + cwd_exists = ( + bool(infos) + or validated in staged_dirs + or any(p == validated for p in files) + or any( + isinstance(p, str) and p.startswith(validated.rstrip("/") + "/") + for p in files + ) + or validated == "/" + or validated == DOCUMENTS_ROOT + ) + if not cwd_exists: + return f"Error: directory '{validated}' not found." + return Command( + update={ + "cwd": validated, + "messages": [ + ToolMessage( + content=f"cwd changed to {validated}", + tool_call_id=runtime.tool_call_id, + ) + ], + } + ) + + def sync_cd( + path: Annotated[str, "Absolute or relative directory path to switch into."], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> Command | str: + return self._run_async_blocking(async_cd(path, runtime)) + + return StructuredTool.from_function( + name="cd", + description=tool_description, + func=sync_cd, + coroutine=async_cd, + ) + + # ------------------------------------------------------------------ tool: pwd + + def _create_pwd_tool(self) -> BaseTool: + tool_description = ( + self._custom_tool_descriptions.get("pwd") or SURFSENSE_PWD_TOOL_DESCRIPTION + ) + + def sync_pwd( + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> str: + return self._current_cwd(runtime) + + async def async_pwd( + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> str: + return self._current_cwd(runtime) + + return StructuredTool.from_function( + name="pwd", + description=tool_description, + func=sync_pwd, + coroutine=async_pwd, + ) + + # ------------------------------------------------------------------ tool: move_file def _create_move_file_tool(self) -> BaseTool: - """Create move_file for desktop local-folder mode.""" tool_description = ( self._custom_tool_descriptions.get("move_file") or SURFSENSE_MOVE_FILE_TOOL_DESCRIPTION ) - def sync_move_file( - source_path: Annotated[ - str, - "Absolute source path to move from.", - ], - destination_path: Annotated[ - str, - "Absolute destination path to move to.", - ], - runtime: ToolRuntime[None, FilesystemState], - *, - overwrite: Annotated[ - bool, - "If True, replace an existing destination file. Defaults to False.", - ] = False, - ) -> Command | str: - if self._filesystem_mode != FilesystemMode.DESKTOP_LOCAL_FOLDER: - return ( - "Error: move_file is only available in desktop local-folder mode." - ) - - if not source_path.strip() or not destination_path.strip(): - return "Error: source_path and destination_path are required." - - resolved_backend = self._get_backend(runtime) - source_target = self._resolve_move_target_path(source_path, runtime) - destination_target = self._resolve_move_target_path( - destination_path, runtime - ) - try: - validated_source = validate_path(source_target) - validated_destination = validate_path(destination_target) - except ValueError as exc: - return f"Error: {exc}" - res: WriteResult = resolved_backend.move( - validated_source, - validated_destination, - overwrite=overwrite, - ) - if res.error: - return res.error - if res.files_update is not None: - return Command( - update={ - "files": res.files_update, - "messages": [ - ToolMessage( - content=( - f"Moved '{validated_source}' to " - f"'{res.path or validated_destination}'" - ), - tool_call_id=runtime.tool_call_id, - ) - ], - } - ) - return ( - f"Moved '{validated_source}' to '{res.path or validated_destination}'" - ) - async def async_move_file( - source_path: Annotated[ - str, - "Absolute source path to move from.", - ], - destination_path: Annotated[ - str, - "Absolute destination path to move to.", - ], - runtime: ToolRuntime[None, FilesystemState], + source_path: Annotated[str, "Absolute or relative source path."], + destination_path: Annotated[str, "Absolute or relative destination path."], + runtime: ToolRuntime[None, SurfSenseFilesystemState], *, overwrite: Annotated[ bool, - "If True, replace an existing destination file. Defaults to False.", + "If True, replace existing destination. Cloud mode rejects True. Defaults to False.", ] = False, ) -> Command | str: - if self._filesystem_mode != FilesystemMode.DESKTOP_LOCAL_FOLDER: - return ( - "Error: move_file is only available in desktop local-folder mode." - ) - if not source_path.strip() or not destination_path.strip(): return "Error: source_path and destination_path are required." - resolved_backend = self._get_backend(runtime) - source_target = self._resolve_move_target_path(source_path, runtime) - destination_target = self._resolve_move_target_path( - destination_path, runtime - ) + source = self._resolve_move_target_path(source_path, runtime) + dest = self._resolve_move_target_path(destination_path, runtime) try: - validated_source = validate_path(source_target) - validated_destination = validate_path(destination_target) + validated_source = validate_path(source) + validated_dest = validate_path(dest) except ValueError as exc: return f"Error: {exc}" - res: WriteResult = await resolved_backend.amove( - validated_source, - validated_destination, - overwrite=overwrite, + + if self._is_cloud(): + return await self._cloud_move_file( + runtime, + validated_source, + validated_dest, + overwrite=overwrite, + ) + + backend = self._get_backend(runtime) + res: WriteResult = await backend.amove( + validated_source, validated_dest, overwrite=overwrite ) if res.error: return res.error + update: dict[str, Any] = { + "messages": [ + ToolMessage( + content=f"Moved '{validated_source}' to '{res.path or validated_dest}'", + tool_call_id=runtime.tool_call_id, + ) + ], + } if res.files_update is not None: - return Command( - update={ - "files": res.files_update, - "messages": [ - ToolMessage( - content=( - f"Moved '{validated_source}' to " - f"'{res.path or validated_destination}'" - ), - tool_call_id=runtime.tool_call_id, - ) - ], - } + update["files"] = res.files_update + return Command(update=update) + + def sync_move_file( + source_path: Annotated[str, "Absolute or relative source path."], + destination_path: Annotated[str, "Absolute or relative destination path."], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + *, + overwrite: Annotated[ + bool, + "If True, replace existing destination. Cloud mode rejects True. Defaults to False.", + ] = False, + ) -> Command | str: + return self._run_async_blocking( + async_move_file( + source_path, destination_path, runtime, overwrite=overwrite ) - return ( - f"Moved '{validated_source}' to '{res.path or validated_destination}'" ) return StructuredTool.from_function( @@ -1183,95 +1309,112 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): coroutine=async_move_file, ) + async def _cloud_move_file( + self, + runtime: ToolRuntime[None, SurfSenseFilesystemState], + source: str, + dest: str, + *, + overwrite: bool, + ) -> Command | str: + backend = self._get_backend(runtime) + if not isinstance(backend, KBPostgresBackend): + return "Error: cloud move requires KBPostgresBackend." + + if source == dest: + return f"Moved '{source}' to '{dest}' (no-op)" + if overwrite: + return ( + "Error: overwrite=True is not supported in cloud mode. Move/edit " + "the destination doc explicitly first." + ) + if not source.startswith(DOCUMENTS_ROOT + "/"): + return ( + "Error: cloud move_file source must be under /documents/ (got " + f"'{source}')." + ) + if not dest.startswith(DOCUMENTS_ROOT + "/"): + return ( + "Error: cloud move_file destination must be under /documents/ (got " + f"'{dest}')." + ) + anon = runtime.state.get("kb_anon_doc") or {} + if isinstance(anon, dict): + anon_path = str(anon.get("path") or "") + if anon_path and (anon_path in (source, dest)): + return "Error: the anonymous uploaded document is read-only." + + files = runtime.state.get("files") or {} + doc_id_by_path = runtime.state.get("doc_id_by_path") or {} + pending_moves = list(runtime.state.get("pending_moves") or []) + + # Dest collision: occupied in state, in pending moves, or in DB. + if dest in files: + return f"Error: destination '{dest}' already exists." + if any(move.get("dest") == dest for move in pending_moves): + return f"Error: destination '{dest}' already exists." + if dest != source: + existing_dest = await backend._load_file_data(dest) + if existing_dest is not None: + return f"Error: destination '{dest}' already exists." + + # Source materialization: lazy load if not in state. + source_file_data = files.get(source) + source_doc_id = doc_id_by_path.get(source) + if source_file_data is None: + loaded = await backend._load_file_data(source) + if loaded is None: + return f"Error: source '{source}' not found." + source_file_data, loaded_doc_id = loaded + if source_doc_id is None: + source_doc_id = loaded_doc_id + + files_update: dict[str, Any] = {source: None, dest: source_file_data} + update: dict[str, Any] = { + "files": files_update, + "pending_moves": [{"source": source, "dest": dest, "overwrite": False}], + "messages": [ + ToolMessage( + content=( + f"Moved '{source}' to '{dest}' (will commit at end of turn)." + ), + tool_call_id=runtime.tool_call_id, + ) + ], + } + + doc_id_update: dict[str, int | None] = {source: None} + if source_doc_id is not None: + doc_id_update[dest] = source_doc_id + update["doc_id_by_path"] = doc_id_update + + dirty_paths = list(runtime.state.get("dirty_paths") or []) + if source in dirty_paths: + new_dirty: list[Any] = [_CLEAR] + for entry in dirty_paths: + new_dirty.append(dest if entry == source else entry) + update["dirty_paths"] = new_dirty + return Command(update=update) + + # ------------------------------------------------------------------ tool: list_tree + def _create_list_tree_tool(self) -> BaseTool: - """Create list_tree for desktop local-folder mode.""" tool_description = ( self._custom_tool_descriptions.get("list_tree") or SURFSENSE_LIST_TREE_TOOL_DESCRIPTION ) - def sync_list_tree( - runtime: ToolRuntime[None, FilesystemState], - *, - path: Annotated[ - str, - "Absolute path to list from. Use '/' for mount roots.", - ] = "/", - max_depth: Annotated[ - int, - "Maximum recursion depth to traverse. Defaults to 8.", - ] = 8, - page_size: Annotated[ - int, - "Maximum number of entries to return. Defaults to 500 (max 1000).", - ] = 500, - include_files: Annotated[ - bool, - "Whether file entries should be included.", - ] = True, - include_dirs: Annotated[ - bool, - "Whether directory entries should be included.", - ] = True, - ) -> str: - if self._filesystem_mode != FilesystemMode.DESKTOP_LOCAL_FOLDER: - return ( - "Error: list_tree is only available in desktop local-folder mode." - ) - if max_depth < 0: - return "Error: max_depth must be >= 0." - if page_size < 1: - return "Error: page_size must be >= 1." - if not include_files and not include_dirs: - return "Error: include_files and include_dirs cannot both be false." - - resolved_backend = self._get_backend(runtime) - target_path = self._resolve_list_target_path(path, runtime) - try: - validated_path = validate_path(target_path) - except ValueError as exc: - return f"Error: {exc}" - - result = resolved_backend.list_tree( - validated_path, - max_depth=max_depth, - page_size=page_size, - include_files=include_files, - include_dirs=include_dirs, - ) - error = result.get("error") if isinstance(result, dict) else None - if isinstance(error, str) and error: - return error - return json.dumps(result, ensure_ascii=True) - async def async_list_tree( - runtime: ToolRuntime[None, FilesystemState], - *, + runtime: ToolRuntime[None, SurfSenseFilesystemState], path: Annotated[ str, - "Absolute path to list from. Use '/' for mount roots.", - ] = "/", - max_depth: Annotated[ - int, - "Maximum recursion depth to traverse. Defaults to 8.", - ] = 8, - page_size: Annotated[ - int, - "Maximum number of entries to return. Defaults to 500 (max 1000).", - ] = 500, - include_files: Annotated[ - bool, - "Whether file entries should be included.", - ] = True, - include_dirs: Annotated[ - bool, - "Whether directory entries should be included.", - ] = True, + "Absolute path to start from. Defaults to /documents in cloud mode.", + ] = "", + max_depth: Annotated[int, "Recursion depth limit. Default 8."] = 8, + page_size: Annotated[int, "Maximum entries returned. Max 1000."] = 500, + include_files: Annotated[bool, "Include file entries."] = True, + include_dirs: Annotated[bool, "Include directory entries."] = True, ) -> str: - if self._filesystem_mode != FilesystemMode.DESKTOP_LOCAL_FOLDER: - return ( - "Error: list_tree is only available in desktop local-folder mode." - ) if max_depth < 0: return "Error: max_depth must be >= 0." if page_size < 1: @@ -1279,25 +1422,58 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): if not include_files and not include_dirs: return "Error: include_files and include_dirs cannot both be false." - resolved_backend = self._get_backend(runtime) - target_path = self._resolve_list_target_path(path, runtime) + target = self._resolve_list_target_path(path, runtime) try: - validated_path = validate_path(target_path) + validated = validate_path(target) except ValueError as exc: return f"Error: {exc}" - result = await resolved_backend.alist_tree( - validated_path, - max_depth=max_depth, - page_size=page_size, - include_files=include_files, - include_dirs=include_dirs, - ) - error = result.get("error") if isinstance(result, dict) else None - if isinstance(error, str) and error: - return error + backend = self._get_backend(runtime) + if isinstance(backend, KBPostgresBackend): + result = await backend.alist_tree_listing( + validated, + max_depth=max_depth, + page_size=page_size, + include_files=include_files, + include_dirs=include_dirs, + ) + elif hasattr(backend, "alist_tree"): + result = await backend.alist_tree( + validated, + max_depth=max_depth, + page_size=page_size, + include_files=include_files, + include_dirs=include_dirs, + ) + else: + return "Error: list_tree is not supported by the active backend." + + if isinstance(result, dict) and isinstance(result.get("error"), str): + return result["error"] return json.dumps(result, ensure_ascii=True) + def sync_list_tree( + runtime: ToolRuntime[None, SurfSenseFilesystemState], + path: Annotated[ + str, + "Absolute path to start from. Defaults to /documents in cloud mode.", + ] = "", + max_depth: Annotated[int, "Recursion depth limit. Default 8."] = 8, + page_size: Annotated[int, "Maximum entries returned. Max 1000."] = 500, + include_files: Annotated[bool, "Include file entries."] = True, + include_dirs: Annotated[bool, "Include directory entries."] = True, + ) -> str: + return self._run_async_blocking( + async_list_tree( + runtime, + path=path, + max_depth=max_depth, + page_size=page_size, + include_files=include_files, + include_dirs=include_dirs, + ) + ) + return StructuredTool.from_function( name="list_tree", description=tool_description, @@ -1305,162 +1481,103 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): coroutine=async_list_tree, ) - def _create_edit_file_tool(self) -> BaseTool: - """Create edit_file with DB persistence (skipped for KB documents).""" - tool_description = ( - self._custom_tool_descriptions.get("edit_file") - or SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION - ) + # ------------------------------------------------------------------ tool: execute_code (sandbox) - def sync_edit_file( - file_path: Annotated[ - str, - "Absolute path to the file to edit. Must be absolute, not relative.", + def _create_execute_code_tool(self) -> BaseTool: + def sync_execute_code( + command: Annotated[ + str, "Python code to execute. Use print() to see output." ], - old_string: Annotated[ - str, - "The exact text to find and replace. Must be unique in the file unless replace_all is True.", - ], - new_string: Annotated[ - str, - "The text to replace old_string with. Must be different from old_string.", - ], - runtime: ToolRuntime[None, FilesystemState], - *, - replace_all: Annotated[ - bool, - "If True, replace all occurrences of old_string. If False (default), old_string must be unique.", - ] = False, - ) -> Command | str: - resolved_backend = self._get_backend(runtime) - target_path = self._resolve_write_target_path(file_path, runtime) - try: - validated_path = validate_path(target_path) - except ValueError as exc: - return f"Error: {exc}" - res: EditResult = resolved_backend.edit( - validated_path, - old_string, - new_string, - replace_all=replace_all, + runtime: ToolRuntime[None, SurfSenseFilesystemState], + timeout: Annotated[ + int | None, + "Optional timeout in seconds.", + ] = None, + ) -> str: + if timeout is not None: + if timeout < 0: + return f"Error: timeout must be non-negative, got {timeout}." + if timeout > self._MAX_EXECUTE_TIMEOUT: + return f"Error: timeout {timeout}s exceeds maximum ({self._MAX_EXECUTE_TIMEOUT}s)." + return self._run_async_blocking( + self._execute_in_sandbox(command, runtime, timeout) ) - if res.error: - return res.error - verify_error, updated_content = self._verify_edited_content_sync( - backend=resolved_backend, - path=validated_path, - new_string=new_string, - ) - if verify_error: - return verify_error - - if self._should_persist_documents() and not self._is_kb_document( - validated_path - ): - if updated_content is None: - return ( - f"Error: could not reload edited file '{validated_path}' for " - "persistence." - ) - persist_result = self._run_async_blocking( - self._persist_edited_document( - file_path=validated_path, - updated_content=updated_content, - ) - ) - if isinstance(persist_result, str): - return persist_result - - if res.files_update is not None: - return Command( - update={ - "files": res.files_update, - "messages": [ - ToolMessage( - content=f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'", - tool_call_id=runtime.tool_call_id, - ) - ], - } - ) - return f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'" - - async def async_edit_file( - file_path: Annotated[ - str, - "Absolute path to the file to edit. Must be absolute, not relative.", + async def async_execute_code( + command: Annotated[ + str, "Python code to execute. Use print() to see output." ], - old_string: Annotated[ - str, - "The exact text to find and replace. Must be unique in the file unless replace_all is True.", - ], - new_string: Annotated[ - str, - "The text to replace old_string with. Must be different from old_string.", - ], - runtime: ToolRuntime[None, FilesystemState], - *, - replace_all: Annotated[ - bool, - "If True, replace all occurrences of old_string. If False (default), old_string must be unique.", - ] = False, - ) -> Command | str: - resolved_backend = self._get_backend(runtime) - target_path = self._resolve_write_target_path(file_path, runtime) - try: - validated_path = validate_path(target_path) - except ValueError as exc: - return f"Error: {exc}" - res: EditResult = await resolved_backend.aedit( - validated_path, - old_string, - new_string, - replace_all=replace_all, - ) - if res.error: - return res.error - - verify_error, updated_content = await self._verify_edited_content_async( - backend=resolved_backend, - path=validated_path, - new_string=new_string, - ) - if verify_error: - return verify_error - - if self._should_persist_documents() and not self._is_kb_document( - validated_path - ): - if updated_content is None: - return ( - f"Error: could not reload edited file '{validated_path}' for " - "persistence." - ) - persist_error = await self._persist_edited_document( - file_path=validated_path, - updated_content=updated_content, - ) - if persist_error: - return persist_error - - if res.files_update is not None: - return Command( - update={ - "files": res.files_update, - "messages": [ - ToolMessage( - content=f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'", - tool_call_id=runtime.tool_call_id, - ) - ], - } - ) - return f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'" + runtime: ToolRuntime[None, SurfSenseFilesystemState], + timeout: Annotated[ + int | None, + "Optional timeout in seconds.", + ] = None, + ) -> str: + if timeout is not None: + if timeout < 0: + return f"Error: timeout must be non-negative, got {timeout}." + if timeout > self._MAX_EXECUTE_TIMEOUT: + return f"Error: timeout {timeout}s exceeds maximum ({self._MAX_EXECUTE_TIMEOUT}s)." + return await self._execute_in_sandbox(command, runtime, timeout) return StructuredTool.from_function( - name="edit_file", - description=tool_description, - func=sync_edit_file, - coroutine=async_edit_file, + name="execute_code", + description=SURFSENSE_EXECUTE_CODE_TOOL_DESCRIPTION, + func=sync_execute_code, + coroutine=async_execute_code, ) + + @staticmethod + def _wrap_as_python(code: str) -> str: + sentinel = f"_PYEOF_{secrets.token_hex(8)}" + return f"python3 << '{sentinel}'\n{code}\n{sentinel}" + + async def _execute_in_sandbox( + self, + command: str, + runtime: ToolRuntime[None, SurfSenseFilesystemState], + timeout: int | None, + ) -> str: + assert self._thread_id is not None + command = self._wrap_as_python(command) + try: + return await self._try_sandbox_execute(command, runtime, timeout) + except (DaytonaError, Exception) as first_err: + logger.warning( + "Sandbox execute failed for thread %s, retrying: %s", + self._thread_id, + first_err, + ) + try: + await delete_sandbox(self._thread_id) + except Exception: + _evict_sandbox_cache(self._thread_id) + try: + return await self._try_sandbox_execute(command, runtime, timeout) + except Exception: + logger.exception( + "Sandbox retry also failed for thread %s", self._thread_id + ) + return "Error: Code execution is temporarily unavailable. Please try again." + + async def _try_sandbox_execute( + self, + command: str, + runtime: ToolRuntime[None, SurfSenseFilesystemState], + timeout: int | None, + ) -> str: + sandbox, _is_new = await get_or_create_sandbox(self._thread_id) + result = await sandbox.aexecute(command, timeout=timeout) + output = (result.output or "").strip() + if not output and result.exit_code == 0: + return ( + "[Code executed successfully but produced no output. " + "Use print() to display results, then try again.]" + ) + parts = [result.output] + if result.exit_code is not None: + status = "succeeded" if result.exit_code == 0 else "failed" + parts.append(f"\n[Command {status} with exit code {result.exit_code}]") + if result.truncated: + parts.append("\n[Output was truncated due to size limits]") + return "".join(parts) diff --git a/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py b/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py new file mode 100644 index 000000000..5682977d9 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py @@ -0,0 +1,622 @@ +"""End-of-turn persistence for the cloud-mode SurfSense filesystem. + +This middleware runs ``aafter_agent`` once per turn (cloud only). It commits +all staged folder creations, file moves, and content writes/edits to +Postgres in a single ordered pass: + +1. Materialize ``staged_dirs`` into ``Folder`` rows. +2. Apply ``pending_moves`` in order (chained moves resolved via + ``doc_id_by_path``). +3. Normalize ``dirty_paths`` through ``pending_moves`` so write-then-move + sequences commit at the final path. +4. Commit content writes / edits for ``/documents/*`` paths, skipping + ``temp_*`` basenames. + +The commit body is exposed as a free function ``commit_staged_filesystem_state`` +so the optional stream-task fallback (``stream_new_chat.py``) can call the +exact same routine when ``aafter_agent`` was skipped (e.g. client disconnect). +""" + +from __future__ import annotations + +import logging +from datetime import UTC, datetime +from typing import Any + +from fractional_indexing import generate_key_between +from langchain.agents.middleware import AgentMiddleware, AgentState +from langchain_core.callbacks import dispatch_custom_event +from langgraph.runtime import Runtime +from sqlalchemy import delete, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState +from app.agents.new_chat.path_resolver import ( + DOCUMENTS_ROOT, + parse_documents_path, + safe_folder_segment, + virtual_path_to_doc, +) +from app.agents.new_chat.state_reducers import _CLEAR +from app.db import ( + Chunk, + Document, + DocumentType, + Folder, + shielded_async_session, +) +from app.indexing_pipeline.document_chunker import chunk_text +from app.utils.document_converters import ( + embed_texts, + generate_content_hash, + generate_unique_identifier_hash, +) + +logger = logging.getLogger(__name__) + + +_TEMP_PREFIX = "temp_" + + +def _basename(path: str) -> str: + return path.rsplit("/", 1)[-1] + + +# --------------------------------------------------------------------------- +# Folder helpers +# --------------------------------------------------------------------------- + + +async def _ensure_folder_hierarchy( + session: AsyncSession, + *, + search_space_id: int, + created_by_id: str | None, + folder_parts: list[str], +) -> int | None: + """Ensure a chain of folder names exists under the search space. + + Returns the leaf folder id, or ``None`` if ``folder_parts`` is empty + (i.e. a document directly under ``/documents/``). + """ + if not folder_parts: + return None + parent_id: int | None = None + for raw in folder_parts: + name = safe_folder_segment(str(raw)) + query = select(Folder).where( + Folder.search_space_id == search_space_id, + Folder.name == name, + ) + if parent_id is None: + query = query.where(Folder.parent_id.is_(None)) + else: + query = query.where(Folder.parent_id == parent_id) + result = await session.execute(query) + folder = result.scalar_one_or_none() + if folder is None: + sibling_query = ( + select(Folder.position).order_by(Folder.position.desc()).limit(1) + ) + sibling_query = sibling_query.where( + Folder.search_space_id == search_space_id + ) + if parent_id is None: + sibling_query = sibling_query.where(Folder.parent_id.is_(None)) + else: + sibling_query = sibling_query.where(Folder.parent_id == parent_id) + sibling_result = await session.execute(sibling_query) + last_position = sibling_result.scalar_one_or_none() + folder = Folder( + name=name, + position=generate_key_between(last_position, None), + parent_id=parent_id, + search_space_id=search_space_id, + created_by_id=created_by_id, + updated_at=datetime.now(UTC), + ) + session.add(folder) + await session.flush() + parent_id = folder.id + return parent_id + + +# --------------------------------------------------------------------------- +# Document helpers +# --------------------------------------------------------------------------- + + +async def _create_document( + session: AsyncSession, + *, + virtual_path: str, + content: str, + search_space_id: int, + created_by_id: str | None, +) -> Document: + """Create a NOTE Document + Chunks for ``virtual_path``.""" + folder_parts, title = parse_documents_path(virtual_path) + if not title: + raise ValueError(f"invalid /documents path '{virtual_path}'") + folder_id = await _ensure_folder_hierarchy( + session, + search_space_id=search_space_id, + created_by_id=created_by_id, + folder_parts=folder_parts, + ) + unique_identifier_hash = generate_unique_identifier_hash( + DocumentType.NOTE, + virtual_path, + search_space_id, + ) + # Guard against the unique_identifier_hash constraint: another row at the + # same virtual_path (this search space) already owns the hash. Callers are + # expected to upsert via the wrapper, but this defends against bypasses + # and gives a clean ValueError instead of a session-poisoning IntegrityError. + path_collision = await session.execute( + select(Document.id).where( + Document.search_space_id == search_space_id, + Document.unique_identifier_hash == unique_identifier_hash, + ) + ) + if path_collision.scalar_one_or_none() is not None: + raise ValueError( + f"a document already exists at path '{virtual_path}' " + "(unique_identifier_hash collision)" + ) + content_hash = generate_content_hash(content, search_space_id) + content_collision = await session.execute( + select(Document.id).where( + Document.search_space_id == search_space_id, + Document.content_hash == content_hash, + ) + ) + if content_collision.scalar_one_or_none() is not None: + raise ValueError( + f"a document with identical content already exists for path '{virtual_path}'" + ) + doc = Document( + title=title, + document_type=DocumentType.NOTE, + document_metadata={"virtual_path": virtual_path}, + content=content, + content_hash=content_hash, + unique_identifier_hash=unique_identifier_hash, + source_markdown=content, + search_space_id=search_space_id, + folder_id=folder_id, + created_by_id=created_by_id, + updated_at=datetime.now(UTC), + ) + session.add(doc) + await session.flush() + + summary_embedding = embed_texts([content])[0] + doc.embedding = summary_embedding + chunks = chunk_text(content) + if chunks: + chunk_embeddings = embed_texts(chunks) + session.add_all( + [ + Chunk(document_id=doc.id, content=text, embedding=embedding) + for text, embedding in zip(chunks, chunk_embeddings, strict=True) + ] + ) + return doc + + +async def _update_document( + session: AsyncSession, + *, + doc_id: int, + content: str, + virtual_path: str, + search_space_id: int, +) -> Document | None: + """Update an existing Document's content + chunks.""" + result = await session.execute( + select(Document).where( + Document.id == doc_id, + Document.search_space_id == search_space_id, + ) + ) + document = result.scalar_one_or_none() + if document is None: + return None + + document.content = content + document.source_markdown = content + document.content_hash = generate_content_hash(content, search_space_id) + document.updated_at = datetime.now(UTC) + metadata = dict(document.document_metadata or {}) + metadata["virtual_path"] = virtual_path + document.document_metadata = metadata + document.unique_identifier_hash = generate_unique_identifier_hash( + DocumentType.NOTE, + virtual_path, + search_space_id, + ) + + summary_embedding = embed_texts([content])[0] + document.embedding = summary_embedding + + await session.execute(delete(Chunk).where(Chunk.document_id == document.id)) + chunks = chunk_text(content) + if chunks: + chunk_embeddings = embed_texts(chunks) + session.add_all( + [ + Chunk(document_id=document.id, content=text, embedding=embedding) + for text, embedding in zip(chunks, chunk_embeddings, strict=True) + ] + ) + return document + + +# --------------------------------------------------------------------------- +# Move helpers +# --------------------------------------------------------------------------- + + +async def _apply_move( + session: AsyncSession, + *, + search_space_id: int, + created_by_id: str | None, + move: dict[str, Any], + doc_id_by_path: dict[str, int], + doc_id_path_tombstones: dict[str, int | None], +) -> dict[str, Any] | None: + """Apply a single staged move; updates the in-memory mapping for chain resolution.""" + source = str(move.get("source") or "") + dest = str(move.get("dest") or "") + if not source or not dest or source == dest: + return None + + if not source.startswith(DOCUMENTS_ROOT + "/") or not dest.startswith( + DOCUMENTS_ROOT + "/" + ): + return None + + doc_id: int | None = doc_id_by_path.get(source) + document: Document | None = None + if doc_id is not None: + result = await session.execute( + select(Document).where( + Document.id == doc_id, + Document.search_space_id == search_space_id, + ) + ) + document = result.scalar_one_or_none() + if document is None: + document = await virtual_path_to_doc( + session, + search_space_id=search_space_id, + virtual_path=source, + ) + if document is None: + logger.info( + "kb_persistence: skipping move %s -> %s (source not found)", + source, + dest, + ) + return None + + folder_parts, new_title = parse_documents_path(dest) + if not new_title: + return None + folder_id = await _ensure_folder_hierarchy( + session, + search_space_id=search_space_id, + created_by_id=created_by_id, + folder_parts=folder_parts, + ) + + document.title = new_title + document.folder_id = folder_id + metadata = dict(document.document_metadata or {}) + metadata["virtual_path"] = dest + document.document_metadata = metadata + document.unique_identifier_hash = generate_unique_identifier_hash( + DocumentType.NOTE, + dest, + search_space_id, + ) + document.updated_at = datetime.now(UTC) + + doc_id_by_path.pop(source, None) + doc_id_by_path[dest] = document.id + doc_id_path_tombstones[source] = None + doc_id_path_tombstones[dest] = document.id + return {"id": document.id, "source": source, "dest": dest, "title": new_title} + + +# --------------------------------------------------------------------------- +# Commit body +# --------------------------------------------------------------------------- + + +async def commit_staged_filesystem_state( + state: dict[str, Any] | AgentState, + *, + search_space_id: int, + created_by_id: str | None, + filesystem_mode: FilesystemMode = FilesystemMode.CLOUD, + dispatch_events: bool = True, +) -> dict[str, Any] | None: + """Commit all staged filesystem changes; return the state delta for reducers. + + Shared between :class:`KnowledgeBasePersistenceMiddleware.aafter_agent` + and the optional stream-task fallback. + """ + if filesystem_mode != FilesystemMode.CLOUD: + return None + + state_dict: dict[str, Any] = ( + dict(state) + if isinstance(state, dict) + else dict(getattr(state, "values", {}) or {}) + ) + + files: dict[str, Any] = state_dict.get("files") or {} + staged_dirs: list[str] = list(state_dict.get("staged_dirs") or []) + pending_moves: list[dict[str, Any]] = list(state_dict.get("pending_moves") or []) + dirty_paths: list[str] = list(state_dict.get("dirty_paths") or []) + doc_id_by_path: dict[str, int] = dict(state_dict.get("doc_id_by_path") or {}) + kb_anon_doc = state_dict.get("kb_anon_doc") + + if kb_anon_doc: + temp_paths = [ + p + for p in files + if isinstance(p, str) and _basename(p).startswith(_TEMP_PREFIX) + ] + return { + "dirty_paths": [_CLEAR], + "staged_dirs": [_CLEAR], + "pending_moves": [_CLEAR], + "files": dict.fromkeys(temp_paths), + } + + if not (staged_dirs or pending_moves or dirty_paths): + return None + + committed_creates: list[dict[str, Any]] = [] + committed_updates: list[dict[str, Any]] = [] + discarded: list[str] = [] + applied_moves: list[dict[str, Any]] = [] + doc_id_path_tombstones: dict[str, int | None] = {} + tree_changed = False + + try: + async with shielded_async_session() as session: + for folder_path in staged_dirs: + if not isinstance(folder_path, str): + continue + if not folder_path.startswith(DOCUMENTS_ROOT): + continue + rel = folder_path[len(DOCUMENTS_ROOT) :].strip("/") + folder_parts_full = [p for p in rel.split("/") if p] + if not folder_parts_full: + continue + await _ensure_folder_hierarchy( + session, + search_space_id=search_space_id, + created_by_id=created_by_id, + folder_parts=folder_parts_full, + ) + tree_changed = True + + for move in pending_moves: + applied = await _apply_move( + session, + search_space_id=search_space_id, + created_by_id=created_by_id, + move=move, + doc_id_by_path=doc_id_by_path, + doc_id_path_tombstones=doc_id_path_tombstones, + ) + if applied: + applied_moves.append(applied) + tree_changed = True + + move_alias = { + m["source"]: m["dest"] for m in pending_moves if m.get("source") + } + + def _final_path(path: str) -> str: + seen: set[str] = set() + while path in move_alias and path not in seen: + seen.add(path) + path = move_alias[path] + return path + + kb_dirty_seen: set[str] = set() + kb_dirty: list[str] = [] + for raw in dirty_paths: + if not isinstance(raw, str): + continue + final = _final_path(raw) + if not final.startswith(DOCUMENTS_ROOT + "/"): + continue + if final in kb_dirty_seen: + continue + kb_dirty_seen.add(final) + kb_dirty.append(final) + + for path in kb_dirty: + basename = _basename(path) + if basename.startswith(_TEMP_PREFIX): + discarded.append(path) + continue + file_data = files.get(path) + if not isinstance(file_data, dict): + continue + content = "\n".join(file_data.get("content") or []) + doc_id = doc_id_by_path.get(path) + if doc_id is None: + # The in-memory ``doc_id_by_path`` is per-thread and starts + # empty in every new chat. If the agent writes to a path + # that already exists in the DB (e.g. a previous chat's + # ``notes.md``), we must NOT try to INSERT — it would hit + # ``unique_identifier_hash`` (path-derived). Look up the + # existing doc and update it in place instead. + existing = await virtual_path_to_doc( + session, + search_space_id=search_space_id, + virtual_path=path, + ) + if existing is not None: + doc_id = existing.id + doc_id_by_path[path] = existing.id + if doc_id is not None: + updated = await _update_document( + session, + doc_id=doc_id, + content=content, + virtual_path=path, + search_space_id=search_space_id, + ) + if updated is not None: + committed_updates.append( + { + "id": updated.id, + "title": updated.title, + "documentType": DocumentType.NOTE.value, + "searchSpaceId": search_space_id, + "folderId": updated.folder_id, + "createdById": str(created_by_id) + if created_by_id + else None, + "virtualPath": path, + } + ) + else: + try: + new_doc = await _create_document( + session, + virtual_path=path, + content=content, + search_space_id=search_space_id, + created_by_id=created_by_id, + ) + except ValueError as exc: + logger.warning( + "kb_persistence: skipping %s create: %s", path, exc + ) + continue + doc_id_by_path[path] = new_doc.id + committed_creates.append( + { + "id": new_doc.id, + "title": new_doc.title, + "documentType": DocumentType.NOTE.value, + "searchSpaceId": search_space_id, + "folderId": new_doc.folder_id, + "createdById": str(created_by_id) + if created_by_id + else None, + "virtualPath": path, + } + ) + tree_changed = True + + await session.commit() + except Exception: # pragma: no cover - rollback safety net + logger.exception( + "kb_persistence: commit failed (search_space=%s)", search_space_id + ) + return None + + if dispatch_events: + for payload in committed_creates: + try: + dispatch_custom_event("document_created", payload) + except Exception: + logger.exception( + "kb_persistence: failed to dispatch document_created event" + ) + for payload in committed_updates: + try: + dispatch_custom_event("document_updated", payload) + except Exception: + logger.exception( + "kb_persistence: failed to dispatch document_updated event" + ) + + temp_paths = [ + p for p in files if isinstance(p, str) and _basename(p).startswith(_TEMP_PREFIX) + ] + + doc_id_update: dict[str, int | None] = {**doc_id_path_tombstones} + for payload in committed_creates: + doc_id_update[str(payload.get("virtualPath") or "")] = int(payload["id"]) + + delta: dict[str, Any] = { + "dirty_paths": [_CLEAR], + "staged_dirs": [_CLEAR], + "pending_moves": [_CLEAR], + } + if temp_paths: + delta["files"] = dict.fromkeys(temp_paths) + if doc_id_update: + delta["doc_id_by_path"] = doc_id_update + if tree_changed: + delta["tree_version"] = int(state_dict.get("tree_version") or 0) + 1 + + logger.info( + "kb_persistence: commit (search_space=%s) creates=%d updates=%d " + "moves=%d staged_dirs=%d discarded=%d", + search_space_id, + len(committed_creates), + len(committed_updates), + len(applied_moves), + len(staged_dirs), + len(discarded), + ) + return delta + + +# --------------------------------------------------------------------------- +# Middleware +# --------------------------------------------------------------------------- + + +class KnowledgeBasePersistenceMiddleware(AgentMiddleware): # type: ignore[type-arg] + """End-of-turn cloud persistence for the SurfSense filesystem agent.""" + + tools = () + state_schema = SurfSenseFilesystemState + + def __init__( + self, + *, + search_space_id: int, + created_by_id: str | None, + filesystem_mode: FilesystemMode, + ) -> None: + self.search_space_id = search_space_id + self.created_by_id = created_by_id + self.filesystem_mode = filesystem_mode + + async def aafter_agent( # type: ignore[override] + self, + state: AgentState, + runtime: Runtime[Any], + ) -> dict[str, Any] | None: + del runtime + if self.filesystem_mode != FilesystemMode.CLOUD: + return None + return await commit_staged_filesystem_state( + state, + search_space_id=self.search_space_id, + created_by_id=self.created_by_id, + filesystem_mode=self.filesystem_mode, + ) + + +__all__ = [ + "KnowledgeBasePersistenceMiddleware", + "commit_staged_filesystem_state", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/kb_postgres_backend.py b/surfsense_backend/app/agents/new_chat/middleware/kb_postgres_backend.py new file mode 100644 index 000000000..ddb2d4af1 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/kb_postgres_backend.py @@ -0,0 +1,963 @@ +"""Postgres-backed virtual filesystem for the SurfSense agent (cloud mode). + +The backend is **strictly conforming** to deepagents' +:class:`BackendProtocol`. It returns ``WriteResult`` / ``EditResult`` / list +shapes exactly as upstream expects (no extra fields). All side-state +plumbing — ``dirty_paths``, ``doc_id_by_path``, ``staged_dirs``, +``pending_moves``, ``files`` cache — is appended by the overridden tool +wrappers in :class:`SurfSenseFilesystemMiddleware` via ``Command.update``. + +The backend never writes to Postgres. End-of-turn persistence is handled by +:class:`KnowledgeBasePersistenceMiddleware`. This module is purely a +read-side and a state-merging helper. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import fnmatch +import logging +import re +from datetime import UTC +from typing import Any + +from deepagents.backends.protocol import ( + BackendProtocol, + EditResult, + FileDownloadResponse, + FileInfo, + FileUploadResponse, + GrepMatch, + WriteResult, +) +from deepagents.backends.utils import ( + create_file_data, + file_data_to_string, + format_read_response, + perform_string_replacement, + update_file_data, +) +from langchain.tools import ToolRuntime +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.document_xml import build_document_xml +from app.agents.new_chat.path_resolver import ( + DOCUMENTS_ROOT, + build_path_index, + doc_to_virtual_path, + virtual_path_to_doc, +) +from app.db import Chunk, Document, shielded_async_session + +logger = logging.getLogger(__name__) + +_TEMP_PREFIX = "temp_" +_GREP_MAX_TOTAL_MATCHES = 50 +_GREP_MAX_PER_DOC = 5 + + +def _basename(path: str) -> str: + return path.rsplit("/", 1)[-1] + + +def _is_under(child: str, parent: str) -> bool: + """Return True iff ``child`` is at-or-under ``parent`` (directory semantics).""" + if parent == "/": + return child.startswith("/") + return child == parent or child.startswith(parent.rstrip("/") + "/") + + +def paginate_listing( + infos: list[FileInfo], + *, + offset: int = 0, + limit: int | None = None, +) -> list[FileInfo]: + """Paginate a listing produced by :meth:`KBPostgresBackend.als_info`.""" + if offset < 0: + offset = 0 + end: int | None + end = None if limit is None or limit < 0 else offset + limit + return list(infos[offset:end]) + + +class KBPostgresBackend(BackendProtocol): + """Lazy, read-only Postgres view for ``/documents/*`` virtual paths. + + The backend exposes a virtual ``/documents/`` namespace mirroring the + ``Folder``/``Document`` graph. Reads materialize XML on first access and + cache it via the overriding tool wrappers (NOT here). Writes never touch + the DB — they return ``files_update`` deltas that the wrappers turn into + Command updates, and the persistence middleware commits them at end of + turn. + """ + + _IMAGE_EXTENSIONS = frozenset({".png", ".jpg", ".jpeg", ".gif", ".webp"}) + + def __init__(self, search_space_id: int, runtime: ToolRuntime) -> None: + self.search_space_id = search_space_id + self.runtime = runtime + + @property + def state(self) -> dict[str, Any]: + return getattr(self.runtime, "state", {}) or {} + + # ------------------------------------------------------------------ helpers + + def _state_files(self) -> dict[str, Any]: + return dict(self.state.get("files") or {}) + + def _staged_dirs(self) -> list[str]: + return list(self.state.get("staged_dirs") or []) + + def _pending_moves(self) -> list[dict[str, Any]]: + return list(self.state.get("pending_moves") or []) + + def _kb_anon_doc(self) -> dict[str, Any] | None: + anon = self.state.get("kb_anon_doc") + return anon if isinstance(anon, dict) else None + + def _matched_chunk_ids(self, doc_id: int) -> set[int]: + mapping = self.state.get("kb_matched_chunk_ids") or {} + try: + return set(mapping.get(doc_id, []) or []) + except TypeError: + return set() + + @staticmethod + def _file_data_size(file_data: dict[str, Any]) -> int: + try: + return len("\n".join(file_data.get("content") or [])) + except Exception: + return 0 + + def _normalize_listing_path(self, path: str) -> str: + if not path: + return DOCUMENTS_ROOT + if path == "/": + return path + return path.rstrip("/") if path != "/" else path + + def _moved_view_paths( + self, + existing: dict[str, dict[str, Any]], + ) -> tuple[set[str], dict[str, str]]: + """Apply ``pending_moves`` to a path set and return ``(removed, alias)``. + + Removed paths should disappear from listings; ``alias[source] = dest`` + means a virtual entry should appear at ``dest`` even if no DB row is + yet there. + """ + removed: set[str] = set() + alias: dict[str, str] = {} + for move in self._pending_moves(): + src = move.get("source") + dst = move.get("dest") + if not src or not dst: + continue + removed.add(src) + alias[src] = dst + existing.pop(src, None) + return removed, alias + + # ------------------------------------------------------------------ ls/read + + async def als_info(self, path: str) -> list[FileInfo]: # type: ignore[override] + normalized = self._normalize_listing_path(path) + infos: list[FileInfo] = [] + seen: set[str] = set() + + anon = self._kb_anon_doc() + if anon: + anon_path = str(anon.get("path") or "") + if ( + anon_path + and _is_under(anon_path, normalized) + and anon_path != normalized + and anon_path not in seen + ): + infos.append( + FileInfo( + path=anon_path, + is_dir=False, + size=len(str(anon.get("content") or "")), + modified_at="", + ) + ) + seen.add(anon_path) + + files = self._state_files() + moved_removed, moved_alias = self._moved_view_paths(files) + + if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/": + try: + async with shielded_async_session() as session: + db_infos, subdir_paths = await self._list_db_directory( + session, normalized + ) + except Exception as exc: # pragma: no cover - defensive + logger.warning("KBPostgresBackend.als_info DB error: %s", exc) + db_infos, subdir_paths = [], set() + + for info in db_infos: + p = info.get("path", "") + if not p or p in seen or p in moved_removed: + continue + infos.append(info) + seen.add(p) + + for src, dst in moved_alias.items(): + if src not in seen: + if not _is_under(dst, normalized): + continue + rel = ( + dst[len(normalized) :].lstrip("/") + if normalized != "/" + else dst.lstrip("/") + ) + if "/" in rel: + subdir_paths.add( + (normalized.rstrip("/") + "/" + rel.split("/", 1)[0]) + if normalized != "/" + else "/" + rel.split("/", 1)[0] + ) + continue + if dst in seen: + continue + fd = files.get(dst) + size = self._file_data_size(fd) if isinstance(fd, dict) else 0 + infos.append( + FileInfo( + path=dst, + is_dir=False, + size=int(size), + modified_at=fd.get("modified_at", "") + if isinstance(fd, dict) + else "", + ) + ) + seen.add(dst) + + for staged in self._staged_dirs(): + if not staged or not staged.startswith(DOCUMENTS_ROOT): + continue + if staged == normalized: + continue + if not _is_under(staged, normalized): + continue + rel = ( + staged[len(normalized) :].lstrip("/") + if normalized != "/" + else staged.lstrip("/") + ) + if not rel: + continue + first = rel.split("/", 1)[0] + immediate = ( + normalized.rstrip("/") + "/" + first + if normalized != "/" + else "/" + first + ) + subdir_paths.add(immediate) + + for sub in sorted(subdir_paths): + if sub in seen: + continue + infos.append(FileInfo(path=sub, is_dir=True, size=0, modified_at="")) + seen.add(sub) + + for path_key, fd in files.items(): + if not isinstance(path_key, str) or path_key in seen: + continue + if not _is_under(path_key, normalized) or path_key == normalized: + continue + if normalized == "/": + rel = path_key.lstrip("/") + else: + rel = path_key[len(normalized) :].lstrip("/") + if not rel: + continue + if "/" in rel: + first = rel.split("/", 1)[0] + immediate = ( + normalized.rstrip("/") + "/" + first + if normalized != "/" + else "/" + first + ) + if immediate not in seen: + infos.append( + FileInfo(path=immediate, is_dir=True, size=0, modified_at="") + ) + seen.add(immediate) + continue + include = path_key.startswith(DOCUMENTS_ROOT) or _basename( + path_key + ).startswith(_TEMP_PREFIX) + if not include: + continue + size = self._file_data_size(fd) if isinstance(fd, dict) else 0 + infos.append( + FileInfo( + path=path_key, + is_dir=False, + size=int(size), + modified_at=fd.get("modified_at", "") + if isinstance(fd, dict) + else "", + ) + ) + seen.add(path_key) + + infos.sort(key=lambda fi: (not fi.get("is_dir", False), fi.get("path", ""))) + return infos + + def ls_info(self, path: str) -> list[FileInfo]: # type: ignore[override] + return asyncio.run(self.als_info(path)) + + async def _list_db_directory( + self, + session: AsyncSession, + normalized_path: str, + ) -> tuple[list[FileInfo], set[str]]: + """List immediate Folders + Documents at ``normalized_path``. + + Returns ``(file_infos, subdirectory_paths)``. ``normalized_path`` may + be ``/`` (synthesizes ``/documents``) or a path under ``/documents``. + """ + if normalized_path == "/": + return ( + [], + {DOCUMENTS_ROOT}, + ) + + if not normalized_path.startswith(DOCUMENTS_ROOT): + return [], set() + + index = await build_path_index(session, self.search_space_id) + target_folder_id: int | None = None + if normalized_path != DOCUMENTS_ROOT: + target_path = normalized_path + matches = [ + fid for fid, fpath in index.folder_paths.items() if fpath == target_path + ] + if not matches: + return [], set() + target_folder_id = matches[0] + + result = await session.execute( + select(Document.id, Document.title, Document.folder_id, Document.updated_at) + .where(Document.search_space_id == self.search_space_id) + .where( + Document.folder_id == target_folder_id + if target_folder_id is not None + else Document.folder_id.is_(None) + ) + ) + rows = result.all() + + file_infos: list[FileInfo] = [] + for row in rows: + path = doc_to_virtual_path( + doc_id=row.id, + title=str(row.title or "untitled"), + folder_id=row.folder_id, + index=index, + ) + modified = "" + if row.updated_at is not None: + with contextlib.suppress(Exception): + modified = row.updated_at.astimezone(UTC).isoformat() + file_infos.append( + FileInfo( + path=path, + is_dir=False, + size=0, + modified_at=modified, + ) + ) + + subdirs: set[str] = set() + for _fid, fpath in index.folder_paths.items(): + if fpath == normalized_path: + continue + base = normalized_path.rstrip("/") + if not fpath.startswith(base + "/"): + continue + rel = fpath[len(base) + 1 :] + if "/" in rel: + continue + subdirs.add(base + "/" + rel) + return file_infos, subdirs + + async def aread( # type: ignore[override] + self, + file_path: str, + offset: int = 0, + limit: int = 2000, + ) -> str: + files = self._state_files() + file_data = files.get(file_path) + if file_data is not None: + return format_read_response(file_data, offset, limit) + + loaded = await self._load_file_data(file_path) + if loaded is None: + return f"Error: File '{file_path}' not found" + file_data, _ = loaded + return format_read_response(file_data, offset, limit) + + def read(self, file_path: str, offset: int = 0, limit: int = 2000) -> str: # type: ignore[override] + return asyncio.run(self.aread(file_path, offset, limit)) + + async def _load_file_data( + self, + path: str, + ) -> tuple[dict[str, Any], int | None] | None: + """Lazy-load a virtual KB document into a deepagents ``FileData``. + + Returns ``(file_data, doc_id)`` or ``None`` if the path doesn't map + to any known document. ``doc_id`` is ``None`` for the synthetic + anonymous document so the caller doesn't track it as a DB-backed file. + """ + anon = self._kb_anon_doc() + if anon and str(anon.get("path") or "") == path: + doc_payload = { + "document_id": -1, + "chunks": list(anon.get("chunks") or []), + "matched_chunk_ids": [], + "document": { + "id": -1, + "title": anon.get("title") or "uploaded_document", + "document_type": "FILE", + "metadata": {"source": "anonymous_upload"}, + }, + "source": "FILE", + } + xml = build_document_xml(doc_payload, matched_chunk_ids=set()) + file_data = create_file_data(xml) + return file_data, None + + if not path.startswith(DOCUMENTS_ROOT): + return None + + async with shielded_async_session() as session: + document = await virtual_path_to_doc( + session, + search_space_id=self.search_space_id, + virtual_path=path, + ) + if document is None: + return None + chunk_rows = await session.execute( + select(Chunk.id, Chunk.content) + .where(Chunk.document_id == document.id) + .order_by(Chunk.id) + ) + chunks = [ + {"chunk_id": row.id, "content": row.content} for row in chunk_rows.all() + ] + + doc_payload = { + "document_id": document.id, + "chunks": chunks, + "matched_chunk_ids": list(self._matched_chunk_ids(document.id)), + "document": { + "id": document.id, + "title": document.title, + "document_type": ( + document.document_type.value + if getattr(document, "document_type", None) is not None + else "UNKNOWN" + ), + "metadata": dict(document.document_metadata or {}), + }, + "source": ( + document.document_type.value + if getattr(document, "document_type", None) is not None + else "UNKNOWN" + ), + } + xml = build_document_xml( + doc_payload, + matched_chunk_ids=self._matched_chunk_ids(document.id), + ) + file_data = create_file_data(xml) + return file_data, document.id + + # ------------------------------------------------------------------ writes + + async def awrite(self, file_path: str, content: str) -> WriteResult: # type: ignore[override] + files = self._state_files() + if file_path in files: + return WriteResult( + error=( + f"Cannot write to {file_path} because it already exists. " + "Read and then make an edit, or write to a new path." + ) + ) + new_file_data = create_file_data(content) + return WriteResult(path=file_path, files_update={file_path: new_file_data}) + + def write(self, file_path: str, content: str) -> WriteResult: # type: ignore[override] + return asyncio.run(self.awrite(file_path, content)) + + async def aedit( # type: ignore[override] + self, + file_path: str, + old_string: str, + new_string: str, + replace_all: bool = False, + ) -> EditResult: + files = self._state_files() + file_data = files.get(file_path) + if file_data is None: + loaded = await self._load_file_data(file_path) + if loaded is None: + return EditResult(error=f"Error: File '{file_path}' not found") + file_data, _ = loaded + + content = file_data_to_string(file_data) + result = perform_string_replacement( + content, old_string, new_string, replace_all + ) + if isinstance(result, str): + return EditResult(error=result) + + new_content, occurrences = result + new_file_data = update_file_data(file_data, new_content) + return EditResult( + path=file_path, + files_update={file_path: new_file_data}, + occurrences=int(occurrences), + ) + + def edit( # type: ignore[override] + self, + file_path: str, + old_string: str, + new_string: str, + replace_all: bool = False, + ) -> EditResult: + return asyncio.run(self.aedit(file_path, old_string, new_string, replace_all)) + + # ------------------------------------------------------------------ glob/grep + + async def aglob_info(self, pattern: str, path: str = "/") -> list[FileInfo]: # type: ignore[override] + normalized = self._normalize_listing_path(path) + results: list[FileInfo] = [] + seen: set[str] = set() + + files = self._state_files() + moved_removed, _ = self._moved_view_paths(files) + regex = re.compile(fnmatch.translate(pattern)) + for path_key, fd in files.items(): + if path_key in moved_removed: + continue + if not _is_under(path_key, normalized): + continue + rel = ( + path_key[len(normalized) :].lstrip("/") + if normalized != "/" + else path_key.lstrip("/") + ) + if not regex.match(rel) and not regex.match(path_key): + continue + if path_key in seen: + continue + size = self._file_data_size(fd) if isinstance(fd, dict) else 0 + results.append( + FileInfo( + path=path_key, + is_dir=False, + size=int(size), + modified_at=fd.get("modified_at", "") + if isinstance(fd, dict) + else "", + ) + ) + seen.add(path_key) + + if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/": + try: + async with shielded_async_session() as session: + index = await build_path_index(session, self.search_space_id) + rows = await session.execute( + select(Document.id, Document.title, Document.folder_id).where( + Document.search_space_id == self.search_space_id + ) + ) + for row in rows.all(): + candidate = doc_to_virtual_path( + doc_id=row.id, + title=str(row.title or "untitled"), + folder_id=row.folder_id, + index=index, + ) + if candidate in seen or candidate in moved_removed: + continue + if not _is_under(candidate, normalized): + continue + rel = ( + candidate[len(normalized) :].lstrip("/") + if normalized != "/" + else candidate.lstrip("/") + ) + if not regex.match(rel) and not regex.match(candidate): + continue + results.append( + FileInfo( + path=candidate, is_dir=False, size=0, modified_at="" + ) + ) + seen.add(candidate) + except Exception as exc: # pragma: no cover - defensive + logger.warning("KBPostgresBackend.aglob_info DB error: %s", exc) + + results.sort(key=lambda fi: fi.get("path", "")) + return results + + def glob_info(self, pattern: str, path: str = "/") -> list[FileInfo]: # type: ignore[override] + return asyncio.run(self.aglob_info(pattern, path)) + + async def agrep_raw( # type: ignore[override] + self, + pattern: str, + path: str | None = None, + glob: str | None = None, + ) -> list[GrepMatch] | str: + if not pattern: + return "Error: pattern cannot be empty" + + normalized = self._normalize_listing_path(path or "/") + matches: list[GrepMatch] = [] + + files = self._state_files() + moved_removed, _ = self._moved_view_paths(files) + glob_re = re.compile(fnmatch.translate(glob)) if glob else None + for path_key, fd in files.items(): + if path_key in moved_removed: + continue + if not _is_under(path_key, normalized): + continue + if glob_re is not None and not glob_re.match(_basename(path_key)): + continue + if not isinstance(fd, dict): + continue + for line_no, line in enumerate(fd.get("content") or [], 1): + if pattern in line: + matches.append( + GrepMatch(path=path_key, line=int(line_no), text=str(line)) + ) + if len(matches) >= _GREP_MAX_TOTAL_MATCHES: + return matches + + if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/": + try: + async with shielded_async_session() as session: + index = await build_path_index(session, self.search_space_id) + sub = ( + select(Chunk.document_id, Chunk.id, Chunk.content) + .join(Document, Document.id == Chunk.document_id) + .where(Document.search_space_id == self.search_space_id) + .where(Chunk.content.ilike(f"%{pattern}%")) + .order_by(Chunk.document_id, Chunk.id) + ) + chunk_rows = await session.execute(sub) + per_doc: dict[int, int] = {} + doc_id_to_path: dict[int, str] = {} + needed_doc_ids: set[int] = set() + chunk_buffer: list[tuple[int, int, str]] = [] + for row in chunk_rows.all(): + per_doc.setdefault(row.document_id, 0) + if per_doc[row.document_id] >= _GREP_MAX_PER_DOC: + continue + per_doc[row.document_id] += 1 + chunk_buffer.append((row.document_id, row.id, row.content)) + needed_doc_ids.add(row.document_id) + if sum(per_doc.values()) >= _GREP_MAX_TOTAL_MATCHES - len( + matches + ): + break + if needed_doc_ids: + doc_rows = await session.execute( + select( + Document.id, Document.title, Document.folder_id + ).where(Document.id.in_(list(needed_doc_ids))) + ) + for row in doc_rows.all(): + doc_id_to_path[row.id] = doc_to_virtual_path( + doc_id=row.id, + title=str(row.title or "untitled"), + folder_id=row.folder_id, + index=index, + ) + for doc_id, chunk_id, content in chunk_buffer: + candidate = doc_id_to_path.get(doc_id) + if not candidate or candidate in moved_removed: + continue + if not _is_under(candidate, normalized): + continue + if glob_re is not None and not glob_re.match( + _basename(candidate) + ): + continue + snippet = " ".join(str(content).split())[:240] + matches.append( + GrepMatch( + path=candidate, + line=0, + text=( + f": " + f"{snippet}" + ), + ) + ) + if len(matches) >= _GREP_MAX_TOTAL_MATCHES: + break + except Exception as exc: # pragma: no cover - defensive + logger.warning("KBPostgresBackend.agrep_raw DB error: %s", exc) + + return matches + + def grep_raw( # type: ignore[override] + self, + pattern: str, + path: str | None = None, + glob: str | None = None, + ) -> list[GrepMatch] | str: + return asyncio.run(self.agrep_raw(pattern, path, glob)) + + # ------------------------------------------------------------------ list_tree (helper) + + async def alist_tree_listing( + self, + path: str = DOCUMENTS_ROOT, + *, + max_depth: int | None = 8, + page_size: int = 500, + include_files: bool = True, + include_dirs: bool = True, + ) -> dict[str, Any]: + """Recursive tree listing for cloud mode. + + Mirrors the shape returned by :class:`MultiRootLocalFolderBackend.list_tree`: + ``{"entries": [{path, is_dir, size, modified_at, depth}, ...], "truncated": bool}``. + """ + normalized = self._normalize_listing_path(path or DOCUMENTS_ROOT) + if not normalized.startswith(DOCUMENTS_ROOT) and normalized != "/": + return {"error": "Error: path must be under /documents/"} + + entries: list[dict[str, Any]] = [] + truncated = False + + try: + async with shielded_async_session() as session: + index = await build_path_index(session, self.search_space_id) + doc_rows_raw = await session.execute( + select( + Document.id, + Document.title, + Document.folder_id, + Document.updated_at, + ).where(Document.search_space_id == self.search_space_id) + ) + doc_rows = list(doc_rows_raw.all()) + except Exception as exc: # pragma: no cover + logger.warning("KBPostgresBackend.alist_tree_listing DB error: %s", exc) + return {"entries": [], "truncated": False} + + files = self._state_files() + moved_removed, _ = self._moved_view_paths(files) + anon = self._kb_anon_doc() + anon_path = str(anon.get("path") or "") if anon else "" + + def _depth_of(p: str) -> int: + if p == DOCUMENTS_ROOT: + return 0 + rel_root = ( + p[len(DOCUMENTS_ROOT) :].lstrip("/") + if normalized.startswith(DOCUMENTS_ROOT) + else p.lstrip("/") + ) + return len([part for part in rel_root.split("/") if part]) + + def _add_entry(entry: dict[str, Any]) -> bool: + nonlocal truncated + if len(entries) >= page_size: + truncated = True + return False + entries.append(entry) + return True + + if include_dirs: + for _fid, fpath in sorted(index.folder_paths.items(), key=lambda kv: kv[1]): + if not _is_under(fpath, normalized): + continue + depth = _depth_of(fpath) + if max_depth is not None and depth > max_depth: + continue + if not _add_entry( + { + "path": fpath, + "is_dir": True, + "size": 0, + "modified_at": "", + "depth": depth, + } + ): + return {"entries": entries, "truncated": True} + for staged in self._staged_dirs(): + if not _is_under(staged, normalized): + continue + depth = _depth_of(staged) + if max_depth is not None and depth > max_depth: + continue + if any(e["path"] == staged for e in entries): + continue + if not _add_entry( + { + "path": staged, + "is_dir": True, + "size": 0, + "modified_at": "", + "depth": depth, + } + ): + return {"entries": entries, "truncated": True} + + if include_files: + for row in sorted(doc_rows, key=lambda r: str(r.title or "")): + candidate = doc_to_virtual_path( + doc_id=row.id, + title=str(row.title or "untitled"), + folder_id=row.folder_id, + index=index, + ) + if candidate in moved_removed: + continue + if not _is_under(candidate, normalized): + continue + depth = _depth_of(candidate) + if max_depth is not None and depth > max_depth: + continue + modified = "" + if row.updated_at is not None: + with contextlib.suppress(Exception): + modified = row.updated_at.astimezone(UTC).isoformat() + if not _add_entry( + { + "path": candidate, + "is_dir": False, + "size": 0, + "modified_at": modified, + "depth": depth, + } + ): + return {"entries": entries, "truncated": True} + + if anon_path and _is_under(anon_path, normalized): + depth = _depth_of(anon_path) + if (max_depth is None or depth <= max_depth) and not _add_entry( + { + "path": anon_path, + "is_dir": False, + "size": len(str(anon.get("content") or "")), + "modified_at": "", + "depth": depth, + } + ): + return {"entries": entries, "truncated": True} + + for path_key, fd in files.items(): + if not isinstance(path_key, str): + continue + if not _is_under(path_key, normalized): + continue + if any(e["path"] == path_key for e in entries): + continue + if not ( + path_key.startswith(DOCUMENTS_ROOT) + or _basename(path_key).startswith(_TEMP_PREFIX) + ): + continue + depth = _depth_of(path_key) + if max_depth is not None and depth > max_depth: + continue + size = self._file_data_size(fd) if isinstance(fd, dict) else 0 + if not _add_entry( + { + "path": path_key, + "is_dir": False, + "size": int(size), + "modified_at": fd.get("modified_at", "") + if isinstance(fd, dict) + else "", + "depth": depth, + } + ): + return {"entries": entries, "truncated": True} + + return {"entries": entries, "truncated": truncated} + + # ------------------------------------------------------------------ uploads (unsupported) + + def upload_files( # type: ignore[override] + self, files: list[tuple[str, bytes]] + ) -> list[FileUploadResponse]: + msg = "KBPostgresBackend does not support upload_files." + raise NotImplementedError(msg) + + def download_files( # type: ignore[override] + self, paths: list[str] + ) -> list[FileDownloadResponse]: + responses: list[FileDownloadResponse] = [] + files = self._state_files() + for path in paths: + fd = files.get(path) + if fd is None: + responses.append( + FileDownloadResponse( + path=path, content=None, error="file_not_found" + ) + ) + continue + content_str = file_data_to_string(fd) + responses.append( + FileDownloadResponse( + path=path, + content=content_str.encode("utf-8"), + error=None, + ) + ) + return responses + + +# --- module-level small helpers --------------------------------------------- + + +async def list_tree_listing( + backend: KBPostgresBackend, + path: str, + *, + max_depth: int | None = 8, + page_size: int = 500, + include_files: bool = True, + include_dirs: bool = True, +) -> dict[str, Any]: + """Async helper used by the overridden ``list_tree`` tool wrapper.""" + return await backend.alist_tree_listing( + path, + max_depth=max_depth, + page_size=page_size, + include_files=include_files, + include_dirs=include_dirs, + ) + + +__all__ = [ + "KBPostgresBackend", + "list_tree_listing", + "paginate_listing", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py index 6df317aaa..edd8c7af1 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py +++ b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py @@ -1,10 +1,24 @@ -"""Knowledge-base pre-search middleware for the SurfSense new chat agent. +"""Hybrid-search priority middleware for the SurfSense new chat agent. -This middleware runs before the main agent loop and seeds a virtual filesystem -(`files` state) with relevant documents retrieved via hybrid search. On each -turn the filesystem is *expanded* — new results merge with documents loaded -during prior turns — and a synthetic ``ls`` result is injected into the message -history so the LLM is immediately aware of the current filesystem structure. +This middleware runs ``before_agent`` on every turn and writes: + +* ``state["kb_priority"]`` — the top-K most relevant documents for the + current user message, used to render a ```` system + message immediately before the user turn. +* ``state["kb_matched_chunk_ids"]`` — internal hand-off mapping + (``Document.id`` → matched chunk IDs) consumed by + :class:`KBPostgresBackend._load_file_data` when the agent first reads each + document, so the XML wrapper can flag matched sections in + ````. + +The previous "scoped filesystem" behaviour (synthetic ``ls`` + state +``files`` seeding) is intentionally removed: documents are now lazy-loaded +from Postgres on demand, with the full workspace tree rendered separately +by :class:`KnowledgeTreeMiddleware`. + +In anonymous mode the middleware skips hybrid search entirely and emits a +single-entry priority list pointing at the Redis-loaded document +(``state["kb_anon_doc"]``). """ from __future__ import annotations @@ -13,27 +27,30 @@ import asyncio import json import logging import re -import uuid from collections.abc import Sequence from datetime import UTC, datetime from typing import Any from langchain.agents.middleware import AgentMiddleware, AgentState from langchain_core.language_models import BaseChatModel -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langgraph.runtime import Runtime from litellm import token_counter from pydantic import BaseModel, Field, ValidationError from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState +from app.agents.new_chat.path_resolver import ( + PathIndex, + build_path_index, + doc_to_virtual_path, +) from app.agents.new_chat.utils import parse_date_or_datetime, resolve_date_range from app.db import ( NATIVE_TO_LEGACY_DOCTYPE, Chunk, Document, - Folder, shielded_async_session, ) from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever @@ -70,7 +87,6 @@ class KBSearchPlan(BaseModel): def _extract_text_from_message(message: BaseMessage) -> str: - """Extract plain text from a message content.""" content = getattr(message, "content", "") if isinstance(content, str): return content @@ -85,19 +101,6 @@ def _extract_text_from_message(message: BaseMessage) -> str: return str(content) -def _safe_filename(value: str, *, fallback: str = "untitled.xml") -> str: - """Convert arbitrary text into a filesystem-safe filename.""" - name = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip() - name = re.sub(r"\s+", " ", name) - if not name: - name = fallback - if len(name) > 180: - name = name[:180].rstrip() - if not name.lower().endswith(".xml"): - name = f"{name}.xml" - return name - - def _render_recent_conversation( messages: Sequence[BaseMessage], *, @@ -107,10 +110,9 @@ def _render_recent_conversation( ) -> str: """Render recent dialogue for internal planning under a token budget. - Prefers the latest messages and uses the project's existing model-aware - token budgeting hooks when available on the LLM (`_count_tokens`, - `_get_max_input_tokens`). Falls back to the prior fixed-message heuristic - if token counting is unavailable. + Filters to ``HumanMessage`` and ``AIMessage`` (without tool_calls) so that + injected ``SystemMessage`` artefacts (priority list, workspace tree, + file-write contract) don't pollute the planner prompt. """ rendered: list[tuple[str, str]] = [] for message in messages: @@ -133,8 +135,6 @@ def _render_recent_conversation( if not rendered: return "" - # Exclude the latest user message from "recent conversation" because it is - # already passed separately as "Latest user message" in the planner prompt. if rendered and rendered[-1][0] == "user" and rendered[-1][1] == user_text.strip(): rendered = rendered[:-1] @@ -216,8 +216,6 @@ def _render_recent_conversation( selected_lines = candidate_lines continue - # If the full message does not fit, keep as much of this most-recent - # older message as possible via binary search. lo, hi = 1, len(text) best_line: str | None = None while lo <= hi: @@ -249,7 +247,6 @@ def _build_kb_planner_prompt( recent_conversation: str, user_text: str, ) -> str: - """Build a compact internal prompt for KB query rewriting and date scoping.""" today = datetime.now(UTC).date().isoformat() return ( "You optimize internal knowledge-base search inputs for document retrieval.\n" @@ -275,12 +272,10 @@ def _build_kb_planner_prompt( def _extract_json_payload(text: str) -> str: - """Extract a JSON object from a raw LLM response.""" stripped = text.strip() fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL) if fenced: return fenced.group(1) - start = stripped.find("{") end = stripped.rfind("}") if start != -1 and end != -1 and end > start: @@ -289,7 +284,6 @@ def _extract_json_payload(text: str) -> str: def _parse_kb_search_plan_response(response_text: str) -> KBSearchPlan: - """Parse and validate the planner's JSON response.""" payload = json.loads(_extract_json_payload(response_text)) return KBSearchPlan.model_validate(payload) @@ -298,212 +292,19 @@ def _normalize_optional_date_range( start_date: str | None, end_date: str | None, ) -> tuple[datetime | None, datetime | None]: - """Normalize optional planner dates into a UTC datetime range.""" parsed_start = parse_date_or_datetime(start_date) if start_date else None parsed_end = parse_date_or_datetime(end_date) if end_date else None if parsed_start is None and parsed_end is None: return None, None - resolved_start, resolved_end = resolve_date_range(parsed_start, parsed_end) - return resolved_start, resolved_end - - -def _build_document_xml( - document: dict[str, Any], - matched_chunk_ids: set[int] | None = None, -) -> str: - """Build citation-friendly XML with a ```` for smart seeking. - - The ```` at the top of each document lists every chunk with its - line range inside ```` and flags chunks that directly - matched the search query (``matched="true"``). This lets the LLM jump - straight to the most relevant section via ``read_file(offset=…, limit=…)`` - instead of reading sequentially from the start. - """ - matched = matched_chunk_ids or set() - - doc_meta = document.get("document") or {} - metadata = (doc_meta.get("metadata") or {}) if isinstance(doc_meta, dict) else {} - document_id = doc_meta.get("id", document.get("document_id", "unknown")) - document_type = doc_meta.get("document_type", document.get("source", "UNKNOWN")) - title = doc_meta.get("title") or metadata.get("title") or "Untitled Document" - url = ( - metadata.get("url") or metadata.get("source") or metadata.get("page_url") or "" - ) - metadata_json = json.dumps(metadata, ensure_ascii=False) - - # --- 1. Metadata header (fixed structure) --- - metadata_lines: list[str] = [ - "", - "", - f" {document_id}", - f" {document_type}", - f" <![CDATA[{title}]]>", - f" ", - f" ", - "", - "", - ] - - # --- 2. Pre-build chunk XML strings to compute line counts --- - chunks = document.get("chunks") or [] - chunk_entries: list[tuple[int | None, str]] = [] # (chunk_id, xml_string) - if isinstance(chunks, list): - for chunk in chunks: - if not isinstance(chunk, dict): - continue - chunk_id = chunk.get("chunk_id") or chunk.get("id") - chunk_content = str(chunk.get("content", "")).strip() - if not chunk_content: - continue - if chunk_id is None: - xml = f" " - else: - xml = f" " - chunk_entries.append((chunk_id, xml)) - - # --- 3. Compute line numbers for every chunk --- - # Layout (1-indexed lines for read_file): - # metadata_lines -> len(metadata_lines) lines - # -> 1 line - # index entries -> len(chunk_entries) lines - # -> 1 line - # (empty line) -> 1 line - # -> 1 line - # chunk xml lines… - # -> 1 line - # -> 1 line - index_overhead = ( - 1 + len(chunk_entries) + 1 + 1 + 1 - ) # tags + empty + - first_chunk_line = len(metadata_lines) + index_overhead + 1 # 1-indexed - - current_line = first_chunk_line - index_entry_lines: list[str] = [] - for cid, xml_str in chunk_entries: - num_lines = xml_str.count("\n") + 1 - end_line = current_line + num_lines - 1 - matched_attr = ' matched="true"' if cid is not None and cid in matched else "" - if cid is not None: - index_entry_lines.append( - f' ' - ) - else: - index_entry_lines.append( - f' ' - ) - current_line = end_line + 1 - - # --- 4. Assemble final XML --- - lines = metadata_lines.copy() - lines.append("") - lines.extend(index_entry_lines) - lines.append("") - lines.append("") - lines.append("") - for _, xml_str in chunk_entries: - lines.append(xml_str) - lines.extend(["", ""]) - return "\n".join(lines) - - -async def _get_folder_paths( - session: AsyncSession, search_space_id: int -) -> dict[int, str]: - """Return a map of folder_id -> virtual folder path under /documents.""" - result = await session.execute( - select(Folder.id, Folder.name, Folder.parent_id).where( - Folder.search_space_id == search_space_id - ) - ) - rows = result.all() - by_id = {row.id: {"name": row.name, "parent_id": row.parent_id} for row in rows} - - cache: dict[int, str] = {} - - def resolve_path(folder_id: int) -> str: - if folder_id in cache: - return cache[folder_id] - parts: list[str] = [] - cursor: int | None = folder_id - visited: set[int] = set() - while cursor is not None and cursor in by_id and cursor not in visited: - visited.add(cursor) - entry = by_id[cursor] - parts.append( - _safe_filename(str(entry["name"]), fallback="folder").removesuffix( - ".xml" - ) - ) - cursor = entry["parent_id"] - parts.reverse() - path = "/documents/" + "/".join(parts) if parts else "/documents" - cache[folder_id] = path - return path - - for folder_id in by_id: - resolve_path(folder_id) - return cache - - -def _build_synthetic_ls( - existing_files: dict[str, Any] | None, - new_files: dict[str, Any], - *, - mentioned_paths: set[str] | None = None, -) -> tuple[AIMessage, ToolMessage]: - """Build a synthetic ls("/documents") tool-call + result for the LLM context. - - Mentioned files are listed first. A separate header tells the LLM which - files the user explicitly selected; the path list itself stays clean so - paths can be passed directly to ``read_file`` without stripping tags. - """ - _mentioned = mentioned_paths or set() - merged: dict[str, Any] = {**(existing_files or {}), **new_files} - doc_paths = [ - p for p, v in merged.items() if p.startswith("/documents/") and v is not None - ] - - new_set = set(new_files) - mentioned_list = [p for p in doc_paths if p in _mentioned] - new_non_mentioned = [p for p in doc_paths if p in new_set and p not in _mentioned] - old_paths = [p for p in doc_paths if p not in new_set] - ordered = mentioned_list + new_non_mentioned + old_paths - - parts: list[str] = [] - if mentioned_list: - parts.append( - "USER-MENTIONED documents (read these thoroughly before answering):" - ) - for p in mentioned_list: - parts.append(f" {p}") - parts.append("") - parts.append(str(ordered) if ordered else "No documents found.") - - tool_call_id = f"auto_ls_{uuid.uuid4().hex[:12]}" - ai_msg = AIMessage( - content="", - tool_calls=[{"name": "ls", "args": {"path": "/documents"}, "id": tool_call_id}], - ) - tool_msg = ToolMessage( - content="\n".join(parts), - tool_call_id=tool_call_id, - ) - return ai_msg, tool_msg + return resolve_date_range(parsed_start, parsed_end) def _resolve_search_types( available_connectors: list[str] | None, available_document_types: list[str] | None, ) -> list[str] | None: - """Build a flat list of document-type strings for the chunk retriever. - - Includes legacy equivalents from ``NATIVE_TO_LEGACY_DOCTYPE`` so that - old documents indexed under Composio names are still found. - - Returns ``None`` when no filtering is desired (search all types). - """ types: set[str] = set() if available_document_types: types.update(available_document_types) @@ -531,13 +332,8 @@ async def browse_recent_documents( start_date: datetime | None = None, end_date: datetime | None = None, ) -> list[dict[str, Any]]: - """Return documents ordered by recency (newest first), no relevance ranking. - - Used when the user's intent is temporal ("latest file", "most recent upload") - and hybrid search would produce poor results because the query has no - meaningful topical signal. - """ - from sqlalchemy import func, select + """Return documents ordered by recency (newest first), no relevance ranking.""" + from sqlalchemy import func from app.db import DocumentType @@ -581,7 +377,6 @@ async def browse_recent_documents( return [] doc_ids = [d.id for d in documents] - numbered = ( select( Chunk.id.label("chunk_id"), @@ -632,6 +427,7 @@ async def browse_recent_documents( else None ), "metadata": metadata, + "folder_id": getattr(doc, "folder_id", None), }, "source": ( doc.document_type.value @@ -640,12 +436,6 @@ async def browse_recent_documents( ), } ) - - logger.info( - "browse_recent_documents: %d docs returned for space=%d", - len(results), - search_space_id, - ) return results @@ -659,17 +449,11 @@ async def search_knowledge_base( start_date: datetime | None = None, end_date: datetime | None = None, ) -> list[dict[str, Any]]: - """Run a single unified hybrid search against the knowledge base. - - Uses one ``ChucksHybridSearchRetriever`` call across all document types - instead of fanning out per-connector. This reduces the number of DB - queries from ~10 to 2 (one RRF query + one chunk fetch). - """ + """Run a single unified hybrid search against the knowledge base.""" if not query: return [] [embedding] = embed_texts([query]) - doc_types = _resolve_search_types(available_connectors, available_document_types) retriever_top_k = min(top_k * 3, 30) @@ -693,14 +477,7 @@ async def fetch_mentioned_documents( document_ids: list[int], search_space_id: int, ) -> list[dict[str, Any]]: - """Fetch explicitly mentioned documents with *all* their chunks. - - Returns the same dict structure as ``search_knowledge_base`` so results - can be merged directly into ``build_scoped_filesystem``. Unlike search - results, every chunk is included (no top-K limiting) and none are marked - as ``matched`` since the entire document is relevant by virtue of the - user's explicit mention. - """ + """Fetch explicitly mentioned documents.""" if not document_ids: return [] @@ -750,6 +527,7 @@ async def fetch_mentioned_documents( else None ), "metadata": metadata, + "folder_id": getattr(doc, "folder_id", None), }, "source": ( doc.document_type.value @@ -762,96 +540,36 @@ async def fetch_mentioned_documents( return results -async def build_scoped_filesystem( - *, - documents: Sequence[dict[str, Any]], - search_space_id: int, -) -> tuple[dict[str, dict[str, str]], dict[int, str]]: - """Build a StateBackend-compatible files dict from search results. - - Returns ``(files, doc_id_to_path)`` so callers can reliably map a - document id back to its filesystem path without guessing by title. - Paths are collision-proof: when two documents resolve to the same - path the doc-id is appended to disambiguate. - """ - async with shielded_async_session() as session: - folder_paths = await _get_folder_paths(session, search_space_id) - doc_ids = [ - (doc.get("document") or {}).get("id") - for doc in documents - if isinstance(doc, dict) - ] - doc_ids = [doc_id for doc_id in doc_ids if isinstance(doc_id, int)] - folder_by_doc_id: dict[int, int | None] = {} - if doc_ids: - doc_rows = await session.execute( - select(Document.id, Document.folder_id).where( - Document.search_space_id == search_space_id, - Document.id.in_(doc_ids), - ) - ) - folder_by_doc_id = { - row.id: row.folder_id for row in doc_rows.all() if row.id is not None - } - - files: dict[str, dict[str, str]] = {} - doc_id_to_path: dict[int, str] = {} - for document in documents: - doc_meta = document.get("document") or {} - title = str(doc_meta.get("title") or "untitled") - doc_id = doc_meta.get("id") - folder_id = folder_by_doc_id.get(doc_id) if isinstance(doc_id, int) else None - base_folder = folder_paths.get(folder_id, "/documents") - file_name = _safe_filename(title) - path = f"{base_folder}/{file_name}" - if path in files: - stem = file_name.removesuffix(".xml") - path = f"{base_folder}/{stem} ({doc_id}).xml" - matched_ids = set(document.get("matched_chunk_ids") or []) - xml_content = _build_document_xml(document, matched_chunk_ids=matched_ids) - files[path] = { - "content": xml_content.split("\n"), - "encoding": "utf-8", - "created_at": "", - "modified_at": "", - } - if isinstance(doc_id, int): - doc_id_to_path[doc_id] = path - return files, doc_id_to_path +def _render_priority_message(priority: list[dict[str, Any]]) -> SystemMessage: + """Render the priority list as a single ```` system message.""" + if not priority: + body = "(no priority documents for this turn)" + else: + lines: list[str] = [] + for entry in priority: + score = entry.get("score") + mentioned = entry.get("mentioned") + score_str = f"{score:.3f}" if isinstance(score, (int, float)) else "n/a" + mark = " [USER-MENTIONED]" if mentioned else "" + lines.append(f"- {entry.get('path', '')} (score={score_str}){mark}") + body = "\n".join(lines) + return SystemMessage( + content=( + "\n" + "These documents are most relevant to the latest user message; " + "read them first. Matched sections are flagged inside each " + "document's .\n" + f"{body}\n" + "" + ) + ) -def _build_anon_scoped_filesystem( - documents: Sequence[dict[str, Any]], -) -> dict[str, dict[str, str]]: - """Build a scoped filesystem for anonymous documents without DB queries. - - Anonymous uploads have no folders, so all files go under /documents. - """ - files: dict[str, dict[str, str]] = {} - for document in documents: - doc_meta = document.get("document") or {} - title = str(doc_meta.get("title") or "untitled") - file_name = _safe_filename(title) - path = f"/documents/{file_name}" - if path in files: - doc_id = doc_meta.get("id", "dup") - stem = file_name.removesuffix(".xml") - path = f"/documents/{stem} ({doc_id}).xml" - matched_ids = set(document.get("matched_chunk_ids") or []) - xml_content = _build_document_xml(document, matched_chunk_ids=matched_ids) - files[path] = { - "content": xml_content.split("\n"), - "encoding": "utf-8", - "created_at": "", - "modified_at": "", - } - return files - - -class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] - """Pre-agent middleware that always searches the KB and seeds a scoped filesystem.""" +class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] + """Compute hybrid-search priority hints for the current turn.""" tools = () + state_schema = SurfSenseFilesystemState def __init__( self, @@ -863,7 +581,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] available_document_types: list[str] | None = None, top_k: int = 10, mentioned_document_ids: list[int] | None = None, - anon_session_id: str | None = None, ) -> None: self.llm = llm self.search_space_id = search_space_id @@ -872,7 +589,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] self.available_document_types = available_document_types self.top_k = top_k self.mentioned_document_ids = mentioned_document_ids or [] - self.anon_session_id = anon_session_id async def _plan_search_inputs( self, @@ -880,10 +596,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] messages: Sequence[BaseMessage], user_text: str, ) -> tuple[str, datetime | None, datetime | None, bool]: - """Rewrite the KB query and infer optional date filters with the LLM. - - Returns (optimized_query, start_date, end_date, is_recency_query). - """ if self.llm is None: return user_text, None, None, False @@ -914,7 +626,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] ) is_recency = plan.is_recency_query _perf_log.info( - "[kb_fs_middleware] planner in %.3fs query=%r optimized=%r " + "[kb_priority] planner in %.3fs query=%r optimized=%r " "start=%s end=%s recency=%s", loop.time() - t0, user_text[:80], @@ -946,106 +658,68 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] pass return asyncio.run(self.abefore_agent(state, runtime)) - async def _load_anon_document(self) -> dict[str, Any] | None: - """Load the anonymous user's uploaded document from Redis.""" - if not self.anon_session_id: - return None - try: - import redis.asyncio as aioredis - - from app.config import config - - redis_client = aioredis.from_url( - config.REDIS_APP_URL, decode_responses=True - ) - try: - redis_key = f"anon:doc:{self.anon_session_id}" - data = await redis_client.get(redis_key) - if not data: - return None - doc = json.loads(data) - return { - "document_id": -1, - "content": doc.get("content", ""), - "score": 1.0, - "chunks": [ - { - "chunk_id": -1, - "content": doc.get("content", ""), - } - ], - "matched_chunk_ids": [-1], - "document": { - "id": -1, - "title": doc.get("filename", "uploaded_document"), - "document_type": "FILE", - "metadata": {"source": "anonymous_upload"}, - }, - "source": "FILE", - "_user_mentioned": True, - } - finally: - await redis_client.aclose() - except Exception as exc: - logger.warning("Failed to load anonymous document from Redis: %s", exc) - return None - async def abefore_agent( # type: ignore[override] self, state: AgentState, runtime: Runtime[Any], ) -> dict[str, Any] | None: del runtime + if self.filesystem_mode != FilesystemMode.CLOUD: + return None + messages = state.get("messages") or [] if not messages: return None - if self.filesystem_mode != FilesystemMode.CLOUD: - # Local-folder mode should not seed cloud KB documents into filesystem. - return None - last_human = None + last_human: HumanMessage | None = None for msg in reversed(messages): if isinstance(msg, HumanMessage): last_human = msg break if last_human is None: return None - user_text = _extract_text_from_message(last_human).strip() if not user_text: return None - t0 = _perf_log and asyncio.get_event_loop().time() - existing_files = state.get("files") + anon_doc = state.get("kb_anon_doc") + if anon_doc: + return self._anon_priority(state, anon_doc) - # --- Anonymous session: load Redis doc and skip DB queries --- - if self.anon_session_id: - merged: list[dict[str, Any]] = [] - anon_doc = await self._load_anon_document() - if anon_doc: - merged.append(anon_doc) + return await self._authenticated_priority(state, messages, user_text) - if merged: - new_files = _build_anon_scoped_filesystem(merged) - mentioned_paths = set(new_files.keys()) - else: - new_files = {} - mentioned_paths = set() + def _anon_priority( + self, + state: AgentState, + anon_doc: dict[str, Any], + ) -> dict[str, Any]: + path = str(anon_doc.get("path") or "") + title = str(anon_doc.get("title") or "uploaded_document") + priority = [ + { + "path": path, + "score": 1.0, + "document_id": None, + "title": title, + "mentioned": True, + } + ] + new_messages = list(state.get("messages") or []) + insert_at = max(len(new_messages) - 1, 0) + new_messages.insert(insert_at, _render_priority_message(priority)) + return { + "kb_priority": priority, + "kb_matched_chunk_ids": {}, + "messages": new_messages, + } - ai_msg, tool_msg = _build_synthetic_ls( - existing_files, - new_files, - mentioned_paths=mentioned_paths, - ) - if t0 is not None: - _perf_log.info( - "[kb_fs_middleware] anon completed in %.3fs new_files=%d", - asyncio.get_event_loop().time() - t0, - len(new_files), - ) - return {"files": new_files, "messages": [ai_msg, tool_msg]} - - # --- Authenticated session: full KB search --- + async def _authenticated_priority( + self, + state: AgentState, + messages: Sequence[BaseMessage], + user_text: str, + ) -> dict[str, Any]: + t0 = asyncio.get_event_loop().time() ( planned_query, start_date, @@ -1056,7 +730,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] user_text=user_text, ) - # --- 1. Fetch mentioned documents (user-selected, all chunks) --- mentioned_results: list[dict[str, Any]] = [] if self.mentioned_document_ids: mentioned_results = await fetch_mentioned_documents( @@ -1065,7 +738,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] ) self.mentioned_document_ids = [] - # --- 2. Run KB search (recency browse or hybrid) --- if is_recency: doc_types = _resolve_search_types( self.available_connectors, self.available_document_types @@ -1088,48 +760,108 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] end_date=end_date, ) - # --- 3. Merge: mentioned first, then search (dedup by doc id) --- seen_doc_ids: set[int] = set() - merged_auth: list[dict[str, Any]] = [] + merged: list[dict[str, Any]] = [] for doc in mentioned_results: doc_id = (doc.get("document") or {}).get("id") - if doc_id is not None: + if isinstance(doc_id, int): seen_doc_ids.add(doc_id) - merged_auth.append(doc) + merged.append(doc) for doc in search_results: doc_id = (doc.get("document") or {}).get("id") - if doc_id is not None and doc_id in seen_doc_ids: + if isinstance(doc_id, int) and doc_id in seen_doc_ids: continue - merged_auth.append(doc) + merged.append(doc) - # --- 4. Build scoped filesystem --- - new_files, doc_id_to_path = await build_scoped_filesystem( - documents=merged_auth, - search_space_id=self.search_space_id, + priority, matched_chunk_ids = await self._materialize_priority(merged) + + new_messages = list(messages) + insert_at = max(len(new_messages) - 1, 0) + new_messages.insert(insert_at, _render_priority_message(priority)) + + _perf_log.info( + "[kb_priority] completed in %.3fs query=%r priority=%d mentioned=%d", + asyncio.get_event_loop().time() - t0, + user_text[:80], + len(priority), + len(mentioned_results), ) - mentioned_doc_ids = { - (d.get("document") or {}).get("id") for d in mentioned_results - } - mentioned_paths = { - doc_id_to_path[did] for did in mentioned_doc_ids if did in doc_id_to_path + return { + "kb_priority": priority, + "kb_matched_chunk_ids": matched_chunk_ids, + "messages": new_messages, } - ai_msg, tool_msg = _build_synthetic_ls( - existing_files, - new_files, - mentioned_paths=mentioned_paths, - ) + async def _materialize_priority( + self, merged: list[dict[str, Any]] + ) -> tuple[list[dict[str, Any]], dict[int, list[int]]]: + """Resolve canonical paths and matched chunk ids for the priority list.""" + priority: list[dict[str, Any]] = [] + matched_chunk_ids: dict[int, list[int]] = {} - if t0 is not None: - _perf_log.info( - "[kb_fs_middleware] completed in %.3fs query=%r optimized=%r " - "mentioned=%d new_files=%d total=%d", - asyncio.get_event_loop().time() - t0, - user_text[:80], - planned_query[:120], - len(mentioned_results), - len(new_files), - len(new_files) + len(existing_files or {}), + if not merged: + return priority, matched_chunk_ids + + async with shielded_async_session() as session: + index: PathIndex = await build_path_index(session, self.search_space_id) + doc_ids = [ + (doc.get("document") or {}).get("id") + for doc in merged + if isinstance(doc, dict) + ] + doc_ids = [doc_id for doc_id in doc_ids if isinstance(doc_id, int)] + folder_by_doc_id: dict[int, int | None] = {} + if doc_ids: + folder_rows = await session.execute( + select(Document.id, Document.folder_id).where( + Document.search_space_id == self.search_space_id, + Document.id.in_(doc_ids), + ) + ) + folder_by_doc_id = {row.id: row.folder_id for row in folder_rows.all()} + + for doc in merged: + doc_meta = doc.get("document") or {} + doc_id = doc_meta.get("id") + title = doc_meta.get("title") or "untitled" + folder_id = ( + folder_by_doc_id.get(doc_id) + if isinstance(doc_id, int) + else doc_meta.get("folder_id") ) - return {"files": new_files, "messages": [ai_msg, tool_msg]} + path = doc_to_virtual_path( + doc_id=doc_id if isinstance(doc_id, int) else None, + title=str(title), + folder_id=folder_id if isinstance(folder_id, int) else None, + index=index, + ) + priority.append( + { + "path": path, + "score": float(doc.get("score") or 0.0), + "document_id": doc_id if isinstance(doc_id, int) else None, + "title": str(title), + "mentioned": bool(doc.get("_user_mentioned")), + } + ) + if isinstance(doc_id, int): + chunk_ids = doc.get("matched_chunk_ids") or [] + if chunk_ids: + matched_chunk_ids[doc_id] = [ + int(cid) for cid in chunk_ids if isinstance(cid, (int, str)) + ] + return priority, matched_chunk_ids + + +# Backwards-compatible alias for any external imports. +KnowledgeBaseSearchMiddleware = KnowledgePriorityMiddleware + + +__all__ = [ + "KnowledgeBaseSearchMiddleware", + "KnowledgePriorityMiddleware", + "browse_recent_documents", + "fetch_mentioned_documents", + "search_knowledge_base", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/knowledge_tree.py b/surfsense_backend/app/agents/new_chat/middleware/knowledge_tree.py new file mode 100644 index 000000000..467d19747 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/knowledge_tree.py @@ -0,0 +1,272 @@ +"""Workspace-tree middleware for the SurfSense agent. + +Renders the full ``Folder``+``Document`` tree under ``/documents/`` once per +turn (cloud only), caches it by ``(search_space_id, tree_version)``, and +injects the result as a ```` system message immediately +before the latest human turn. + +The render is bounded by two truncation layers: + +1. **Entry cap** — at most ``MAX_TREE_ENTRIES`` lines. The remainder is + replaced with a "use ls" hint. +2. **Token cap** — at most ``MAX_TREE_TOKENS`` tokens (using the LLM's + token-count profile when available). If the entry-truncated tree still + exceeds the token cap we fall back to a root-only summary. + +Anonymous mode renders only ``state['kb_anon_doc']`` (no DB calls). + +This middleware also performs a one-time initialization of ``state['cwd']`` +to ``"/documents"`` so subsequent middlewares and tools always see a valid +cwd in cloud mode. +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any + +from langchain.agents.middleware import AgentMiddleware, AgentState +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import SystemMessage +from langgraph.runtime import Runtime +from sqlalchemy import select + +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState +from app.agents.new_chat.path_resolver import ( + DOCUMENTS_ROOT, + PathIndex, + build_path_index, + doc_to_virtual_path, +) +from app.db import Document, shielded_async_session + +try: + from litellm import token_counter +except Exception: # pragma: no cover - optional dep + token_counter = None # type: ignore[assignment] + +logger = logging.getLogger(__name__) + + +MAX_TREE_ENTRIES = 500 +MAX_TREE_TOKENS = 4000 + + +def _approx_tokens(text: str) -> int: + """Cheap fallback token estimate (1 token ~= 4 chars).""" + return max(1, (len(text) + 3) // 4) + + +def _count_tokens(text: str, *, llm: BaseChatModel | None) -> int: + if llm is None: + return _approx_tokens(text) + count_fn = getattr(llm, "_count_tokens", None) + if callable(count_fn): + try: + return int(count_fn([{"role": "user", "content": text}])) + except Exception: + pass + profile = getattr(llm, "profile", None) + model_names: list[str] = [] + if isinstance(profile, dict): + tcms = profile.get("token_count_models") + if isinstance(tcms, list): + model_names.extend(name for name in tcms if isinstance(name, str) and name) + tcm = profile.get("token_count_model") + if isinstance(tcm, str) and tcm and tcm not in model_names: + model_names.append(tcm) + model_name = model_names[0] if model_names else getattr(llm, "model", None) + if not isinstance(model_name, str) or not model_name or token_counter is None: + return _approx_tokens(text) + try: + return int( + token_counter( + messages=[{"role": "user", "content": text}], + model=model_name, + ) + ) + except Exception: + return _approx_tokens(text) + + +class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg] + """Inject the workspace folder/document tree into the agent's context.""" + + tools = () + state_schema = SurfSenseFilesystemState + + def __init__( + self, + *, + search_space_id: int, + filesystem_mode: FilesystemMode, + llm: BaseChatModel | None = None, + max_entries: int = MAX_TREE_ENTRIES, + max_tokens: int = MAX_TREE_TOKENS, + ) -> None: + self.search_space_id = search_space_id + self.filesystem_mode = filesystem_mode + self.llm = llm + self.max_entries = max_entries + self.max_tokens = max_tokens + self._cache: dict[tuple[int, int, bool], str] = {} + + async def abefore_agent( # type: ignore[override] + self, + state: AgentState, + runtime: Runtime[Any], + ) -> dict[str, Any] | None: + del runtime + if self.filesystem_mode != FilesystemMode.CLOUD: + return None + + update: dict[str, Any] = {} + if not state.get("cwd"): + update["cwd"] = DOCUMENTS_ROOT + + anon_doc = state.get("kb_anon_doc") + if anon_doc: + tree_msg = self._render_anon_tree(anon_doc) + else: + tree_msg = await self._render_kb_tree(state) + + messages = list(state.get("messages") or []) + insert_at = max(len(messages) - 1, 0) + messages.insert(insert_at, SystemMessage(content=tree_msg)) + update["messages"] = messages + return update + + def before_agent( # type: ignore[override] + self, + state: AgentState, + runtime: Runtime[Any], + ) -> dict[str, Any] | None: + try: + loop = asyncio.get_running_loop() + if loop.is_running(): + return None + except RuntimeError: + pass + return asyncio.run(self.abefore_agent(state, runtime)) + + # ------------------------------------------------------------------ render + + def _render_anon_tree(self, anon_doc: dict[str, Any]) -> str: + path = str(anon_doc.get("path") or "") + title = str(anon_doc.get("title") or "uploaded_document") + return ( + "\n" + "Anonymous session — only one read-only document is available.\n" + f"{DOCUMENTS_ROOT}/\n" + f" {path} — {title}\n" + "" + ) + + async def _render_kb_tree(self, state: AgentState) -> str: + version = int(state.get("tree_version") or 0) + cache_key = (self.search_space_id, version, False) + cached = self._cache.get(cache_key) + if cached is not None: + return cached + + try: + async with shielded_async_session() as session: + index = await build_path_index(session, self.search_space_id) + doc_rows = await session.execute( + select(Document.id, Document.title, Document.folder_id).where( + Document.search_space_id == self.search_space_id + ) + ) + docs = list(doc_rows.all()) + except Exception as exc: # pragma: no cover - defensive + logger.warning("knowledge_tree: DB error %s", exc) + return "\n(unavailable)\n" + + rendered = self._format_tree(index, docs) + self._cache[cache_key] = rendered + return rendered + + def _format_tree(self, index: PathIndex, docs: list[Any]) -> str: + folder_paths = sorted(set(index.folder_paths.values())) + doc_paths = sorted( + doc_to_virtual_path( + doc_id=row.id, + title=str(row.title or "untitled"), + folder_id=row.folder_id, + index=index, + ) + for row in docs + ) + all_paths = sorted(set(folder_paths + doc_paths + [DOCUMENTS_ROOT])) + + lines: list[str] = [] + for path in all_paths: + depth = ( + 0 + if path == DOCUMENTS_ROOT + else len([p for p in path[len(DOCUMENTS_ROOT) :].split("/") if p]) + ) + indent = " " * depth + is_dir = path == DOCUMENTS_ROOT or path in folder_paths + display = ( + path.rsplit("/", 1)[-1] if path != DOCUMENTS_ROOT else "/documents" + ) + if is_dir: + lines.append(f"{indent}{display}/") + else: + lines.append(f"{indent}{display}") + if len(lines) >= self.max_entries: + remaining = len(all_paths) - len(lines) + if remaining > 0: + lines.append( + f"... {remaining} more entries — use " + "ls('/documents/', offset, limit) to expand" + ) + break + + body = "\n".join(lines) + rendered = f"\n{body}\n" + + token_count = _count_tokens(rendered, llm=self.llm) + if token_count <= self.max_tokens: + return rendered + + return self._format_root_summary(folder_paths, doc_paths) + + def _format_root_summary( + self, folder_paths: list[str], doc_paths: list[str] + ) -> str: + top_level: dict[str, int] = {} + loose_docs = 0 + for path in doc_paths: + rel = path[len(DOCUMENTS_ROOT) :].lstrip("/") + if "/" in rel: + top = rel.split("/", 1)[0] + top_level[top] = top_level.get(top, 0) + 1 + else: + loose_docs += 1 + for path in folder_paths: + rel = path[len(DOCUMENTS_ROOT) :].lstrip("/") + if not rel: + continue + top = rel.split("/", 1)[0] + top_level.setdefault(top, 0) + + lines = [DOCUMENTS_ROOT + "/"] + for name in sorted(top_level): + count = top_level[name] + lines.append(f" {name}/ ({count} document{'s' if count != 1 else ''})") + if loose_docs: + lines.append( + f" ({loose_docs} loose document{'s' if loose_docs != 1 else ''})" + ) + lines.append( + "Tree is large; use list_tree('/documents/') to drill in " + "or ls('/documents/', offset, limit) for paginated listings." + ) + return "\n" + "\n".join(lines) + "\n" + + +__all__ = ["KnowledgeTreeMiddleware"] diff --git a/surfsense_backend/app/agents/new_chat/path_resolver.py b/surfsense_backend/app/agents/new_chat/path_resolver.py new file mode 100644 index 000000000..861f48ee7 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/path_resolver.py @@ -0,0 +1,351 @@ +"""Canonical virtual-path resolver for SurfSense knowledge-base documents. + +This module is the single source of truth for mapping ``Document`` rows to +virtual paths under ``/documents/`` and back. It is used by: + +* :class:`KnowledgeTreeMiddleware` (rendering the workspace tree) +* :class:`KnowledgePriorityMiddleware` (computing priority paths) +* :class:`KBPostgresBackend` (``als_info`` / ``aread`` / move operations) +* :class:`KnowledgeBasePersistenceMiddleware` (resolving moves and creates) + +Centralising the logic ensures that title-collision suffixes, folder paths, +and ``unique_identifier_hash`` lookups never drift between renders and +commits. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, field + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import Document, DocumentType, Folder +from app.utils.document_converters import generate_unique_identifier_hash + +DOCUMENTS_ROOT = "/documents" +"""Root virtual folder for all KB documents.""" + +_INVALID_FILENAME_CHARS = re.compile(r"[\\/:*?\"<>|]+") +_WHITESPACE_RUN = re.compile(r"\s+") + + +def safe_filename(value: str, *, fallback: str = "untitled.xml") -> str: + """Convert arbitrary text into a filesystem-safe ``.xml`` filename.""" + name = _INVALID_FILENAME_CHARS.sub("_", value).strip() + name = _WHITESPACE_RUN.sub(" ", name) + if not name: + name = fallback + if len(name) > 180: + name = name[:180].rstrip() + if not name.lower().endswith(".xml"): + name = f"{name}.xml" + return name + + +def safe_folder_segment(value: str, *, fallback: str = "folder") -> str: + """Sanitize a single folder name into a path-safe segment.""" + name = _INVALID_FILENAME_CHARS.sub("_", value).strip() + name = _WHITESPACE_RUN.sub(" ", name) + if not name: + return fallback + if len(name) > 180: + name = name[:180].rstrip() + return name + + +def _suffix_with_doc_id(filename: str, doc_id: int | None) -> str: + if doc_id is None: + return filename + if not filename.lower().endswith(".xml"): + return f"{filename} ({doc_id}).xml" + stem = filename[:-4] + return f"{stem} ({doc_id}).xml" + + +_SUFFIX_PATTERN = re.compile(r"\s\((\d+)\)\.xml$", re.IGNORECASE) + + +def parse_doc_id_suffix(filename: str) -> tuple[str, int | None]: + """Strip a trailing ``" ().xml"`` suffix; return ``(stem, doc_id)``. + + If no suffix is present, returns ``(stem_without_xml_extension, None)``. + """ + match = _SUFFIX_PATTERN.search(filename) + if match: + doc_id = int(match.group(1)) + stem = filename[: match.start()] + return stem, doc_id + if filename.lower().endswith(".xml"): + return filename[:-4], None + return filename, None + + +@dataclass +class PathIndex: + """In-memory occupancy snapshot used by :func:`doc_to_virtual_path`. + + Built once per call site so collision handling is deterministic and so + we don't perform N folder lookups per render. + """ + + folder_paths: dict[int, str] = field(default_factory=dict) + """``Folder.id`` -> absolute virtual folder path under ``/documents``.""" + + occupants: dict[str, int] = field(default_factory=dict) + """virtual path -> ``Document.id`` already occupying that path (this render).""" + + +async def _build_folder_paths( + session: AsyncSession, + search_space_id: int, +) -> dict[int, str]: + """Compute ``Folder.id`` -> absolute virtual path under ``/documents``.""" + result = await session.execute( + select(Folder.id, Folder.name, Folder.parent_id).where( + Folder.search_space_id == search_space_id + ) + ) + rows = result.all() + by_id = {row.id: {"name": row.name, "parent_id": row.parent_id} for row in rows} + cache: dict[int, str] = {} + + def resolve(folder_id: int) -> str: + if folder_id in cache: + return cache[folder_id] + parts: list[str] = [] + cursor: int | None = folder_id + visited: set[int] = set() + while cursor is not None and cursor in by_id and cursor not in visited: + visited.add(cursor) + entry = by_id[cursor] + parts.append(safe_folder_segment(str(entry["name"]))) + cursor = entry["parent_id"] + parts.reverse() + path = f"{DOCUMENTS_ROOT}/" + "/".join(parts) if parts else DOCUMENTS_ROOT + cache[folder_id] = path + return path + + for folder_id in by_id: + resolve(folder_id) + return cache + + +async def build_path_index( + session: AsyncSession, + search_space_id: int, + *, + populate_occupants: bool = True, +) -> PathIndex: + """Build a :class:`PathIndex` for a search space. + + ``populate_occupants`` controls whether the occupancy map is pre-seeded + from existing ``Document`` rows. Most callers want this so that + :func:`doc_to_virtual_path` can detect collisions across the whole space; + the persistence middleware sets this to ``False`` when it is iterating to + decide where to place fresh documents. + """ + folder_paths = await _build_folder_paths(session, search_space_id) + occupants: dict[str, int] = {} + if populate_occupants: + rows = await session.execute( + select(Document.id, Document.title, Document.folder_id).where( + Document.search_space_id == search_space_id, + ) + ) + for row in rows.all(): + base = folder_paths.get(row.folder_id, DOCUMENTS_ROOT) + filename = safe_filename(str(row.title or "untitled")) + path = f"{base}/{filename}" + if path in occupants and occupants[path] != row.id: + path = f"{base}/{_suffix_with_doc_id(filename, row.id)}" + occupants[path] = row.id + return PathIndex(folder_paths=folder_paths, occupants=occupants) + + +def doc_to_virtual_path( + *, + doc_id: int | None, + title: str, + folder_id: int | None, + index: PathIndex, +) -> str: + """Return the canonical virtual path for a document. + + Mutates ``index.occupants`` so subsequent calls see this assignment and + deterministically pick a different suffix for the next colliding doc. + """ + base = index.folder_paths.get(folder_id, DOCUMENTS_ROOT) + filename = safe_filename(str(title or "untitled")) + path = f"{base}/{filename}" + occupant = index.occupants.get(path) + if occupant is not None and occupant != doc_id: + path = f"{base}/{_suffix_with_doc_id(filename, doc_id)}" + if doc_id is not None: + index.occupants[path] = doc_id + return path + + +async def virtual_path_to_doc( + session: AsyncSession, + *, + search_space_id: int, + virtual_path: str, +) -> Document | None: + """Resolve a virtual path back to a ``Document`` row. + + Resolution order: + 1. ``Document.unique_identifier_hash`` lookup (fast path for paths created + by SurfSense itself — every NOTE write goes through this hash). + 2. If the basename carries a ``" ().xml"`` disambiguation suffix, + try a direct id lookup constrained to the search space. + 3. Title-from-basename + folder-resolution lookup as a last resort. + """ + if not virtual_path or not virtual_path.startswith(DOCUMENTS_ROOT): + return None + + unique_hash = generate_unique_identifier_hash( + DocumentType.NOTE, + virtual_path, + search_space_id, + ) + result = await session.execute( + select(Document).where( + Document.search_space_id == search_space_id, + Document.unique_identifier_hash == unique_hash, + ) + ) + document = result.scalar_one_or_none() + if document is not None: + return document + + rel = virtual_path[len(DOCUMENTS_ROOT) :].lstrip("/") + if not rel: + return None + parts = [p for p in rel.split("/") if p] + if not parts: + return None + basename = parts[-1] + folder_parts = parts[:-1] + + stem, suffix_doc_id = parse_doc_id_suffix(basename) + if suffix_doc_id is not None: + result = await session.execute( + select(Document).where( + Document.search_space_id == search_space_id, + Document.id == suffix_doc_id, + ) + ) + document = result.scalar_one_or_none() + if document is not None: + return document + + folder_id = await _resolve_folder_id( + session, search_space_id=search_space_id, folder_parts=folder_parts + ) + title_candidates: list[str] = [] + raw_title = stem + title_candidates.append(raw_title) + if raw_title.endswith(".xml"): + title_candidates.append(raw_title[:-4]) + + for candidate in dict.fromkeys(title_candidates): + if not candidate: + continue + query = select(Document).where( + Document.search_space_id == search_space_id, + Document.title == candidate, + ) + if folder_id is None: + query = query.where(Document.folder_id.is_(None)) + else: + query = query.where(Document.folder_id == folder_id) + result = await session.execute(query) + document = result.scalars().first() + if document is not None: + return document + + # Fallback: title-as-string lookup misses when the real DB title contains + # characters that ``safe_filename`` lossily replaces (``:``, ``/``, ``*``, + # etc.) — common for connector-imported docs (Google Calendar/Drive etc.). + # The workspace tree shows the lossy filename, so the agent passes that + # filename back here. Scan all documents in the resolved folder and match + # by ``safe_filename(title)`` to recover the original document. + folder_scan = select(Document).where( + Document.search_space_id == search_space_id, + ) + if folder_id is None: + folder_scan = folder_scan.where(Document.folder_id.is_(None)) + else: + folder_scan = folder_scan.where(Document.folder_id == folder_id) + result = await session.execute(folder_scan) + for candidate_doc in result.scalars().all(): + encoded = safe_filename(str(candidate_doc.title or "untitled")) + if encoded == basename: + return candidate_doc + return None + + +async def _resolve_folder_id( + session: AsyncSession, + *, + search_space_id: int, + folder_parts: list[str], +) -> int | None: + """Look up the leaf folder id for a chain of folder names; return ``None`` if missing.""" + if not folder_parts: + return None + parent_id: int | None = None + for raw in folder_parts: + name = safe_folder_segment(raw) + query = select(Folder.id).where( + Folder.search_space_id == search_space_id, + Folder.name == name, + ) + if parent_id is None: + query = query.where(Folder.parent_id.is_(None)) + else: + query = query.where(Folder.parent_id == parent_id) + result = await session.execute(query) + row = result.first() + if row is None: + return None + parent_id = row[0] + return parent_id + + +def parse_documents_path(virtual_path: str) -> tuple[list[str], str]: + """Parse a ``/documents/...`` path into ``(folder_parts, document_title)``. + + The title has any ``.xml`` extension and trailing ``" ()"`` + disambiguation suffix stripped. + """ + if not virtual_path or not virtual_path.startswith(DOCUMENTS_ROOT): + return [], "" + rel = virtual_path[len(DOCUMENTS_ROOT) :].strip("/") + if not rel: + return [], "" + parts = [p for p in rel.split("/") if p] + if not parts: + return [], "" + folder_parts = parts[:-1] + basename = parts[-1] + stem, _ = parse_doc_id_suffix(basename) + title = stem + if title.endswith(".xml"): + title = title[:-4] + return folder_parts, title + + +__all__ = [ + "DOCUMENTS_ROOT", + "PathIndex", + "build_path_index", + "doc_to_virtual_path", + "parse_doc_id_suffix", + "parse_documents_path", + "safe_filename", + "safe_folder_segment", + "virtual_path_to_doc", +] diff --git a/surfsense_backend/app/agents/new_chat/state_reducers.py b/surfsense_backend/app/agents/new_chat/state_reducers.py new file mode 100644 index 000000000..ce32406e6 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/state_reducers.py @@ -0,0 +1,201 @@ +"""Reducers and sentinels for SurfSense filesystem state. + +These reducers back the extra state fields used by the cloud-mode filesystem +agent (`cwd`, `staged_dirs`, `pending_moves`, `dirty_paths`, `doc_id_by_path`, +`kb_priority`, `kb_matched_chunk_ids`, `kb_anon_doc`, `tree_version`). + +Tools mutate these fields ONLY via `Command(update={...})` returns; the +reducers are responsible for merging successive updates atomically and for +honouring an explicit reset sentinel (`_CLEAR`) so that a single update can +both reset and reseed a list (used by `move_file` / `aafter_agent`). + +The sentinel is intentionally a plain string constant rather than a custom +object so that LangGraph's checkpointer (which serializes raw `Command.update` +deltas via ``ormsgpack`` BEFORE reducers are applied) can round-trip writes +that contain it. The token uses a NUL-bracketed form that cannot collide with +any real virtual path, document title, or dict key produced by the agent. +""" + +from __future__ import annotations + +from typing import Any, Final, TypeVar + +_CLEAR: Final[str] = "\x00__SURFSENSE_FILESYSTEM_CLEAR__\x00" +"""Reset sentinel; pass it inside a list/dict update to request a reset. + +For list reducers: ``[_CLEAR, *items]`` resets the field then appends ``items``. +For dict reducers: ``{_CLEAR: True, **items}`` resets the field then merges ``items``. + +Because the value is a plain string with embedded NUL bytes, it is natively +serializable by ``ormsgpack`` (used by LangGraph's PostgreSQL checkpointer) +yet still distinct from any real path / key produced by application code. +""" + + +T = TypeVar("T") + + +def _replace_reducer[T](left: T | None, right: T | None) -> T | None: + """Replace `left` outright with `right`. ``None`` on the right is honored as a reset.""" + return right + + +def _is_clear(value: Any) -> bool: + return isinstance(value, str) and value == _CLEAR + + +def _add_unique_reducer( + left: list[Any] | None, + right: list[Any] | None, +) -> list[Any]: + """Append items from ``right`` to ``left`` while preserving uniqueness. + + Semantics: + - If ``right`` is ``None`` or empty, return ``left`` unchanged. + - If ``right`` contains the ``_CLEAR`` sentinel anywhere, the result is + reseeded with only the items that appear AFTER the LAST occurrence of + ``_CLEAR`` (deduplicated, preserving first-seen order). This gives a + single-update "reset and reseed" capability. + - Otherwise, items from ``right`` are appended to ``left`` (order preserved + from first seen) while skipping values that are already present. + """ + if right is None: + return list(left or []) + if not right: + return list(left or []) + + last_clear = -1 + for index, item in enumerate(right): + if _is_clear(item): + last_clear = index + + if last_clear >= 0: + seed: list[Any] = [] + seen: set[Any] = set() + for item in right[last_clear + 1 :]: + if _is_clear(item): + continue + try: + if item in seen: + continue + seen.add(item) + except TypeError: + if item in seed: + continue + seed.append(item) + return seed + + base = list(left or []) + try: + seen: set[Any] = set(base) + except TypeError: + seen = set() + for item in right: + if _is_clear(item): + continue + try: + if item in seen: + continue + seen.add(item) + except TypeError: + if item in base: + continue + base.append(item) + return base + + +def _list_append_reducer( + left: list[Any] | None, + right: list[Any] | None, +) -> list[Any]: + """Append items from ``right`` to ``left`` preserving order and duplicates. + + Honours the ``_CLEAR`` sentinel exactly like :func:`_add_unique_reducer`, + but does NOT deduplicate. Used for queues whose ordering and duplicate + occurrences matter (e.g. ``pending_moves``). + """ + if right is None: + return list(left or []) + if not right: + return list(left or []) + + last_clear = -1 + for index, item in enumerate(right): + if _is_clear(item): + last_clear = index + + if last_clear >= 0: + return [item for item in right[last_clear + 1 :] if not _is_clear(item)] + + base = list(left or []) + base.extend(item for item in right if not _is_clear(item)) + return base + + +def _dict_merge_with_tombstones_reducer( + left: dict[Any, Any] | None, + right: dict[Any, Any] | None, +) -> dict[Any, Any]: + """Merge ``right`` into ``left`` with two extra capabilities: + + * Keys whose value is ``None`` are removed from the merged result + (tombstone semantics, matching the deepagents file-data reducer). + * The special key ``_CLEAR`` (with any truthy value) resets ``left`` to + ``{}`` before merging the remaining keys from ``right``. This makes it + possible to atomically clear and reseed the dictionary in a single + update. + """ + if right is None: + return dict(left or {}) + + if _CLEAR in right or any(_is_clear(k) for k in right): + result: dict[Any, Any] = {} + for key, value in right.items(): + if _is_clear(key): + continue + if value is None: + result.pop(key, None) + continue + result[key] = value + return result + + if left is None: + return {key: value for key, value in right.items() if value is not None} + + result = dict(left) + for key, value in right.items(): + if value is None: + result.pop(key, None) + else: + result[key] = value + return result + + +def _initial_filesystem_state() -> dict[str, Any]: + """Default empty values for SurfSense filesystem state fields. + + Consumers should always treat these fields as ``state.get(key) or + DEFAULT`` so that fresh threads (without checkpointed state) work + correctly. + """ + return { + "cwd": "/documents", + "staged_dirs": [], + "pending_moves": [], + "doc_id_by_path": {}, + "dirty_paths": [], + "kb_priority": [], + "kb_matched_chunk_ids": {}, + "kb_anon_doc": None, + "tree_version": 0, + } + + +__all__ = [ + "_CLEAR", + "_add_unique_reducer", + "_dict_merge_with_tombstones_reducer", + "_initial_filesystem_state", + "_list_append_reducer", + "_replace_reducer", +] diff --git a/surfsense_backend/app/agents/new_chat/system_prompt.py b/surfsense_backend/app/agents/new_chat/system_prompt.py index e77132182..0c9426892 100644 --- a/surfsense_backend/app/agents/new_chat/system_prompt.py +++ b/surfsense_backend/app/agents/new_chat/system_prompt.py @@ -332,7 +332,7 @@ _TOOL_INSTRUCTIONS["scrape_webpage"] = """ * When a user asks to "get", "fetch", "pull", "grab", "scrape", or "read" content from a URL * When the user wants live/dynamic data from a specific webpage (e.g., tables, scores, stats, prices) * When a URL was mentioned earlier in the conversation and the user asks for its actual content - * When preloaded `/documents/` data is insufficient and the user wants more + * When `/documents/` knowledge-base data is insufficient and the user wants more - Trigger scenarios: * "Read this article and summarize it" * "What does this page say about X?" diff --git a/surfsense_backend/app/agents/new_chat/tools/generate_image.py b/surfsense_backend/app/agents/new_chat/tools/generate_image.py index d94d55b1a..3803fa39c 100644 --- a/surfsense_backend/app/agents/new_chat/tools/generate_image.py +++ b/surfsense_backend/app/agents/new_chat/tools/generate_image.py @@ -20,7 +20,12 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.config import config -from app.db import ImageGeneration, ImageGenerationConfig, SearchSpace +from app.db import ( + ImageGeneration, + ImageGenerationConfig, + SearchSpace, + shielded_async_session, +) from app.services.image_gen_router_service import ( IMAGE_GEN_AUTO_MODE_ID, ImageGenRouterService, @@ -70,8 +75,13 @@ def create_generate_image_tool( Args: search_space_id: The search space ID (for config resolution) - db_session: Async database session + db_session: Reserved for compatibility with the tool registry. + The streaming task's ``AsyncSession`` is shared by every tool; + because AsyncSession is not concurrency-safe, parallel tool calls + would interleave flushes (e.g. podcast + image in the same step) + and poison the transaction. This tool opens its own session. """ + del db_session # use a fresh per-call session, see below @tool async def generate_image( @@ -93,110 +103,119 @@ def create_generate_image_tool( A dictionary containing the generated image(s) for display in the chat. """ try: - # Resolve the image generation config from the search space preference - result = await db_session.execute( - select(SearchSpace).filter(SearchSpace.id == search_space_id) - ) - search_space = result.scalars().first() - if not search_space: - return {"error": "Search space not found"} - - config_id = ( - search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID - ) - - # Build generation kwargs - # NOTE: size, quality, and style are intentionally NOT passed. - # Different models support different values for these params - # (e.g. DALL-E 3 wants "hd"/"standard" for quality while - # gpt-image-1 wants "high"/"medium"/"low"; size options also - # differ). Letting the model use its own defaults avoids errors. - gen_kwargs: dict[str, Any] = {} - if n is not None and n > 1: - gen_kwargs["n"] = n - - # Call litellm based on config type - if is_image_gen_auto_mode(config_id): - if not ImageGenRouterService.is_initialized(): - return { - "error": "No image generation models configured. " - "Please add an image model in Settings > Image Models." - } - response = await ImageGenRouterService.aimage_generation( - prompt=prompt, model="auto", **gen_kwargs + # Use a per-call session so concurrent tool calls don't share an + # AsyncSession (which is not concurrency-safe). The streaming + # task's session is shared across every tool; without isolation, + # autoflushes from a concurrent writer poison this tool too. + async with shielded_async_session() as session: + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) ) - elif config_id < 0: - cfg = _get_global_image_gen_config(config_id) - if not cfg: - return {"error": f"Image generation config {config_id} not found"} + search_space = result.scalars().first() + if not search_space: + return {"error": "Search space not found"} - model_string = _build_model_string( - cfg.get("provider", ""), - cfg["model_name"], - cfg.get("custom_provider"), + config_id = ( + search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID ) - gen_kwargs["api_key"] = cfg.get("api_key") - if cfg.get("api_base"): - gen_kwargs["api_base"] = cfg["api_base"] - if cfg.get("api_version"): - gen_kwargs["api_version"] = cfg["api_version"] - if cfg.get("litellm_params"): - gen_kwargs.update(cfg["litellm_params"]) - response = await aimage_generation( - prompt=prompt, model=model_string, **gen_kwargs - ) - else: - # Positive ID = user-created ImageGenerationConfig - cfg_result = await db_session.execute( - select(ImageGenerationConfig).filter( - ImageGenerationConfig.id == config_id + # Build generation kwargs + # NOTE: size, quality, and style are intentionally NOT passed. + # Different models support different values for these params + # (e.g. DALL-E 3 wants "hd"/"standard" for quality while + # gpt-image-1 wants "high"/"medium"/"low"; size options also + # differ). Letting the model use its own defaults avoids errors. + gen_kwargs: dict[str, Any] = {} + if n is not None and n > 1: + gen_kwargs["n"] = n + + # Call litellm based on config type + if is_image_gen_auto_mode(config_id): + if not ImageGenRouterService.is_initialized(): + return { + "error": "No image generation models configured. " + "Please add an image model in Settings > Image Models." + } + response = await ImageGenRouterService.aimage_generation( + prompt=prompt, model="auto", **gen_kwargs ) - ) - db_cfg = cfg_result.scalars().first() - if not db_cfg: - return {"error": f"Image generation config {config_id} not found"} + elif config_id < 0: + cfg = _get_global_image_gen_config(config_id) + if not cfg: + return { + "error": f"Image generation config {config_id} not found" + } - model_string = _build_model_string( - db_cfg.provider.value, - db_cfg.model_name, - db_cfg.custom_provider, - ) - gen_kwargs["api_key"] = db_cfg.api_key - if db_cfg.api_base: - gen_kwargs["api_base"] = db_cfg.api_base - if db_cfg.api_version: - gen_kwargs["api_version"] = db_cfg.api_version - if db_cfg.litellm_params: - gen_kwargs.update(db_cfg.litellm_params) + model_string = _build_model_string( + cfg.get("provider", ""), + cfg["model_name"], + cfg.get("custom_provider"), + ) + gen_kwargs["api_key"] = cfg.get("api_key") + if cfg.get("api_base"): + gen_kwargs["api_base"] = cfg["api_base"] + if cfg.get("api_version"): + gen_kwargs["api_version"] = cfg["api_version"] + if cfg.get("litellm_params"): + gen_kwargs.update(cfg["litellm_params"]) - response = await aimage_generation( - prompt=prompt, model=model_string, **gen_kwargs + response = await aimage_generation( + prompt=prompt, model=model_string, **gen_kwargs + ) + else: + # Positive ID = user-created ImageGenerationConfig + cfg_result = await session.execute( + select(ImageGenerationConfig).filter( + ImageGenerationConfig.id == config_id + ) + ) + db_cfg = cfg_result.scalars().first() + if not db_cfg: + return { + "error": f"Image generation config {config_id} not found" + } + + model_string = _build_model_string( + db_cfg.provider.value, + db_cfg.model_name, + db_cfg.custom_provider, + ) + gen_kwargs["api_key"] = db_cfg.api_key + if db_cfg.api_base: + gen_kwargs["api_base"] = db_cfg.api_base + if db_cfg.api_version: + gen_kwargs["api_version"] = db_cfg.api_version + if db_cfg.litellm_params: + gen_kwargs.update(db_cfg.litellm_params) + + response = await aimage_generation( + prompt=prompt, model=model_string, **gen_kwargs + ) + + # Parse the response and store in DB + response_dict = ( + response.model_dump() + if hasattr(response, "model_dump") + else dict(response) ) - # Parse the response and store in DB - response_dict = ( - response.model_dump() - if hasattr(response, "model_dump") - else dict(response) - ) + # Generate a random access token for this image + access_token = generate_image_token() - # Generate a random access token for this image - access_token = generate_image_token() - - # Save to image_generations table for history - db_image_gen = ImageGeneration( - prompt=prompt, - model=getattr(response, "_hidden_params", {}).get("model"), - n=n, - image_generation_config_id=config_id, - response_data=response_dict, - search_space_id=search_space_id, - access_token=access_token, - ) - db_session.add(db_image_gen) - await db_session.commit() - await db_session.refresh(db_image_gen) + # Save to image_generations table for history + db_image_gen = ImageGeneration( + prompt=prompt, + model=getattr(response, "_hidden_params", {}).get("model"), + n=n, + image_generation_config_id=config_id, + response_data=response_dict, + search_space_id=search_space_id, + access_token=access_token, + ) + session.add(db_image_gen) + await session.commit() + await session.refresh(db_image_gen) + db_image_gen_id = db_image_gen.id # Extract image URLs from response images = response_dict.get("data", []) @@ -217,7 +236,7 @@ def create_generate_image_tool( backend_url = config.BACKEND_URL or "http://localhost:8000" image_url = ( f"{backend_url}/api/v1/image-generations/" - f"{db_image_gen.id}/image?token={access_token}" + f"{db_image_gen_id}/image?token={access_token}" ) else: return {"error": "No displayable image data in the response"} diff --git a/surfsense_backend/app/agents/new_chat/tools/podcast.py b/surfsense_backend/app/agents/new_chat/tools/podcast.py index 248a4f450..2c9b7fa0c 100644 --- a/surfsense_backend/app/agents/new_chat/tools/podcast.py +++ b/surfsense_backend/app/agents/new_chat/tools/podcast.py @@ -11,7 +11,7 @@ from typing import Any from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.db import Podcast, PodcastStatus +from app.db import Podcast, PodcastStatus, shielded_async_session def create_generate_podcast_tool( @@ -27,12 +27,16 @@ def create_generate_podcast_tool( Args: search_space_id: The user's search space ID - db_session: Database session for creating the podcast record + db_session: Reserved for future read-side use; the row is written via a + fresh, tool-local session so parallel tool calls (e.g. podcast + + video presentation in the same agent step) don't share an + ``AsyncSession`` (which is not concurrency-safe). thread_id: The chat thread ID for associating the podcast Returns: A configured tool function for generating podcasts """ + del db_session # writes use a fresh tool-local session, see below @tool async def generate_podcast( @@ -64,32 +68,40 @@ def create_generate_podcast_tool( - message: Status message (or "error" field if status is failed) """ try: - podcast = Podcast( - title=podcast_title, - status=PodcastStatus.PENDING, - search_space_id=search_space_id, - thread_id=thread_id, - ) - db_session.add(podcast) - await db_session.commit() - await db_session.refresh(podcast) + # Open a fresh session per call. The streaming task's session is + # shared between every tool, and ``AsyncSession`` is NOT safe for + # concurrent use: when the LLM emits parallel tool calls, two + # concurrent ``add()`` / ``commit()`` paths interleave and the + # second one hits "Session.add() during flush" → the transaction + # is poisoned for both tools. + async with shielded_async_session() as session: + podcast = Podcast( + title=podcast_title, + status=PodcastStatus.PENDING, + search_space_id=search_space_id, + thread_id=thread_id, + ) + session.add(podcast) + await session.commit() + await session.refresh(podcast) + podcast_id = podcast.id from app.tasks.celery_tasks.podcast_tasks import ( generate_content_podcast_task, ) task = generate_content_podcast_task.delay( - podcast_id=podcast.id, + podcast_id=podcast_id, source_content=source_content, search_space_id=search_space_id, user_prompt=user_prompt, ) - print(f"[generate_podcast] Created podcast {podcast.id}, task: {task.id}") + print(f"[generate_podcast] Created podcast {podcast_id}, task: {task.id}") return { "status": PodcastStatus.PENDING.value, - "podcast_id": podcast.id, + "podcast_id": podcast_id, "title": podcast_title, "message": "Podcast generation started. This may take a few minutes.", } diff --git a/surfsense_backend/app/agents/new_chat/tools/video_presentation.py b/surfsense_backend/app/agents/new_chat/tools/video_presentation.py index a90e08ac3..7bf9a1c3b 100644 --- a/surfsense_backend/app/agents/new_chat/tools/video_presentation.py +++ b/surfsense_backend/app/agents/new_chat/tools/video_presentation.py @@ -11,7 +11,7 @@ from typing import Any from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.db import VideoPresentation, VideoPresentationStatus +from app.db import VideoPresentation, VideoPresentationStatus, shielded_async_session def create_generate_video_presentation_tool( @@ -23,8 +23,11 @@ def create_generate_video_presentation_tool( Factory function to create the generate_video_presentation tool with injected dependencies. Pre-creates video presentation record with pending status so the ID is available - immediately for frontend polling. + immediately for frontend polling. The row is written via a fresh, tool-local + session so parallel tool calls (e.g. video + podcast in the same agent step) + don't share an ``AsyncSession`` (which is not concurrency-safe). """ + del db_session # writes use a fresh tool-local session, see below @tool async def generate_video_presentation( @@ -42,34 +45,40 @@ def create_generate_video_presentation_tool( user_prompt: Optional style/tone instructions. """ try: - video_pres = VideoPresentation( - title=video_title, - status=VideoPresentationStatus.PENDING, - search_space_id=search_space_id, - thread_id=thread_id, - ) - db_session.add(video_pres) - await db_session.commit() - await db_session.refresh(video_pres) + # See podcast.py for the rationale: parallel tool calls share the + # streaming session, and AsyncSession is not concurrency-safe — + # interleaved flushes produce "Session.add() during flush" and + # poison the transaction for every concurrent tool. + async with shielded_async_session() as session: + video_pres = VideoPresentation( + title=video_title, + status=VideoPresentationStatus.PENDING, + search_space_id=search_space_id, + thread_id=thread_id, + ) + session.add(video_pres) + await session.commit() + await session.refresh(video_pres) + video_pres_id = video_pres.id from app.tasks.celery_tasks.video_presentation_tasks import ( generate_video_presentation_task, ) task = generate_video_presentation_task.delay( - video_presentation_id=video_pres.id, + video_presentation_id=video_pres_id, source_content=source_content, search_space_id=search_space_id, user_prompt=user_prompt, ) print( - f"[generate_video_presentation] Created video presentation {video_pres.id}, task: {task.id}" + f"[generate_video_presentation] Created video presentation {video_pres_id}, task: {task.id}" ) return { "status": VideoPresentationStatus.PENDING.value, - "video_presentation_id": video_pres.id, + "video_presentation_id": video_pres_id, "title": video_title, "message": "Video presentation generation started. This may take a few minutes.", } diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 7239c57a5..9fc0325e5 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -30,7 +30,7 @@ from sqlalchemy.orm import selectinload from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent from app.agents.new_chat.checkpointer import get_checkpointer -from app.agents.new_chat.filesystem_selection import FilesystemSelection +from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection from app.agents.new_chat.llm_config import ( AgentConfig, create_chat_litellm_from_agent_config, @@ -42,6 +42,9 @@ from app.agents.new_chat.memory_extraction import ( extract_and_save_memory, extract_and_save_team_memory, ) +from app.agents.new_chat.middleware.kb_persistence import ( + commit_staged_filesystem_state, +) from app.db import ( ChatVisibility, NewChatMessage, @@ -258,6 +261,10 @@ async def _stream_agent_events( initial_step_id: str | None = None, initial_step_title: str = "", initial_step_items: list[str] | None = None, + *, + fallback_commit_search_space_id: int | None = None, + fallback_commit_created_by_id: str | None = None, + fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD, ) -> AsyncGenerator[str, None]: """Shared async generator that streams and formats astream_events from the agent. @@ -1280,6 +1287,40 @@ async def _stream_agent_events( state = await agent.aget_state(config) state_values = getattr(state, "values", {}) or {} + + # Safety net: if astream_events was cancelled before + # KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work + # (dirty_paths / staged_dirs / pending_moves) will still be in the + # checkpointed state. Run the SAME shared commit helper here so the + # turn's writes don't get lost on client disconnect, then push the + # delta back into the graph using `as_node=...` so reducers fire as if + # the after_agent hook produced it. + if ( + fallback_commit_filesystem_mode == FilesystemMode.CLOUD + and fallback_commit_search_space_id is not None + and ( + (state_values.get("dirty_paths") or []) + or (state_values.get("staged_dirs") or []) + or (state_values.get("pending_moves") or []) + ) + ): + try: + delta = await commit_staged_filesystem_state( + state_values, + search_space_id=fallback_commit_search_space_id, + created_by_id=fallback_commit_created_by_id, + filesystem_mode=fallback_commit_filesystem_mode, + dispatch_events=False, + ) + if delta: + await agent.aupdate_state( + config, + delta, + as_node="KnowledgeBasePersistenceMiddleware.after_agent", + ) + except Exception as exc: + _perf_log.warning("[stream_new_chat] safety-net commit failed: %s", exc) + contract_state = state_values.get("file_operation_contract") or {} contract_turn_id = contract_state.get("turn_id") current_turn_id = config.get("configurable", {}).get("turn_id", "") @@ -1814,6 +1855,13 @@ async def stream_new_chat( initial_step_id=initial_step_id, initial_step_title=initial_title, initial_step_items=initial_items, + fallback_commit_search_space_id=search_space_id, + fallback_commit_created_by_id=user_id, + fallback_commit_filesystem_mode=( + filesystem_selection.mode + if filesystem_selection + else FilesystemMode.CLOUD + ), ): if not _first_event_logged: _perf_log.info( @@ -2251,6 +2299,13 @@ async def stream_resume_chat( streaming_service=streaming_service, result=stream_result, step_prefix="thinking-resume", + fallback_commit_search_space_id=search_space_id, + fallback_commit_created_by_id=user_id, + fallback_commit_filesystem_mode=( + filesystem_selection.mode + if filesystem_selection + else FilesystemMode.CLOUD + ), ): if not _first_event_logged: _perf_log.info( diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_path_resolver.py b/surfsense_backend/tests/unit/agents/new_chat/test_path_resolver.py new file mode 100644 index 000000000..ddb20330d --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_path_resolver.py @@ -0,0 +1,198 @@ +"""Tests for canonical virtual-path resolver helpers.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.agents.new_chat.path_resolver import ( + DOCUMENTS_ROOT, + PathIndex, + doc_to_virtual_path, + parse_doc_id_suffix, + parse_documents_path, + safe_filename, + safe_folder_segment, + virtual_path_to_doc, +) + +pytestmark = pytest.mark.unit + + +class TestSafeFilename: + def test_appends_xml_extension(self): + assert safe_filename("notes").endswith(".xml") + + def test_strips_invalid_chars(self): + assert "/" not in safe_filename("a/b\\c.xml") + + def test_falls_back_when_empty(self): + assert safe_filename("").endswith(".xml") + assert safe_filename("///") == "untitled.xml" or safe_filename("///").endswith( + ".xml" + ) + + +class TestSafeFolderSegment: + def test_strips_path_separators(self): + assert "/" not in safe_folder_segment("a/b") + + def test_falls_back(self): + assert safe_folder_segment("") == "folder" + + +class TestParseDocIdSuffix: + def test_parses_suffix(self): + stem, doc_id = parse_doc_id_suffix("My Doc (42).xml") + assert stem == "My Doc" + assert doc_id == 42 + + def test_no_suffix_returns_none(self): + stem, doc_id = parse_doc_id_suffix("My Doc.xml") + assert stem == "My Doc" + assert doc_id is None + + def test_no_xml_extension(self): + stem, doc_id = parse_doc_id_suffix("plain") + assert stem == "plain" + assert doc_id is None + + +class TestDocToVirtualPath: + def test_root_when_no_folder(self): + index = PathIndex() + path = doc_to_virtual_path(doc_id=1, title="Hello", folder_id=None, index=index) + assert path == f"{DOCUMENTS_ROOT}/Hello.xml" + assert index.occupants[path] == 1 + + def test_collision_picks_doc_id_suffix(self): + index = PathIndex(occupants={f"{DOCUMENTS_ROOT}/Hello.xml": 7}) + path = doc_to_virtual_path(doc_id=8, title="Hello", folder_id=None, index=index) + assert path == f"{DOCUMENTS_ROOT}/Hello (8).xml" + assert index.occupants[path] == 8 + + def test_uses_folder_path_when_known(self): + index = PathIndex(folder_paths={5: f"{DOCUMENTS_ROOT}/notes"}) + path = doc_to_virtual_path(doc_id=2, title="A", folder_id=5, index=index) + assert path == f"{DOCUMENTS_ROOT}/notes/A.xml" + + +class TestParseDocumentsPath: + def test_extracts_folder_parts_and_title(self): + parts, title = parse_documents_path(f"{DOCUMENTS_ROOT}/foo/bar/baz.xml") + assert parts == ["foo", "bar"] + assert title == "baz" + + def test_strips_doc_id_suffix(self): + parts, title = parse_documents_path(f"{DOCUMENTS_ROOT}/foo/My Doc (12).xml") + assert parts == ["foo"] + assert title == "My Doc" + + def test_non_documents_returns_empty(self): + assert parse_documents_path("/other/x.xml") == ([], "") + + +def _result_from_scalars(rows: list): + """Build a fake SQLAlchemy ``Result`` whose ``.scalars().all()`` and + ``.scalars().first()`` yield ``rows``.""" + scalars = MagicMock() + scalars.all.return_value = list(rows) + scalars.first.return_value = rows[0] if rows else None + result = MagicMock() + result.scalars.return_value = scalars + result.scalar_one_or_none.return_value = None + result.first.return_value = None + return result + + +def _result_from_one(value): + result = MagicMock() + result.scalar_one_or_none.return_value = value + return result + + +class TestVirtualPathToDoc: + """Lookup must round-trip through ``safe_filename``'s lossy encoding. + + The workspace tree displays ``safe_filename(title)`` as the basename, so + when the agent passes that basename back to a tool (move/edit/read) the + resolver must find the original document even though characters like + ``:`` were replaced with ``_``. + """ + + @pytest.mark.asyncio + async def test_falls_back_to_safe_filename_match_when_title_lossy(self): + # A Google Calendar-style title that contains a colon — safe_filename + # rewrites the colon to ``_``, so the literal title-equality lookup + # would miss this row. + original_title = "Calendar: Happy birthday!" + encoded_basename = safe_filename(original_title) + assert encoded_basename == "Calendar_ Happy birthday!.xml" + + target_doc = SimpleNamespace(id=42, title=original_title, folder_id=None) + + session = MagicMock() + # Each ``await session.execute(...)`` returns a fresh canned result. + # Order matches the resolver's lookup steps: + # 1) unique_identifier_hash → no match + # 2) literal title match → no match (lossy encoding) + # 3) folder scan → returns the row whose title encodes to basename + session.execute = AsyncMock( + side_effect=[ + _result_from_one(None), + _result_from_scalars([]), + _result_from_scalars([target_doc]), + ] + ) + + document = await virtual_path_to_doc( + session, + search_space_id=5, + virtual_path=f"{DOCUMENTS_ROOT}/{encoded_basename}", + ) + assert document is target_doc + + @pytest.mark.asyncio + async def test_returns_none_when_no_doc_matches_safe_filename(self): + session = MagicMock() + session.execute = AsyncMock( + side_effect=[ + _result_from_one(None), + _result_from_scalars([]), + _result_from_scalars( + [SimpleNamespace(id=1, title="Something else", folder_id=None)] + ), + ] + ) + + document = await virtual_path_to_doc( + session, + search_space_id=5, + virtual_path=f"{DOCUMENTS_ROOT}/Calendar_ Happy birthday!.xml", + ) + assert document is None + + @pytest.mark.asyncio + async def test_literal_title_match_short_circuits_fallback(self): + # When the literal title query hits, the folder-scan fallback must + # NOT run (saves a query and avoids picking the wrong doc when two + # rows share a lossy encoding). + target_doc = SimpleNamespace(id=7, title="Plain Note", folder_id=None) + + session = MagicMock() + session.execute = AsyncMock( + side_effect=[ + _result_from_one(None), + _result_from_scalars([target_doc]), + ] + ) + + document = await virtual_path_to_doc( + session, + search_space_id=5, + virtual_path=f"{DOCUMENTS_ROOT}/Plain Note.xml", + ) + assert document is target_doc + assert session.execute.await_count == 2 diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_state_reducers.py b/surfsense_backend/tests/unit/agents/new_chat/test_state_reducers.py new file mode 100644 index 000000000..3caeb9a34 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_state_reducers.py @@ -0,0 +1,107 @@ +"""Tests for SurfSense filesystem state reducers.""" + +from __future__ import annotations + +import pytest + +from app.agents.new_chat.state_reducers import ( + _CLEAR, + _add_unique_reducer, + _dict_merge_with_tombstones_reducer, + _initial_filesystem_state, + _list_append_reducer, + _replace_reducer, +) + +pytestmark = pytest.mark.unit + + +class TestReplaceReducer: + def test_right_wins_outright(self): + assert _replace_reducer("a", "b") == "b" + + def test_none_right_returns_none(self): + assert _replace_reducer("a", None) is None + + def test_none_left_returns_right(self): + assert _replace_reducer(None, "b") == "b" + + +class TestAddUniqueReducer: + def test_appends_unique_items(self): + assert _add_unique_reducer(["a"], ["b", "c"]) == ["a", "b", "c"] + + def test_dedupes_against_left(self): + assert _add_unique_reducer(["a", "b"], ["b", "c"]) == ["a", "b", "c"] + + def test_dedupes_within_right(self): + assert _add_unique_reducer([], ["a", "a", "b"]) == ["a", "b"] + + def test_clear_anywhere_resets_and_reseeds_with_after_items(self): + # _CLEAR semantics: only items AFTER the LAST _CLEAR are kept. + result = _add_unique_reducer(["x", "y"], ["a", _CLEAR, "b", "c"]) + assert result == ["b", "c"] + + def test_multiple_clears_use_last(self): + result = _add_unique_reducer(["x"], [_CLEAR, "a", _CLEAR, "b"]) + assert result == ["b"] + + def test_clear_only_resets_to_empty(self): + assert _add_unique_reducer(["x", "y"], [_CLEAR]) == [] + + def test_empty_right_keeps_left(self): + assert _add_unique_reducer(["a"], []) == ["a"] + assert _add_unique_reducer(["a"], None) == ["a"] + + +class TestListAppendReducer: + def test_preserves_order_and_duplicates(self): + result = _list_append_reducer([{"a": 1}], [{"b": 2}, {"a": 1}]) + assert result == [{"a": 1}, {"b": 2}, {"a": 1}] + + def test_clear_resets_keeping_after_items(self): + result = _list_append_reducer([{"a": 1}], [{"old": 1}, _CLEAR, {"new": 2}]) + assert result == [{"new": 2}] + + +class TestDictMergeWithTombstones: + def test_merges_keys(self): + assert _dict_merge_with_tombstones_reducer({"a": 1}, {"b": 2}) == { + "a": 1, + "b": 2, + } + + def test_none_value_deletes_key(self): + result = _dict_merge_with_tombstones_reducer({"a": 1, "b": 2}, {"a": None}) + assert result == {"b": 2} + + def test_clear_resets_then_merges(self): + result = _dict_merge_with_tombstones_reducer( + {"a": 1, "b": 2}, {_CLEAR: True, "c": 3} + ) + assert result == {"c": 3} + + def test_clear_keeps_only_post_clear_non_none(self): + result = _dict_merge_with_tombstones_reducer( + {"a": 1}, {_CLEAR: True, "b": 2, "c": None} + ) + assert result == {"b": 2} + + def test_none_left_handled(self): + assert _dict_merge_with_tombstones_reducer(None, {"a": 1, "b": None}) == { + "a": 1 + } + + +class TestInitialFilesystemState: + def test_default_shape(self): + state = _initial_filesystem_state() + assert state["cwd"] == "/documents" + assert state["staged_dirs"] == [] + assert state["pending_moves"] == [] + assert state["doc_id_by_path"] == {} + assert state["dirty_paths"] == [] + assert state["kb_priority"] == [] + assert state["kb_matched_chunk_ids"] == {} + assert state["kb_anon_doc"] is None + assert state["tree_version"] == 0 diff --git a/surfsense_backend/tests/unit/middleware/test_filesystem_backends.py b/surfsense_backend/tests/unit/middleware/test_filesystem_backends.py index 98996d6bc..c71b5efde 100644 --- a/surfsense_backend/tests/unit/middleware/test_filesystem_backends.py +++ b/surfsense_backend/tests/unit/middleware/test_filesystem_backends.py @@ -36,11 +36,18 @@ def test_backend_resolver_returns_multi_root_backend_for_single_root(tmp_path: P def test_backend_resolver_uses_cloud_mode_by_default(): resolver = build_backend_resolver(FilesystemSelection()) backend = resolver(_RuntimeStub()) - # StateBackend class name check keeps this test decoupled - # from internal deepagents runtime class identity. + # When no search_space_id is provided we fall back to StateBackend so + # sub-agents / tests without DB access still work. assert backend.__class__.__name__ == "StateBackend" +def test_backend_resolver_uses_kb_postgres_in_cloud_with_search_space(): + resolver = build_backend_resolver(FilesystemSelection(), search_space_id=42) + backend = resolver(_RuntimeStub()) + assert backend.__class__.__name__ == "KBPostgresBackend" + assert backend.search_space_id == 42 + + def test_backend_resolver_returns_multi_root_backend_for_multiple_roots(tmp_path: Path): root_one = tmp_path / "resume" root_two = tmp_path / "notes" diff --git a/surfsense_backend/tests/unit/middleware/test_filesystem_middleware.py b/surfsense_backend/tests/unit/middleware/test_filesystem_middleware.py new file mode 100644 index 000000000..c2e304399 --- /dev/null +++ b/surfsense_backend/tests/unit/middleware/test_filesystem_middleware.py @@ -0,0 +1,204 @@ +"""Unit tests for the SurfSense filesystem middleware new behaviors. + +Covers: +* cloud cwd defaults to ``/documents`` and relative paths resolve under it +* cloud writes outside ``/documents/`` are rejected unless basename starts + with ``temp_`` +* cloud writes/edits to the anonymous document are rejected (read-only) +* helper methods on the middleware (``_resolve_relative``, + ``_check_cloud_write_namespace``, ``_default_cwd``) + +These tests use ``__new__`` to bypass the heavy ``__init__`` and exercise +the helper methods directly so the test surface stays narrow and fast. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.middleware.filesystem import ( + SurfSenseFilesystemMiddleware, + _build_filesystem_system_prompt, + _build_tool_descriptions, +) + +pytestmark = pytest.mark.unit + + +def _make_middleware(mode: FilesystemMode = FilesystemMode.CLOUD): + middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) + middleware._filesystem_mode = mode + return middleware + + +def _runtime(state: dict | None = None) -> SimpleNamespace: + return SimpleNamespace(state=state or {}) + + +class TestCloudCwdDefaults: + def test_default_cwd_in_cloud_is_documents_root(self): + m = _make_middleware() + assert m._default_cwd() == "/documents" + + def test_default_cwd_in_desktop_is_root(self): + m = _make_middleware(FilesystemMode.DESKTOP_LOCAL_FOLDER) + assert m._default_cwd() == "/" + + def test_current_cwd_uses_state_when_set(self): + m = _make_middleware() + runtime = _runtime({"cwd": "/documents/notes"}) + assert m._current_cwd(runtime) == "/documents/notes" + + def test_current_cwd_falls_back_to_default(self): + m = _make_middleware() + runtime = _runtime({}) + assert m._current_cwd(runtime) == "/documents" + + def test_current_cwd_ignores_invalid(self): + m = _make_middleware() + runtime = _runtime({"cwd": "not-absolute"}) + assert m._current_cwd(runtime) == "/documents" + + +class TestRelativePathResolution: + def test_relative_path_resolves_against_cwd(self): + m = _make_middleware() + runtime = _runtime({"cwd": "/documents/projects"}) + assert ( + m._resolve_relative("notes.md", runtime) == "/documents/projects/notes.md" + ) + + def test_relative_path_with_dotdot(self): + m = _make_middleware() + runtime = _runtime({"cwd": "/documents/a/b"}) + assert m._resolve_relative("../c.md", runtime) == "/documents/a/c.md" + + def test_absolute_path_is_kept(self): + m = _make_middleware() + runtime = _runtime({"cwd": "/documents"}) + assert m._resolve_relative("/other/x.md", runtime) == "/other/x.md" + + def test_empty_path_returns_cwd(self): + m = _make_middleware() + runtime = _runtime({"cwd": "/documents/projects"}) + assert m._resolve_relative("", runtime) == "/documents/projects" + + +class TestCloudWriteNamespacePolicy: + def test_documents_path_allowed(self): + m = _make_middleware() + runtime = _runtime() + assert m._check_cloud_write_namespace("/documents/foo.md", runtime) is None + + def test_documents_root_allowed(self): + m = _make_middleware() + runtime = _runtime() + assert m._check_cloud_write_namespace("/documents", runtime) is None + + def test_temp_basename_anywhere_allowed(self): + m = _make_middleware() + runtime = _runtime() + assert m._check_cloud_write_namespace("/temp_scratch.md", runtime) is None + assert m._check_cloud_write_namespace("/foo/temp_x.md", runtime) is None + assert m._check_cloud_write_namespace("/documents/temp_x.md", runtime) is None + + def test_other_paths_rejected(self): + m = _make_middleware() + runtime = _runtime() + err = m._check_cloud_write_namespace("/foo/bar.md", runtime) + assert err is not None + assert "must target /documents" in err + + def test_anon_doc_path_is_read_only(self): + m = _make_middleware() + runtime = _runtime( + { + "kb_anon_doc": { + "path": "/documents/uploaded.xml", + "title": "uploaded", + "content": "", + "chunks": [], + } + } + ) + err = m._check_cloud_write_namespace("/documents/uploaded.xml", runtime) + assert err is not None + assert "read-only" in err + + def test_desktop_mode_skips_namespace_policy(self): + m = _make_middleware(FilesystemMode.DESKTOP_LOCAL_FOLDER) + runtime = _runtime() + assert m._check_cloud_write_namespace("/random/path.md", runtime) is None + + +class TestModeSpecificPrompts: + """The prompt and tool descriptions must only describe the active mode. + + Cross-mode noise wastes tokens and confuses the model with rules it + cannot use this session. + """ + + def test_cloud_prompt_omits_desktop_section(self): + prompt = _build_filesystem_system_prompt( + FilesystemMode.CLOUD, sandbox_available=False + ) + assert "Local Folder Mode" not in prompt + assert "mount-prefixed" not in prompt + assert "Persistence Rules" in prompt + assert "/documents" in prompt + assert "temp_" in prompt + + def test_desktop_prompt_omits_cloud_persistence_rules(self): + prompt = _build_filesystem_system_prompt( + FilesystemMode.DESKTOP_LOCAL_FOLDER, sandbox_available=False + ) + assert "Persistence Rules" not in prompt + assert "Workspace Tree" not in prompt + assert "" not in prompt + assert "Local Folder Mode" in prompt + assert "mount-prefixed" in prompt + + def test_cloud_tool_descs_omit_desktop_phrases(self): + descs = _build_tool_descriptions(FilesystemMode.CLOUD) + for name in ( + "write_file", + "edit_file", + "move_file", + "mkdir", + "list_tree", + "grep", + ): + text = descs[name] + assert "Desktop" not in text, f"{name} leaks desktop hints" + assert "Cloud mode:" not in text, f"{name} qualifies a cloud-only desc" + + def test_desktop_tool_descs_omit_cloud_phrases(self): + descs = _build_tool_descriptions(FilesystemMode.DESKTOP_LOCAL_FOLDER) + for name in ( + "write_file", + "edit_file", + "move_file", + "mkdir", + "list_tree", + "grep", + ): + text = descs[name] + assert "Cloud" not in text, f"{name} leaks cloud hints" + assert "/documents/" not in text, f"{name} mentions cloud namespace" + assert "temp_" not in text, f"{name} mentions cloud temp_ semantics" + + def test_sandbox_addendum_appended_when_available(self): + prompt = _build_filesystem_system_prompt( + FilesystemMode.CLOUD, sandbox_available=True + ) + assert "execute_code" in prompt + assert "Code Execution" in prompt + + def test_sandbox_addendum_absent_when_unavailable(self): + prompt = _build_filesystem_system_prompt( + FilesystemMode.CLOUD, sandbox_available=False + ) + assert "execute_code" not in prompt diff --git a/surfsense_backend/tests/unit/middleware/test_filesystem_verification.py b/surfsense_backend/tests/unit/middleware/test_filesystem_verification.py index cca15e789..81cf590d3 100644 --- a/surfsense_backend/tests/unit/middleware/test_filesystem_verification.py +++ b/surfsense_backend/tests/unit/middleware/test_filesystem_verification.py @@ -11,25 +11,6 @@ from app.agents.new_chat.middleware.multi_root_local_folder_backend import ( pytestmark = pytest.mark.unit -class _BackendWithRawRead: - def __init__(self, content: str) -> None: - self._content = content - - def read(self, file_path: str, offset: int = 0, limit: int = 200000) -> str: - del file_path, offset, limit - return " 1\tline1\n 2\tline2" - - async def aread(self, file_path: str, offset: int = 0, limit: int = 200000) -> str: - return self.read(file_path, offset, limit) - - def read_raw(self, file_path: str) -> str: - del file_path - return self._content - - async def aread_raw(self, file_path: str) -> str: - return self.read_raw(file_path) - - class _RuntimeNoSuggestedPath: state = {"file_operation_contract": {}} @@ -39,40 +20,19 @@ class _RuntimeWithSuggestedPath: self.state = {"file_operation_contract": {"suggested_path": suggested_path}} -def test_verify_written_content_prefers_raw_sync() -> None: - middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) - expected = "line1\nline2" - backend = _BackendWithRawRead(expected) - - verify_error = middleware._verify_written_content_sync( - backend=backend, - path="/note.md", - expected_content=expected, - ) - - assert verify_error is None - - -def test_contract_suggested_path_falls_back_to_notes_md() -> None: +def test_contract_suggested_path_falls_back_to_documents_notes_md() -> None: middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) middleware._filesystem_mode = FilesystemMode.CLOUD suggested = middleware._get_contract_suggested_path(_RuntimeNoSuggestedPath()) # type: ignore[arg-type] - assert suggested == "/notes.md" + # Cloud default cwd is /documents so the fallback lands in the KB. + assert suggested == "/documents/notes.md" -@pytest.mark.asyncio -async def test_verify_written_content_prefers_raw_async() -> None: +def test_contract_suggested_path_falls_back_to_root_notes_md_in_desktop() -> None: middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) - expected = "line1\nline2" - backend = _BackendWithRawRead(expected) - - verify_error = await middleware._verify_written_content_async( - backend=backend, - path="/note.md", - expected_content=expected, - ) - - assert verify_error is None + middleware._filesystem_mode = FilesystemMode.DESKTOP_LOCAL_FOLDER + suggested = middleware._get_contract_suggested_path(_RuntimeNoSuggestedPath()) # type: ignore[arg-type] + assert suggested == "/notes.md" def test_normalize_local_mount_path_prefixes_default_mount(tmp_path: Path) -> None: diff --git a/surfsense_backend/tests/unit/middleware/test_knowledge_search.py b/surfsense_backend/tests/unit/middleware/test_knowledge_search.py index 1aaf5d127..2ca470680 100644 --- a/surfsense_backend/tests/unit/middleware/test_knowledge_search.py +++ b/surfsense_backend/tests/unit/middleware/test_knowledge_search.py @@ -5,10 +5,10 @@ import json import pytest from langchain_core.messages import AIMessage, HumanMessage +from app.agents.new_chat.document_xml import build_document_xml as _build_document_xml from app.agents.new_chat.middleware.knowledge_search import ( KBSearchPlan, KnowledgeBaseSearchMiddleware, - _build_document_xml, _normalize_optional_date_range, _parse_kb_search_plan_response, _render_recent_conversation, @@ -248,17 +248,10 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: captured.update(kwargs) return [] - async def fake_build_scoped_filesystem(**kwargs): - return {}, {} - monkeypatch.setattr( "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", fake_search_knowledge_base, ) - monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem", - fake_build_scoped_filesystem, - ) llm = FakeLLM( json.dumps( @@ -298,17 +291,10 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: captured.update(kwargs) return [] - async def fake_build_scoped_filesystem(**kwargs): - return {}, {} - monkeypatch.setattr( "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", fake_search_knowledge_base, ) - monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem", - fake_build_scoped_filesystem, - ) middleware = KnowledgeBaseSearchMiddleware( llm=FakeLLM("not json"), @@ -334,17 +320,10 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: captured.update(kwargs) return [] - async def fake_build_scoped_filesystem(**kwargs): - return {}, {} - monkeypatch.setattr( "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", fake_search_knowledge_base, ) - monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem", - fake_build_scoped_filesystem, - ) middleware = KnowledgeBaseSearchMiddleware( llm=FakeLLM( @@ -386,9 +365,6 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: search_called = True return [] - async def fake_build_scoped_filesystem(**kwargs): - return {}, {} - monkeypatch.setattr( "app.agents.new_chat.middleware.knowledge_search.browse_recent_documents", fake_browse_recent_documents, @@ -397,10 +373,6 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", fake_search_knowledge_base, ) - monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem", - fake_build_scoped_filesystem, - ) llm = FakeLLM( json.dumps( @@ -440,9 +412,6 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: search_captured.update(kwargs) return [] - async def fake_build_scoped_filesystem(**kwargs): - return {}, {} - monkeypatch.setattr( "app.agents.new_chat.middleware.knowledge_search.browse_recent_documents", fake_browse_recent_documents, @@ -451,10 +420,6 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", fake_search_knowledge_base, ) - monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem", - fake_build_scoped_filesystem, - ) llm = FakeLLM( json.dumps(