feat: updated file management for main agent

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-04-28 04:32:52 -07:00
parent 8d50f90060
commit 05ca4c0b9f
27 changed files with 5054 additions and 1803 deletions

3
.gitignore vendored
View file

@ -7,4 +7,5 @@ node_modules/
.pnpm-store
.DS_Store
deepagents/
debug.log
debug.log
opencode/

View file

@ -28,13 +28,76 @@ from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, ToolMessage
from app.agents.new_chat.document_xml import build_document_xml
from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware
from app.agents.new_chat.middleware.knowledge_search import (
build_scoped_filesystem,
search_knowledge_base,
)
from app.agents.new_chat.path_resolver import (
DOCUMENTS_ROOT,
build_path_index,
doc_to_virtual_path,
)
from app.db import shielded_async_session
from app.services.new_streaming_service import VercelStreamingService
try:
from deepagents.backends.utils import create_file_data
except Exception: # pragma: no cover - defensive
def create_file_data(content: str) -> dict[str, Any]:
return {"content": content.split("\n")}
async def _build_autocomplete_filesystem(
*,
documents: Any,
search_space_id: int,
) -> tuple[dict[str, Any], dict[int, str]]:
"""Build a ``state['files']``-shaped dict from KB search results.
This is the autocomplete-specific replacement for the previous
``build_scoped_filesystem`` helper. It uses the canonical path resolver
so paths line up with the rest of the system, including collision
suffixes for duplicate titles.
"""
files: dict[str, Any] = {}
doc_id_to_path: dict[int, str] = {}
if not documents:
return files, doc_id_to_path
async with shielded_async_session() as session:
index = await build_path_index(session, search_space_id)
for document in documents:
if not isinstance(document, dict):
continue
meta = document.get("document") or {}
doc_id = meta.get("id")
if not isinstance(doc_id, int):
continue
title = str(meta.get("title") or "untitled")
folder_id = meta.get("folder_id")
path = doc_to_virtual_path(
doc_id=doc_id, title=title, folder_id=folder_id, index=index
)
chunk_ids = document.get("matched_chunk_ids") or []
try:
matched_set = {int(c) for c in chunk_ids}
except (TypeError, ValueError):
matched_set = set()
xml = build_document_xml(document, matched_chunk_ids=matched_set)
files[path] = create_file_data(xml)
doc_id_to_path[doc_id] = path
if not files:
# Ensure the synthetic /documents folder is visible even when empty.
files.setdefault(f"{DOCUMENTS_ROOT}/.placeholder", create_file_data(""))
return files, doc_id_to_path
logger = logging.getLogger(__name__)
KB_TOP_K = 10
@ -174,7 +237,7 @@ async def precompute_kb_filesystem(
if not search_results:
return _KBResult()
new_files, _ = await build_scoped_filesystem(
new_files, _ = await _build_autocomplete_filesystem(
documents=search_results,
search_space_id=search_space_id,
)
@ -215,13 +278,12 @@ async def precompute_kb_filesystem(
class AutocompleteFilesystemMiddleware(SurfSenseFilesystemMiddleware):
"""Filesystem middleware for autocomplete — read-only exploration only.
Strips ``save_document`` (permanent KB persistence) and passes
``search_space_id=None`` so ``write_file`` / ``edit_file`` stay ephemeral.
Passes ``search_space_id=None`` so the new persistence pipeline is
bypassed; the autocomplete flow only reads, never commits to Postgres.
"""
def __init__(self) -> None:
super().__init__(search_space_id=None, created_by_id=None)
self.tools = [t for t in self.tools if t.name != "save_document"]
# ---------------------------------------------------------------------------

View file

@ -34,12 +34,15 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.context import SurfSenseContextSchema
from app.agents.new_chat.filesystem_backends import build_backend_resolver
from app.agents.new_chat.filesystem_selection import FilesystemSelection
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
from app.agents.new_chat.llm_config import AgentConfig
from app.agents.new_chat.middleware import (
AnonymousDocumentMiddleware,
DedupHITLToolCallsMiddleware,
FileIntentMiddleware,
KnowledgeBaseSearchMiddleware,
KnowledgeBasePersistenceMiddleware,
KnowledgePriorityMiddleware,
KnowledgeTreeMiddleware,
MemoryInjectionMiddleware,
SurfSenseFilesystemMiddleware,
)
@ -246,7 +249,12 @@ async def create_surfsense_deep_agent(
"""
_t_agent_total = time.perf_counter()
filesystem_selection = filesystem_selection or FilesystemSelection()
backend_resolver = build_backend_resolver(filesystem_selection)
backend_resolver = build_backend_resolver(
filesystem_selection,
search_space_id=search_space_id
if filesystem_selection.mode == FilesystemMode.CLOUD
else None,
)
# Discover available connectors and document types for this search space
available_connectors: list[str] | None = None
@ -299,7 +307,9 @@ async def create_surfsense_deep_agent(
modified_disabled_tools = list(disabled_tools) if disabled_tools else []
modified_disabled_tools.extend(get_connector_gated_tools(available_connectors))
# Remove direct KB search tool; we now pre-seed a scoped filesystem via middleware.
# Remove direct KB search tool; KnowledgePriorityMiddleware now runs hybrid
# search per turn and surfaces hits as a <priority_documents> hint plus
# `<chunk_index matched="true">` markers inside lazy-loaded XML.
if "search_knowledge_base" not in modified_disabled_tools:
modified_disabled_tools.append("search_knowledge_base")
@ -365,6 +375,11 @@ async def create_surfsense_deep_agent(
)
# General-purpose subagent middleware
# Subagent omits AnonymousDocumentMiddleware, KnowledgeTreeMiddleware,
# KnowledgePriorityMiddleware, and KnowledgeBasePersistenceMiddleware - it
# inherits state and tools from the parent, but should not (a) re-load
# anon docs / re-render the tree / re-run hybrid search, or (b) commit at
# its own completion (only the top-level agent's aafter_agent commits).
gp_middleware = [
TodoListMiddleware(),
_memory_middleware,
@ -389,19 +404,35 @@ async def create_surfsense_deep_agent(
}
# Main agent middleware
# Order: AnonDoc -> Tree -> Priority -> FileIntent -> Filesystem -> Persistence -> ...
# before_agent hooks run in declared order; later injections sit closer to
# the latest human turn. Tree (large + cacheable) is injected earliest so
# provider-side prefix caching has more material to hit; FileIntent (most
# actionable per-turn contract) is injected closest to the user message.
deepagent_middleware = [
TodoListMiddleware(),
_memory_middleware,
FileIntentMiddleware(llm=llm),
KnowledgeBaseSearchMiddleware(
AnonymousDocumentMiddleware(
anon_session_id=anon_session_id,
)
if filesystem_selection.mode == FilesystemMode.CLOUD
else None,
KnowledgeTreeMiddleware(
search_space_id=search_space_id,
filesystem_mode=filesystem_selection.mode,
llm=llm,
)
if filesystem_selection.mode == FilesystemMode.CLOUD
else None,
KnowledgePriorityMiddleware(
llm=llm,
search_space_id=search_space_id,
filesystem_mode=filesystem_selection.mode,
available_connectors=available_connectors,
available_document_types=available_document_types,
mentioned_document_ids=mentioned_document_ids,
anon_session_id=anon_session_id,
),
FileIntentMiddleware(llm=llm),
SurfSenseFilesystemMiddleware(
backend=backend_resolver,
filesystem_mode=filesystem_selection.mode,
@ -409,12 +440,20 @@ async def create_surfsense_deep_agent(
created_by_id=user_id,
thread_id=thread_id,
),
KnowledgeBasePersistenceMiddleware(
search_space_id=search_space_id,
created_by_id=user_id,
filesystem_mode=filesystem_selection.mode,
)
if filesystem_selection.mode == FilesystemMode.CLOUD
else None,
SubAgentMiddleware(backend=StateBackend, subagents=[general_purpose_spec]),
create_safe_summarization_middleware(llm, StateBackend),
PatchToolCallsMiddleware(),
DedupHITLToolCallsMiddleware(agent_tools=tools),
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
]
deepagent_middleware = [m for m in deepagent_middleware if m is not None]
# Combine system_prompt with BASE_AGENT_PROMPT (same as create_deep_agent)
final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT

View file

@ -0,0 +1,103 @@
"""Shared XML builder for KB documents.
Produces the citation-friendly XML used by every read of a knowledge-base
document (lazy-loaded by :class:`KBPostgresBackend` and synthetic anonymous
files). The XML carries a ``<chunk_index>`` near the top so the LLM can jump
directly to matched-chunk line ranges via ``read_file(offset=, limit=)``.
Extracted from the original ``knowledge_search.py`` so the backend, the
priority middleware, and any future renderer share a single implementation.
"""
from __future__ import annotations
import json
from typing import Any
def build_document_xml(
document: dict[str, Any],
matched_chunk_ids: set[int] | None = None,
) -> str:
"""Build citation-friendly XML with a ``<chunk_index>`` for smart seeking.
Args:
document: Dict shape produced by hybrid search / lazy-load helpers.
Expected keys: ``document`` (with ``id``, ``title``,
``document_type``, ``metadata``) and ``chunks``
(list of ``{chunk_id, content}``).
matched_chunk_ids: Optional set of chunk IDs to flag as
``matched="true"`` in the chunk index.
"""
matched = matched_chunk_ids or set()
doc_meta = document.get("document") or {}
metadata = (doc_meta.get("metadata") or {}) if isinstance(doc_meta, dict) else {}
document_id = doc_meta.get("id", document.get("document_id", "unknown"))
document_type = doc_meta.get("document_type", document.get("source", "UNKNOWN"))
title = doc_meta.get("title") or metadata.get("title") or "Untitled Document"
url = (
metadata.get("url") or metadata.get("source") or metadata.get("page_url") or ""
)
metadata_json = json.dumps(metadata, ensure_ascii=False)
metadata_lines: list[str] = [
"<document>",
"<document_metadata>",
f" <document_id>{document_id}</document_id>",
f" <document_type>{document_type}</document_type>",
f" <title><![CDATA[{title}]]></title>",
f" <url><![CDATA[{url}]]></url>",
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
"</document_metadata>",
"",
]
chunks = document.get("chunks") or []
chunk_entries: list[tuple[int | None, str]] = []
if isinstance(chunks, list):
for chunk in chunks:
if not isinstance(chunk, dict):
continue
chunk_id = chunk.get("chunk_id") or chunk.get("id")
chunk_content = str(chunk.get("content", "")).strip()
if not chunk_content:
continue
if chunk_id is None:
xml = f" <chunk><![CDATA[{chunk_content}]]></chunk>"
else:
xml = f" <chunk id='{chunk_id}'><![CDATA[{chunk_content}]]></chunk>"
chunk_entries.append((chunk_id, xml))
index_overhead = 1 + len(chunk_entries) + 1 + 1 + 1
first_chunk_line = len(metadata_lines) + index_overhead + 1
current_line = first_chunk_line
index_entry_lines: list[str] = []
for cid, xml_str in chunk_entries:
num_lines = xml_str.count("\n") + 1
end_line = current_line + num_lines - 1
matched_attr = ' matched="true"' if cid is not None and cid in matched else ""
if cid is not None:
index_entry_lines.append(
f' <entry chunk_id="{cid}" lines="{current_line}-{end_line}"{matched_attr}/>'
)
else:
index_entry_lines.append(
f' <entry lines="{current_line}-{end_line}"{matched_attr}/>'
)
current_line = end_line + 1
lines = metadata_lines.copy()
lines.append("<chunk_index>")
lines.extend(index_entry_lines)
lines.append("</chunk_index>")
lines.append("")
lines.append("<document_content>")
for _, xml_str in chunk_entries:
lines.append(xml_str)
lines.extend(["</document_content>", "</document>"])
return "\n".join(lines)
__all__ = ["build_document_xml"]

View file

@ -5,10 +5,12 @@ from __future__ import annotations
from collections.abc import Callable
from functools import lru_cache
from deepagents.backends.protocol import BackendProtocol
from deepagents.backends.state import StateBackend
from langgraph.prebuilt.tool_node import ToolRuntime
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
MultiRootLocalFolderBackend,
)
@ -23,8 +25,20 @@ def _cached_multi_root_backend(
def build_backend_resolver(
selection: FilesystemSelection,
) -> Callable[[ToolRuntime], StateBackend | MultiRootLocalFolderBackend]:
"""Create deepagents backend resolver for the selected filesystem mode."""
*,
search_space_id: int | None = None,
) -> Callable[[ToolRuntime], BackendProtocol]:
"""Create deepagents backend resolver for the selected filesystem mode.
In cloud mode the resolver returns a fresh :class:`KBPostgresBackend`
bound to the current ``runtime`` so the backend can read staging state
(``staged_dirs``, ``pending_moves``, ``files`` cache, ``kb_anon_doc``,
``kb_matched_chunk_ids``) for each tool call. When no ``search_space_id``
is provided, the resolver falls back to :class:`StateBackend` (used by
sub-agents and tests that don't need DB-backed reads).
Desktop-local mode unchanged.
"""
if selection.mode == FilesystemMode.DESKTOP_LOCAL_FOLDER and selection.local_mounts:
@ -36,7 +50,14 @@ def build_backend_resolver(
return _resolve_local
def _resolve_cloud(runtime: ToolRuntime) -> StateBackend:
if search_space_id is not None:
def _resolve_kb(runtime: ToolRuntime) -> BackendProtocol:
return KBPostgresBackend(search_space_id, runtime)
return _resolve_kb
def _resolve_state(runtime: ToolRuntime) -> StateBackend:
return StateBackend(runtime)
return _resolve_cloud
return _resolve_state

View file

@ -0,0 +1,113 @@
"""LangGraph state schema additions used by the SurfSense filesystem agent.
This schema extends deepagents' upstream :class:`FilesystemState` with the
extra fields needed to implement Postgres-backed virtual filesystem semantics:
* ``cwd`` current working directory (per-thread checkpointed).
* ``staged_dirs`` pending mkdir requests (cloud only).
* ``pending_moves`` pending move_file requests (cloud only).
* ``doc_id_by_path`` virtual_path -> Document.id, populated by lazy reads.
* ``dirty_paths`` paths whose state file content differs from DB.
* ``kb_priority`` top-K priority hints rendered into a system message.
* ``kb_matched_chunk_ids`` internal hand-off for matched-chunk highlighting.
* ``kb_anon_doc`` Redis-loaded anonymous document (if any).
* ``tree_version`` bumped by persistence; invalidates the tree render cache.
Tools mutate these fields ONLY via ``Command(update=...)`` returns; the
reducers in :mod:`app.agents.new_chat.state_reducers` handle merging.
"""
from __future__ import annotations
from typing import Annotated, Any, NotRequired
from deepagents.middleware.filesystem import FilesystemState
from typing_extensions import TypedDict
from app.agents.new_chat.state_reducers import (
_add_unique_reducer,
_dict_merge_with_tombstones_reducer,
_list_append_reducer,
_replace_reducer,
)
class PendingMove(TypedDict):
"""A staged move_file operation pending end-of-turn commit."""
source: str
dest: str
overwrite: bool
class KbPriorityEntry(TypedDict, total=False):
path: str
score: float
document_id: int | None
title: str
mentioned: bool
class KbAnonDoc(TypedDict, total=False):
"""In-memory anonymous-session document loaded from Redis."""
path: str
title: str
content: str
chunks: list[dict[str, Any]]
class SurfSenseFilesystemState(FilesystemState):
"""Filesystem state used by the SurfSense agent (cloud + desktop).
Extends deepagents' :class:`FilesystemState` (which provides ``files``)
with cloud-mode staging fields and search-priority hints. All extra fields
are reducer-backed so that ``Command(update=...)`` payloads merge cleanly
across agent steps and across checkpoints.
"""
cwd: NotRequired[Annotated[str, _replace_reducer]]
"""Current working directory.
Defaults to ``"/documents"`` in cloud mode and ``"/"`` (or first mount) in
desktop mode. Initialized once per thread by ``KnowledgeTreeMiddleware``.
"""
staged_dirs: NotRequired[Annotated[list[str], _add_unique_reducer]]
"""mkdir paths staged for end-of-turn folder creation (cloud only)."""
pending_moves: NotRequired[Annotated[list[PendingMove], _list_append_reducer]]
"""move_file ops staged for end-of-turn commit (cloud only)."""
doc_id_by_path: NotRequired[
Annotated[dict[str, int], _dict_merge_with_tombstones_reducer]
]
"""virtual_path -> ``Document.id`` for lazily loaded files.
Populated on first read of a KB document. Used by edit_file/move_file/
aafter_agent to map paths back to a real DB row. ``None`` values delete
the key (tombstones).
"""
dirty_paths: NotRequired[Annotated[list[str], _add_unique_reducer]]
"""Paths whose ``state["files"]`` content has been modified this turn."""
kb_priority: NotRequired[Annotated[list[KbPriorityEntry], _replace_reducer]]
"""Top-K priority hints rendered as a system message before the user turn."""
kb_matched_chunk_ids: NotRequired[Annotated[dict[int, list[int]], _replace_reducer]]
"""Internal: ``Document.id`` -> list of matched chunk IDs from hybrid search."""
kb_anon_doc: NotRequired[Annotated[KbAnonDoc | None, _replace_reducer]]
"""Anonymous-session document loaded from Redis (read-only, no DB row)."""
tree_version: NotRequired[Annotated[int, _replace_reducer]]
"""Monotonically increasing counter; bumped when commits change the KB tree."""
__all__ = [
"KbAnonDoc",
"KbPriorityEntry",
"PendingMove",
"SurfSenseFilesystemState",
]

View file

@ -1,5 +1,8 @@
"""Middleware components for the SurfSense new chat agent."""
from app.agents.new_chat.middleware.anonymous_document import (
AnonymousDocumentMiddleware,
)
from app.agents.new_chat.middleware.dedup_tool_calls import (
DedupHITLToolCallsMiddleware,
)
@ -9,17 +12,30 @@ from app.agents.new_chat.middleware.file_intent import (
from app.agents.new_chat.middleware.filesystem import (
SurfSenseFilesystemMiddleware,
)
from app.agents.new_chat.middleware.kb_persistence import (
KnowledgeBasePersistenceMiddleware,
commit_staged_filesystem_state,
)
from app.agents.new_chat.middleware.knowledge_search import (
KnowledgeBaseSearchMiddleware,
KnowledgePriorityMiddleware,
)
from app.agents.new_chat.middleware.knowledge_tree import (
KnowledgeTreeMiddleware,
)
from app.agents.new_chat.middleware.memory_injection import (
MemoryInjectionMiddleware,
)
__all__ = [
"AnonymousDocumentMiddleware",
"DedupHITLToolCallsMiddleware",
"FileIntentMiddleware",
"KnowledgeBasePersistenceMiddleware",
"KnowledgeBaseSearchMiddleware",
"KnowledgePriorityMiddleware",
"KnowledgeTreeMiddleware",
"MemoryInjectionMiddleware",
"SurfSenseFilesystemMiddleware",
"commit_staged_filesystem_state",
]

View file

@ -0,0 +1,91 @@
"""Lightweight middleware that loads the anonymous-session document into state.
Anonymous chats receive a single uploaded document via Redis (no DB row,
read-only). This middleware loads it once on the first turn into
``state['kb_anon_doc']`` so:
* :class:`KnowledgeTreeMiddleware` can render the synthetic ``/documents``
view without touching the DB.
* :class:`KnowledgePriorityMiddleware` skips hybrid search and emits a
degenerate priority list.
* :class:`KBPostgresBackend` (``als_info`` / ``aread`` / ``_load_file_data``)
recognises the synthetic path.
The middleware is a no-op when ``anon_session_id`` is not provided or when
the document is already cached in state.
"""
from __future__ import annotations
import json
import logging
from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState
from langgraph.runtime import Runtime
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT, safe_filename
logger = logging.getLogger(__name__)
class AnonymousDocumentMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Load the anonymous user's uploaded document from Redis into state."""
tools = ()
state_schema = SurfSenseFilesystemState
def __init__(self, *, anon_session_id: str | None) -> None:
self.anon_session_id = anon_session_id
async def abefore_agent( # type: ignore[override]
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
del runtime
if not self.anon_session_id:
return None
if state.get("kb_anon_doc"):
return None
anon_doc = await self._load_anon_document()
if anon_doc is None:
return None
return {"kb_anon_doc": anon_doc}
async def _load_anon_document(self) -> dict[str, Any] | None:
"""Read ``anon:doc:<session_id>`` from Redis."""
try:
import redis.asyncio as aioredis # local import to keep cold paths cheap
from app.config import config
redis_client = aioredis.from_url(
config.REDIS_APP_URL, decode_responses=True
)
try:
redis_key = f"anon:doc:{self.anon_session_id}"
data = await redis_client.get(redis_key)
if not data:
return None
payload = json.loads(data)
finally:
await redis_client.aclose()
except Exception as exc:
logger.warning("Failed to load anonymous document from Redis: %s", exc)
return None
title = str(payload.get("filename") or "uploaded_document")
content = str(payload.get("content") or "")
path = f"{DOCUMENTS_ROOT}/{safe_filename(title)}"
return {
"path": path,
"title": title,
"content": content,
"chunks": [{"chunk_id": -1, "content": content}] if content else [],
}
__all__ = ["AnonymousDocumentMiddleware"]

View file

@ -21,7 +21,7 @@ from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langgraph.runtime import Runtime
from pydantic import BaseModel, Field, ValidationError
@ -217,8 +217,19 @@ def _build_recent_conversation(
messages: list[BaseMessage], *, max_messages: int = 6
) -> str:
rows: list[str] = []
for msg in messages[-max_messages:]:
role = "user" if isinstance(msg, HumanMessage) else "assistant"
filtered: list[tuple[str, BaseMessage]] = []
for msg in messages:
role: str | None = None
if isinstance(msg, HumanMessage):
role = "user"
elif isinstance(msg, AIMessage):
if getattr(msg, "tool_calls", None):
continue
role = "assistant"
else:
continue
filtered.append((role, msg))
for role, msg in filtered[-max_messages:]:
text = re.sub(r"\s+", " ", _extract_text_from_message(msg)).strip()
if text:
rows.append(f"{role}: {text[:280]}")

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,622 @@
"""End-of-turn persistence for the cloud-mode SurfSense filesystem.
This middleware runs ``aafter_agent`` once per turn (cloud only). It commits
all staged folder creations, file moves, and content writes/edits to
Postgres in a single ordered pass:
1. Materialize ``staged_dirs`` into ``Folder`` rows.
2. Apply ``pending_moves`` in order (chained moves resolved via
``doc_id_by_path``).
3. Normalize ``dirty_paths`` through ``pending_moves`` so write-then-move
sequences commit at the final path.
4. Commit content writes / edits for ``/documents/*`` paths, skipping
``temp_*`` basenames.
The commit body is exposed as a free function ``commit_staged_filesystem_state``
so the optional stream-task fallback (``stream_new_chat.py``) can call the
exact same routine when ``aafter_agent`` was skipped (e.g. client disconnect).
"""
from __future__ import annotations
import logging
from datetime import UTC, datetime
from typing import Any
from fractional_indexing import generate_key_between
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain_core.callbacks import dispatch_custom_event
from langgraph.runtime import Runtime
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
from app.agents.new_chat.path_resolver import (
DOCUMENTS_ROOT,
parse_documents_path,
safe_folder_segment,
virtual_path_to_doc,
)
from app.agents.new_chat.state_reducers import _CLEAR
from app.db import (
Chunk,
Document,
DocumentType,
Folder,
shielded_async_session,
)
from app.indexing_pipeline.document_chunker import chunk_text
from app.utils.document_converters import (
embed_texts,
generate_content_hash,
generate_unique_identifier_hash,
)
logger = logging.getLogger(__name__)
_TEMP_PREFIX = "temp_"
def _basename(path: str) -> str:
return path.rsplit("/", 1)[-1]
# ---------------------------------------------------------------------------
# Folder helpers
# ---------------------------------------------------------------------------
async def _ensure_folder_hierarchy(
session: AsyncSession,
*,
search_space_id: int,
created_by_id: str | None,
folder_parts: list[str],
) -> int | None:
"""Ensure a chain of folder names exists under the search space.
Returns the leaf folder id, or ``None`` if ``folder_parts`` is empty
(i.e. a document directly under ``/documents/``).
"""
if not folder_parts:
return None
parent_id: int | None = None
for raw in folder_parts:
name = safe_folder_segment(str(raw))
query = select(Folder).where(
Folder.search_space_id == search_space_id,
Folder.name == name,
)
if parent_id is None:
query = query.where(Folder.parent_id.is_(None))
else:
query = query.where(Folder.parent_id == parent_id)
result = await session.execute(query)
folder = result.scalar_one_or_none()
if folder is None:
sibling_query = (
select(Folder.position).order_by(Folder.position.desc()).limit(1)
)
sibling_query = sibling_query.where(
Folder.search_space_id == search_space_id
)
if parent_id is None:
sibling_query = sibling_query.where(Folder.parent_id.is_(None))
else:
sibling_query = sibling_query.where(Folder.parent_id == parent_id)
sibling_result = await session.execute(sibling_query)
last_position = sibling_result.scalar_one_or_none()
folder = Folder(
name=name,
position=generate_key_between(last_position, None),
parent_id=parent_id,
search_space_id=search_space_id,
created_by_id=created_by_id,
updated_at=datetime.now(UTC),
)
session.add(folder)
await session.flush()
parent_id = folder.id
return parent_id
# ---------------------------------------------------------------------------
# Document helpers
# ---------------------------------------------------------------------------
async def _create_document(
session: AsyncSession,
*,
virtual_path: str,
content: str,
search_space_id: int,
created_by_id: str | None,
) -> Document:
"""Create a NOTE Document + Chunks for ``virtual_path``."""
folder_parts, title = parse_documents_path(virtual_path)
if not title:
raise ValueError(f"invalid /documents path '{virtual_path}'")
folder_id = await _ensure_folder_hierarchy(
session,
search_space_id=search_space_id,
created_by_id=created_by_id,
folder_parts=folder_parts,
)
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.NOTE,
virtual_path,
search_space_id,
)
# Guard against the unique_identifier_hash constraint: another row at the
# same virtual_path (this search space) already owns the hash. Callers are
# expected to upsert via the wrapper, but this defends against bypasses
# and gives a clean ValueError instead of a session-poisoning IntegrityError.
path_collision = await session.execute(
select(Document.id).where(
Document.search_space_id == search_space_id,
Document.unique_identifier_hash == unique_identifier_hash,
)
)
if path_collision.scalar_one_or_none() is not None:
raise ValueError(
f"a document already exists at path '{virtual_path}' "
"(unique_identifier_hash collision)"
)
content_hash = generate_content_hash(content, search_space_id)
content_collision = await session.execute(
select(Document.id).where(
Document.search_space_id == search_space_id,
Document.content_hash == content_hash,
)
)
if content_collision.scalar_one_or_none() is not None:
raise ValueError(
f"a document with identical content already exists for path '{virtual_path}'"
)
doc = Document(
title=title,
document_type=DocumentType.NOTE,
document_metadata={"virtual_path": virtual_path},
content=content,
content_hash=content_hash,
unique_identifier_hash=unique_identifier_hash,
source_markdown=content,
search_space_id=search_space_id,
folder_id=folder_id,
created_by_id=created_by_id,
updated_at=datetime.now(UTC),
)
session.add(doc)
await session.flush()
summary_embedding = embed_texts([content])[0]
doc.embedding = summary_embedding
chunks = chunk_text(content)
if chunks:
chunk_embeddings = embed_texts(chunks)
session.add_all(
[
Chunk(document_id=doc.id, content=text, embedding=embedding)
for text, embedding in zip(chunks, chunk_embeddings, strict=True)
]
)
return doc
async def _update_document(
session: AsyncSession,
*,
doc_id: int,
content: str,
virtual_path: str,
search_space_id: int,
) -> Document | None:
"""Update an existing Document's content + chunks."""
result = await session.execute(
select(Document).where(
Document.id == doc_id,
Document.search_space_id == search_space_id,
)
)
document = result.scalar_one_or_none()
if document is None:
return None
document.content = content
document.source_markdown = content
document.content_hash = generate_content_hash(content, search_space_id)
document.updated_at = datetime.now(UTC)
metadata = dict(document.document_metadata or {})
metadata["virtual_path"] = virtual_path
document.document_metadata = metadata
document.unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.NOTE,
virtual_path,
search_space_id,
)
summary_embedding = embed_texts([content])[0]
document.embedding = summary_embedding
await session.execute(delete(Chunk).where(Chunk.document_id == document.id))
chunks = chunk_text(content)
if chunks:
chunk_embeddings = embed_texts(chunks)
session.add_all(
[
Chunk(document_id=document.id, content=text, embedding=embedding)
for text, embedding in zip(chunks, chunk_embeddings, strict=True)
]
)
return document
# ---------------------------------------------------------------------------
# Move helpers
# ---------------------------------------------------------------------------
async def _apply_move(
session: AsyncSession,
*,
search_space_id: int,
created_by_id: str | None,
move: dict[str, Any],
doc_id_by_path: dict[str, int],
doc_id_path_tombstones: dict[str, int | None],
) -> dict[str, Any] | None:
"""Apply a single staged move; updates the in-memory mapping for chain resolution."""
source = str(move.get("source") or "")
dest = str(move.get("dest") or "")
if not source or not dest or source == dest:
return None
if not source.startswith(DOCUMENTS_ROOT + "/") or not dest.startswith(
DOCUMENTS_ROOT + "/"
):
return None
doc_id: int | None = doc_id_by_path.get(source)
document: Document | None = None
if doc_id is not None:
result = await session.execute(
select(Document).where(
Document.id == doc_id,
Document.search_space_id == search_space_id,
)
)
document = result.scalar_one_or_none()
if document is None:
document = await virtual_path_to_doc(
session,
search_space_id=search_space_id,
virtual_path=source,
)
if document is None:
logger.info(
"kb_persistence: skipping move %s -> %s (source not found)",
source,
dest,
)
return None
folder_parts, new_title = parse_documents_path(dest)
if not new_title:
return None
folder_id = await _ensure_folder_hierarchy(
session,
search_space_id=search_space_id,
created_by_id=created_by_id,
folder_parts=folder_parts,
)
document.title = new_title
document.folder_id = folder_id
metadata = dict(document.document_metadata or {})
metadata["virtual_path"] = dest
document.document_metadata = metadata
document.unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.NOTE,
dest,
search_space_id,
)
document.updated_at = datetime.now(UTC)
doc_id_by_path.pop(source, None)
doc_id_by_path[dest] = document.id
doc_id_path_tombstones[source] = None
doc_id_path_tombstones[dest] = document.id
return {"id": document.id, "source": source, "dest": dest, "title": new_title}
# ---------------------------------------------------------------------------
# Commit body
# ---------------------------------------------------------------------------
async def commit_staged_filesystem_state(
state: dict[str, Any] | AgentState,
*,
search_space_id: int,
created_by_id: str | None,
filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
dispatch_events: bool = True,
) -> dict[str, Any] | None:
"""Commit all staged filesystem changes; return the state delta for reducers.
Shared between :class:`KnowledgeBasePersistenceMiddleware.aafter_agent`
and the optional stream-task fallback.
"""
if filesystem_mode != FilesystemMode.CLOUD:
return None
state_dict: dict[str, Any] = (
dict(state)
if isinstance(state, dict)
else dict(getattr(state, "values", {}) or {})
)
files: dict[str, Any] = state_dict.get("files") or {}
staged_dirs: list[str] = list(state_dict.get("staged_dirs") or [])
pending_moves: list[dict[str, Any]] = list(state_dict.get("pending_moves") or [])
dirty_paths: list[str] = list(state_dict.get("dirty_paths") or [])
doc_id_by_path: dict[str, int] = dict(state_dict.get("doc_id_by_path") or {})
kb_anon_doc = state_dict.get("kb_anon_doc")
if kb_anon_doc:
temp_paths = [
p
for p in files
if isinstance(p, str) and _basename(p).startswith(_TEMP_PREFIX)
]
return {
"dirty_paths": [_CLEAR],
"staged_dirs": [_CLEAR],
"pending_moves": [_CLEAR],
"files": dict.fromkeys(temp_paths),
}
if not (staged_dirs or pending_moves or dirty_paths):
return None
committed_creates: list[dict[str, Any]] = []
committed_updates: list[dict[str, Any]] = []
discarded: list[str] = []
applied_moves: list[dict[str, Any]] = []
doc_id_path_tombstones: dict[str, int | None] = {}
tree_changed = False
try:
async with shielded_async_session() as session:
for folder_path in staged_dirs:
if not isinstance(folder_path, str):
continue
if not folder_path.startswith(DOCUMENTS_ROOT):
continue
rel = folder_path[len(DOCUMENTS_ROOT) :].strip("/")
folder_parts_full = [p for p in rel.split("/") if p]
if not folder_parts_full:
continue
await _ensure_folder_hierarchy(
session,
search_space_id=search_space_id,
created_by_id=created_by_id,
folder_parts=folder_parts_full,
)
tree_changed = True
for move in pending_moves:
applied = await _apply_move(
session,
search_space_id=search_space_id,
created_by_id=created_by_id,
move=move,
doc_id_by_path=doc_id_by_path,
doc_id_path_tombstones=doc_id_path_tombstones,
)
if applied:
applied_moves.append(applied)
tree_changed = True
move_alias = {
m["source"]: m["dest"] for m in pending_moves if m.get("source")
}
def _final_path(path: str) -> str:
seen: set[str] = set()
while path in move_alias and path not in seen:
seen.add(path)
path = move_alias[path]
return path
kb_dirty_seen: set[str] = set()
kb_dirty: list[str] = []
for raw in dirty_paths:
if not isinstance(raw, str):
continue
final = _final_path(raw)
if not final.startswith(DOCUMENTS_ROOT + "/"):
continue
if final in kb_dirty_seen:
continue
kb_dirty_seen.add(final)
kb_dirty.append(final)
for path in kb_dirty:
basename = _basename(path)
if basename.startswith(_TEMP_PREFIX):
discarded.append(path)
continue
file_data = files.get(path)
if not isinstance(file_data, dict):
continue
content = "\n".join(file_data.get("content") or [])
doc_id = doc_id_by_path.get(path)
if doc_id is None:
# The in-memory ``doc_id_by_path`` is per-thread and starts
# empty in every new chat. If the agent writes to a path
# that already exists in the DB (e.g. a previous chat's
# ``notes.md``), we must NOT try to INSERT — it would hit
# ``unique_identifier_hash`` (path-derived). Look up the
# existing doc and update it in place instead.
existing = await virtual_path_to_doc(
session,
search_space_id=search_space_id,
virtual_path=path,
)
if existing is not None:
doc_id = existing.id
doc_id_by_path[path] = existing.id
if doc_id is not None:
updated = await _update_document(
session,
doc_id=doc_id,
content=content,
virtual_path=path,
search_space_id=search_space_id,
)
if updated is not None:
committed_updates.append(
{
"id": updated.id,
"title": updated.title,
"documentType": DocumentType.NOTE.value,
"searchSpaceId": search_space_id,
"folderId": updated.folder_id,
"createdById": str(created_by_id)
if created_by_id
else None,
"virtualPath": path,
}
)
else:
try:
new_doc = await _create_document(
session,
virtual_path=path,
content=content,
search_space_id=search_space_id,
created_by_id=created_by_id,
)
except ValueError as exc:
logger.warning(
"kb_persistence: skipping %s create: %s", path, exc
)
continue
doc_id_by_path[path] = new_doc.id
committed_creates.append(
{
"id": new_doc.id,
"title": new_doc.title,
"documentType": DocumentType.NOTE.value,
"searchSpaceId": search_space_id,
"folderId": new_doc.folder_id,
"createdById": str(created_by_id)
if created_by_id
else None,
"virtualPath": path,
}
)
tree_changed = True
await session.commit()
except Exception: # pragma: no cover - rollback safety net
logger.exception(
"kb_persistence: commit failed (search_space=%s)", search_space_id
)
return None
if dispatch_events:
for payload in committed_creates:
try:
dispatch_custom_event("document_created", payload)
except Exception:
logger.exception(
"kb_persistence: failed to dispatch document_created event"
)
for payload in committed_updates:
try:
dispatch_custom_event("document_updated", payload)
except Exception:
logger.exception(
"kb_persistence: failed to dispatch document_updated event"
)
temp_paths = [
p for p in files if isinstance(p, str) and _basename(p).startswith(_TEMP_PREFIX)
]
doc_id_update: dict[str, int | None] = {**doc_id_path_tombstones}
for payload in committed_creates:
doc_id_update[str(payload.get("virtualPath") or "")] = int(payload["id"])
delta: dict[str, Any] = {
"dirty_paths": [_CLEAR],
"staged_dirs": [_CLEAR],
"pending_moves": [_CLEAR],
}
if temp_paths:
delta["files"] = dict.fromkeys(temp_paths)
if doc_id_update:
delta["doc_id_by_path"] = doc_id_update
if tree_changed:
delta["tree_version"] = int(state_dict.get("tree_version") or 0) + 1
logger.info(
"kb_persistence: commit (search_space=%s) creates=%d updates=%d "
"moves=%d staged_dirs=%d discarded=%d",
search_space_id,
len(committed_creates),
len(committed_updates),
len(applied_moves),
len(staged_dirs),
len(discarded),
)
return delta
# ---------------------------------------------------------------------------
# Middleware
# ---------------------------------------------------------------------------
class KnowledgeBasePersistenceMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""End-of-turn cloud persistence for the SurfSense filesystem agent."""
tools = ()
state_schema = SurfSenseFilesystemState
def __init__(
self,
*,
search_space_id: int,
created_by_id: str | None,
filesystem_mode: FilesystemMode,
) -> None:
self.search_space_id = search_space_id
self.created_by_id = created_by_id
self.filesystem_mode = filesystem_mode
async def aafter_agent( # type: ignore[override]
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
del runtime
if self.filesystem_mode != FilesystemMode.CLOUD:
return None
return await commit_staged_filesystem_state(
state,
search_space_id=self.search_space_id,
created_by_id=self.created_by_id,
filesystem_mode=self.filesystem_mode,
)
__all__ = [
"KnowledgeBasePersistenceMiddleware",
"commit_staged_filesystem_state",
]

View file

@ -0,0 +1,963 @@
"""Postgres-backed virtual filesystem for the SurfSense agent (cloud mode).
The backend is **strictly conforming** to deepagents'
:class:`BackendProtocol`. It returns ``WriteResult`` / ``EditResult`` / list
shapes exactly as upstream expects (no extra fields). All side-state
plumbing ``dirty_paths``, ``doc_id_by_path``, ``staged_dirs``,
``pending_moves``, ``files`` cache is appended by the overridden tool
wrappers in :class:`SurfSenseFilesystemMiddleware` via ``Command.update``.
The backend never writes to Postgres. End-of-turn persistence is handled by
:class:`KnowledgeBasePersistenceMiddleware`. This module is purely a
read-side and a state-merging helper.
"""
from __future__ import annotations
import asyncio
import contextlib
import fnmatch
import logging
import re
from datetime import UTC
from typing import Any
from deepagents.backends.protocol import (
BackendProtocol,
EditResult,
FileDownloadResponse,
FileInfo,
FileUploadResponse,
GrepMatch,
WriteResult,
)
from deepagents.backends.utils import (
create_file_data,
file_data_to_string,
format_read_response,
perform_string_replacement,
update_file_data,
)
from langchain.tools import ToolRuntime
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.document_xml import build_document_xml
from app.agents.new_chat.path_resolver import (
DOCUMENTS_ROOT,
build_path_index,
doc_to_virtual_path,
virtual_path_to_doc,
)
from app.db import Chunk, Document, shielded_async_session
logger = logging.getLogger(__name__)
_TEMP_PREFIX = "temp_"
_GREP_MAX_TOTAL_MATCHES = 50
_GREP_MAX_PER_DOC = 5
def _basename(path: str) -> str:
return path.rsplit("/", 1)[-1]
def _is_under(child: str, parent: str) -> bool:
"""Return True iff ``child`` is at-or-under ``parent`` (directory semantics)."""
if parent == "/":
return child.startswith("/")
return child == parent or child.startswith(parent.rstrip("/") + "/")
def paginate_listing(
infos: list[FileInfo],
*,
offset: int = 0,
limit: int | None = None,
) -> list[FileInfo]:
"""Paginate a listing produced by :meth:`KBPostgresBackend.als_info`."""
if offset < 0:
offset = 0
end: int | None
end = None if limit is None or limit < 0 else offset + limit
return list(infos[offset:end])
class KBPostgresBackend(BackendProtocol):
"""Lazy, read-only Postgres view for ``/documents/*`` virtual paths.
The backend exposes a virtual ``/documents/`` namespace mirroring the
``Folder``/``Document`` graph. Reads materialize XML on first access and
cache it via the overriding tool wrappers (NOT here). Writes never touch
the DB they return ``files_update`` deltas that the wrappers turn into
Command updates, and the persistence middleware commits them at end of
turn.
"""
_IMAGE_EXTENSIONS = frozenset({".png", ".jpg", ".jpeg", ".gif", ".webp"})
def __init__(self, search_space_id: int, runtime: ToolRuntime) -> None:
self.search_space_id = search_space_id
self.runtime = runtime
@property
def state(self) -> dict[str, Any]:
return getattr(self.runtime, "state", {}) or {}
# ------------------------------------------------------------------ helpers
def _state_files(self) -> dict[str, Any]:
return dict(self.state.get("files") or {})
def _staged_dirs(self) -> list[str]:
return list(self.state.get("staged_dirs") or [])
def _pending_moves(self) -> list[dict[str, Any]]:
return list(self.state.get("pending_moves") or [])
def _kb_anon_doc(self) -> dict[str, Any] | None:
anon = self.state.get("kb_anon_doc")
return anon if isinstance(anon, dict) else None
def _matched_chunk_ids(self, doc_id: int) -> set[int]:
mapping = self.state.get("kb_matched_chunk_ids") or {}
try:
return set(mapping.get(doc_id, []) or [])
except TypeError:
return set()
@staticmethod
def _file_data_size(file_data: dict[str, Any]) -> int:
try:
return len("\n".join(file_data.get("content") or []))
except Exception:
return 0
def _normalize_listing_path(self, path: str) -> str:
if not path:
return DOCUMENTS_ROOT
if path == "/":
return path
return path.rstrip("/") if path != "/" else path
def _moved_view_paths(
self,
existing: dict[str, dict[str, Any]],
) -> tuple[set[str], dict[str, str]]:
"""Apply ``pending_moves`` to a path set and return ``(removed, alias)``.
Removed paths should disappear from listings; ``alias[source] = dest``
means a virtual entry should appear at ``dest`` even if no DB row is
yet there.
"""
removed: set[str] = set()
alias: dict[str, str] = {}
for move in self._pending_moves():
src = move.get("source")
dst = move.get("dest")
if not src or not dst:
continue
removed.add(src)
alias[src] = dst
existing.pop(src, None)
return removed, alias
# ------------------------------------------------------------------ ls/read
async def als_info(self, path: str) -> list[FileInfo]: # type: ignore[override]
normalized = self._normalize_listing_path(path)
infos: list[FileInfo] = []
seen: set[str] = set()
anon = self._kb_anon_doc()
if anon:
anon_path = str(anon.get("path") or "")
if (
anon_path
and _is_under(anon_path, normalized)
and anon_path != normalized
and anon_path not in seen
):
infos.append(
FileInfo(
path=anon_path,
is_dir=False,
size=len(str(anon.get("content") or "")),
modified_at="",
)
)
seen.add(anon_path)
files = self._state_files()
moved_removed, moved_alias = self._moved_view_paths(files)
if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/":
try:
async with shielded_async_session() as session:
db_infos, subdir_paths = await self._list_db_directory(
session, normalized
)
except Exception as exc: # pragma: no cover - defensive
logger.warning("KBPostgresBackend.als_info DB error: %s", exc)
db_infos, subdir_paths = [], set()
for info in db_infos:
p = info.get("path", "")
if not p or p in seen or p in moved_removed:
continue
infos.append(info)
seen.add(p)
for src, dst in moved_alias.items():
if src not in seen:
if not _is_under(dst, normalized):
continue
rel = (
dst[len(normalized) :].lstrip("/")
if normalized != "/"
else dst.lstrip("/")
)
if "/" in rel:
subdir_paths.add(
(normalized.rstrip("/") + "/" + rel.split("/", 1)[0])
if normalized != "/"
else "/" + rel.split("/", 1)[0]
)
continue
if dst in seen:
continue
fd = files.get(dst)
size = self._file_data_size(fd) if isinstance(fd, dict) else 0
infos.append(
FileInfo(
path=dst,
is_dir=False,
size=int(size),
modified_at=fd.get("modified_at", "")
if isinstance(fd, dict)
else "",
)
)
seen.add(dst)
for staged in self._staged_dirs():
if not staged or not staged.startswith(DOCUMENTS_ROOT):
continue
if staged == normalized:
continue
if not _is_under(staged, normalized):
continue
rel = (
staged[len(normalized) :].lstrip("/")
if normalized != "/"
else staged.lstrip("/")
)
if not rel:
continue
first = rel.split("/", 1)[0]
immediate = (
normalized.rstrip("/") + "/" + first
if normalized != "/"
else "/" + first
)
subdir_paths.add(immediate)
for sub in sorted(subdir_paths):
if sub in seen:
continue
infos.append(FileInfo(path=sub, is_dir=True, size=0, modified_at=""))
seen.add(sub)
for path_key, fd in files.items():
if not isinstance(path_key, str) or path_key in seen:
continue
if not _is_under(path_key, normalized) or path_key == normalized:
continue
if normalized == "/":
rel = path_key.lstrip("/")
else:
rel = path_key[len(normalized) :].lstrip("/")
if not rel:
continue
if "/" in rel:
first = rel.split("/", 1)[0]
immediate = (
normalized.rstrip("/") + "/" + first
if normalized != "/"
else "/" + first
)
if immediate not in seen:
infos.append(
FileInfo(path=immediate, is_dir=True, size=0, modified_at="")
)
seen.add(immediate)
continue
include = path_key.startswith(DOCUMENTS_ROOT) or _basename(
path_key
).startswith(_TEMP_PREFIX)
if not include:
continue
size = self._file_data_size(fd) if isinstance(fd, dict) else 0
infos.append(
FileInfo(
path=path_key,
is_dir=False,
size=int(size),
modified_at=fd.get("modified_at", "")
if isinstance(fd, dict)
else "",
)
)
seen.add(path_key)
infos.sort(key=lambda fi: (not fi.get("is_dir", False), fi.get("path", "")))
return infos
def ls_info(self, path: str) -> list[FileInfo]: # type: ignore[override]
return asyncio.run(self.als_info(path))
async def _list_db_directory(
self,
session: AsyncSession,
normalized_path: str,
) -> tuple[list[FileInfo], set[str]]:
"""List immediate Folders + Documents at ``normalized_path``.
Returns ``(file_infos, subdirectory_paths)``. ``normalized_path`` may
be ``/`` (synthesizes ``/documents``) or a path under ``/documents``.
"""
if normalized_path == "/":
return (
[],
{DOCUMENTS_ROOT},
)
if not normalized_path.startswith(DOCUMENTS_ROOT):
return [], set()
index = await build_path_index(session, self.search_space_id)
target_folder_id: int | None = None
if normalized_path != DOCUMENTS_ROOT:
target_path = normalized_path
matches = [
fid for fid, fpath in index.folder_paths.items() if fpath == target_path
]
if not matches:
return [], set()
target_folder_id = matches[0]
result = await session.execute(
select(Document.id, Document.title, Document.folder_id, Document.updated_at)
.where(Document.search_space_id == self.search_space_id)
.where(
Document.folder_id == target_folder_id
if target_folder_id is not None
else Document.folder_id.is_(None)
)
)
rows = result.all()
file_infos: list[FileInfo] = []
for row in rows:
path = doc_to_virtual_path(
doc_id=row.id,
title=str(row.title or "untitled"),
folder_id=row.folder_id,
index=index,
)
modified = ""
if row.updated_at is not None:
with contextlib.suppress(Exception):
modified = row.updated_at.astimezone(UTC).isoformat()
file_infos.append(
FileInfo(
path=path,
is_dir=False,
size=0,
modified_at=modified,
)
)
subdirs: set[str] = set()
for _fid, fpath in index.folder_paths.items():
if fpath == normalized_path:
continue
base = normalized_path.rstrip("/")
if not fpath.startswith(base + "/"):
continue
rel = fpath[len(base) + 1 :]
if "/" in rel:
continue
subdirs.add(base + "/" + rel)
return file_infos, subdirs
async def aread( # type: ignore[override]
self,
file_path: str,
offset: int = 0,
limit: int = 2000,
) -> str:
files = self._state_files()
file_data = files.get(file_path)
if file_data is not None:
return format_read_response(file_data, offset, limit)
loaded = await self._load_file_data(file_path)
if loaded is None:
return f"Error: File '{file_path}' not found"
file_data, _ = loaded
return format_read_response(file_data, offset, limit)
def read(self, file_path: str, offset: int = 0, limit: int = 2000) -> str: # type: ignore[override]
return asyncio.run(self.aread(file_path, offset, limit))
async def _load_file_data(
self,
path: str,
) -> tuple[dict[str, Any], int | None] | None:
"""Lazy-load a virtual KB document into a deepagents ``FileData``.
Returns ``(file_data, doc_id)`` or ``None`` if the path doesn't map
to any known document. ``doc_id`` is ``None`` for the synthetic
anonymous document so the caller doesn't track it as a DB-backed file.
"""
anon = self._kb_anon_doc()
if anon and str(anon.get("path") or "") == path:
doc_payload = {
"document_id": -1,
"chunks": list(anon.get("chunks") or []),
"matched_chunk_ids": [],
"document": {
"id": -1,
"title": anon.get("title") or "uploaded_document",
"document_type": "FILE",
"metadata": {"source": "anonymous_upload"},
},
"source": "FILE",
}
xml = build_document_xml(doc_payload, matched_chunk_ids=set())
file_data = create_file_data(xml)
return file_data, None
if not path.startswith(DOCUMENTS_ROOT):
return None
async with shielded_async_session() as session:
document = await virtual_path_to_doc(
session,
search_space_id=self.search_space_id,
virtual_path=path,
)
if document is None:
return None
chunk_rows = await session.execute(
select(Chunk.id, Chunk.content)
.where(Chunk.document_id == document.id)
.order_by(Chunk.id)
)
chunks = [
{"chunk_id": row.id, "content": row.content} for row in chunk_rows.all()
]
doc_payload = {
"document_id": document.id,
"chunks": chunks,
"matched_chunk_ids": list(self._matched_chunk_ids(document.id)),
"document": {
"id": document.id,
"title": document.title,
"document_type": (
document.document_type.value
if getattr(document, "document_type", None) is not None
else "UNKNOWN"
),
"metadata": dict(document.document_metadata or {}),
},
"source": (
document.document_type.value
if getattr(document, "document_type", None) is not None
else "UNKNOWN"
),
}
xml = build_document_xml(
doc_payload,
matched_chunk_ids=self._matched_chunk_ids(document.id),
)
file_data = create_file_data(xml)
return file_data, document.id
# ------------------------------------------------------------------ writes
async def awrite(self, file_path: str, content: str) -> WriteResult: # type: ignore[override]
files = self._state_files()
if file_path in files:
return WriteResult(
error=(
f"Cannot write to {file_path} because it already exists. "
"Read and then make an edit, or write to a new path."
)
)
new_file_data = create_file_data(content)
return WriteResult(path=file_path, files_update={file_path: new_file_data})
def write(self, file_path: str, content: str) -> WriteResult: # type: ignore[override]
return asyncio.run(self.awrite(file_path, content))
async def aedit( # type: ignore[override]
self,
file_path: str,
old_string: str,
new_string: str,
replace_all: bool = False,
) -> EditResult:
files = self._state_files()
file_data = files.get(file_path)
if file_data is None:
loaded = await self._load_file_data(file_path)
if loaded is None:
return EditResult(error=f"Error: File '{file_path}' not found")
file_data, _ = loaded
content = file_data_to_string(file_data)
result = perform_string_replacement(
content, old_string, new_string, replace_all
)
if isinstance(result, str):
return EditResult(error=result)
new_content, occurrences = result
new_file_data = update_file_data(file_data, new_content)
return EditResult(
path=file_path,
files_update={file_path: new_file_data},
occurrences=int(occurrences),
)
def edit( # type: ignore[override]
self,
file_path: str,
old_string: str,
new_string: str,
replace_all: bool = False,
) -> EditResult:
return asyncio.run(self.aedit(file_path, old_string, new_string, replace_all))
# ------------------------------------------------------------------ glob/grep
async def aglob_info(self, pattern: str, path: str = "/") -> list[FileInfo]: # type: ignore[override]
normalized = self._normalize_listing_path(path)
results: list[FileInfo] = []
seen: set[str] = set()
files = self._state_files()
moved_removed, _ = self._moved_view_paths(files)
regex = re.compile(fnmatch.translate(pattern))
for path_key, fd in files.items():
if path_key in moved_removed:
continue
if not _is_under(path_key, normalized):
continue
rel = (
path_key[len(normalized) :].lstrip("/")
if normalized != "/"
else path_key.lstrip("/")
)
if not regex.match(rel) and not regex.match(path_key):
continue
if path_key in seen:
continue
size = self._file_data_size(fd) if isinstance(fd, dict) else 0
results.append(
FileInfo(
path=path_key,
is_dir=False,
size=int(size),
modified_at=fd.get("modified_at", "")
if isinstance(fd, dict)
else "",
)
)
seen.add(path_key)
if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/":
try:
async with shielded_async_session() as session:
index = await build_path_index(session, self.search_space_id)
rows = await session.execute(
select(Document.id, Document.title, Document.folder_id).where(
Document.search_space_id == self.search_space_id
)
)
for row in rows.all():
candidate = doc_to_virtual_path(
doc_id=row.id,
title=str(row.title or "untitled"),
folder_id=row.folder_id,
index=index,
)
if candidate in seen or candidate in moved_removed:
continue
if not _is_under(candidate, normalized):
continue
rel = (
candidate[len(normalized) :].lstrip("/")
if normalized != "/"
else candidate.lstrip("/")
)
if not regex.match(rel) and not regex.match(candidate):
continue
results.append(
FileInfo(
path=candidate, is_dir=False, size=0, modified_at=""
)
)
seen.add(candidate)
except Exception as exc: # pragma: no cover - defensive
logger.warning("KBPostgresBackend.aglob_info DB error: %s", exc)
results.sort(key=lambda fi: fi.get("path", ""))
return results
def glob_info(self, pattern: str, path: str = "/") -> list[FileInfo]: # type: ignore[override]
return asyncio.run(self.aglob_info(pattern, path))
async def agrep_raw( # type: ignore[override]
self,
pattern: str,
path: str | None = None,
glob: str | None = None,
) -> list[GrepMatch] | str:
if not pattern:
return "Error: pattern cannot be empty"
normalized = self._normalize_listing_path(path or "/")
matches: list[GrepMatch] = []
files = self._state_files()
moved_removed, _ = self._moved_view_paths(files)
glob_re = re.compile(fnmatch.translate(glob)) if glob else None
for path_key, fd in files.items():
if path_key in moved_removed:
continue
if not _is_under(path_key, normalized):
continue
if glob_re is not None and not glob_re.match(_basename(path_key)):
continue
if not isinstance(fd, dict):
continue
for line_no, line in enumerate(fd.get("content") or [], 1):
if pattern in line:
matches.append(
GrepMatch(path=path_key, line=int(line_no), text=str(line))
)
if len(matches) >= _GREP_MAX_TOTAL_MATCHES:
return matches
if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/":
try:
async with shielded_async_session() as session:
index = await build_path_index(session, self.search_space_id)
sub = (
select(Chunk.document_id, Chunk.id, Chunk.content)
.join(Document, Document.id == Chunk.document_id)
.where(Document.search_space_id == self.search_space_id)
.where(Chunk.content.ilike(f"%{pattern}%"))
.order_by(Chunk.document_id, Chunk.id)
)
chunk_rows = await session.execute(sub)
per_doc: dict[int, int] = {}
doc_id_to_path: dict[int, str] = {}
needed_doc_ids: set[int] = set()
chunk_buffer: list[tuple[int, int, str]] = []
for row in chunk_rows.all():
per_doc.setdefault(row.document_id, 0)
if per_doc[row.document_id] >= _GREP_MAX_PER_DOC:
continue
per_doc[row.document_id] += 1
chunk_buffer.append((row.document_id, row.id, row.content))
needed_doc_ids.add(row.document_id)
if sum(per_doc.values()) >= _GREP_MAX_TOTAL_MATCHES - len(
matches
):
break
if needed_doc_ids:
doc_rows = await session.execute(
select(
Document.id, Document.title, Document.folder_id
).where(Document.id.in_(list(needed_doc_ids)))
)
for row in doc_rows.all():
doc_id_to_path[row.id] = doc_to_virtual_path(
doc_id=row.id,
title=str(row.title or "untitled"),
folder_id=row.folder_id,
index=index,
)
for doc_id, chunk_id, content in chunk_buffer:
candidate = doc_id_to_path.get(doc_id)
if not candidate or candidate in moved_removed:
continue
if not _is_under(candidate, normalized):
continue
if glob_re is not None and not glob_re.match(
_basename(candidate)
):
continue
snippet = " ".join(str(content).split())[:240]
matches.append(
GrepMatch(
path=candidate,
line=0,
text=(
f"<chunk-match in {candidate} chunk_id={chunk_id}>: "
f"{snippet}"
),
)
)
if len(matches) >= _GREP_MAX_TOTAL_MATCHES:
break
except Exception as exc: # pragma: no cover - defensive
logger.warning("KBPostgresBackend.agrep_raw DB error: %s", exc)
return matches
def grep_raw( # type: ignore[override]
self,
pattern: str,
path: str | None = None,
glob: str | None = None,
) -> list[GrepMatch] | str:
return asyncio.run(self.agrep_raw(pattern, path, glob))
# ------------------------------------------------------------------ list_tree (helper)
async def alist_tree_listing(
self,
path: str = DOCUMENTS_ROOT,
*,
max_depth: int | None = 8,
page_size: int = 500,
include_files: bool = True,
include_dirs: bool = True,
) -> dict[str, Any]:
"""Recursive tree listing for cloud mode.
Mirrors the shape returned by :class:`MultiRootLocalFolderBackend.list_tree`:
``{"entries": [{path, is_dir, size, modified_at, depth}, ...], "truncated": bool}``.
"""
normalized = self._normalize_listing_path(path or DOCUMENTS_ROOT)
if not normalized.startswith(DOCUMENTS_ROOT) and normalized != "/":
return {"error": "Error: path must be under /documents/"}
entries: list[dict[str, Any]] = []
truncated = False
try:
async with shielded_async_session() as session:
index = await build_path_index(session, self.search_space_id)
doc_rows_raw = await session.execute(
select(
Document.id,
Document.title,
Document.folder_id,
Document.updated_at,
).where(Document.search_space_id == self.search_space_id)
)
doc_rows = list(doc_rows_raw.all())
except Exception as exc: # pragma: no cover
logger.warning("KBPostgresBackend.alist_tree_listing DB error: %s", exc)
return {"entries": [], "truncated": False}
files = self._state_files()
moved_removed, _ = self._moved_view_paths(files)
anon = self._kb_anon_doc()
anon_path = str(anon.get("path") or "") if anon else ""
def _depth_of(p: str) -> int:
if p == DOCUMENTS_ROOT:
return 0
rel_root = (
p[len(DOCUMENTS_ROOT) :].lstrip("/")
if normalized.startswith(DOCUMENTS_ROOT)
else p.lstrip("/")
)
return len([part for part in rel_root.split("/") if part])
def _add_entry(entry: dict[str, Any]) -> bool:
nonlocal truncated
if len(entries) >= page_size:
truncated = True
return False
entries.append(entry)
return True
if include_dirs:
for _fid, fpath in sorted(index.folder_paths.items(), key=lambda kv: kv[1]):
if not _is_under(fpath, normalized):
continue
depth = _depth_of(fpath)
if max_depth is not None and depth > max_depth:
continue
if not _add_entry(
{
"path": fpath,
"is_dir": True,
"size": 0,
"modified_at": "",
"depth": depth,
}
):
return {"entries": entries, "truncated": True}
for staged in self._staged_dirs():
if not _is_under(staged, normalized):
continue
depth = _depth_of(staged)
if max_depth is not None and depth > max_depth:
continue
if any(e["path"] == staged for e in entries):
continue
if not _add_entry(
{
"path": staged,
"is_dir": True,
"size": 0,
"modified_at": "",
"depth": depth,
}
):
return {"entries": entries, "truncated": True}
if include_files:
for row in sorted(doc_rows, key=lambda r: str(r.title or "")):
candidate = doc_to_virtual_path(
doc_id=row.id,
title=str(row.title or "untitled"),
folder_id=row.folder_id,
index=index,
)
if candidate in moved_removed:
continue
if not _is_under(candidate, normalized):
continue
depth = _depth_of(candidate)
if max_depth is not None and depth > max_depth:
continue
modified = ""
if row.updated_at is not None:
with contextlib.suppress(Exception):
modified = row.updated_at.astimezone(UTC).isoformat()
if not _add_entry(
{
"path": candidate,
"is_dir": False,
"size": 0,
"modified_at": modified,
"depth": depth,
}
):
return {"entries": entries, "truncated": True}
if anon_path and _is_under(anon_path, normalized):
depth = _depth_of(anon_path)
if (max_depth is None or depth <= max_depth) and not _add_entry(
{
"path": anon_path,
"is_dir": False,
"size": len(str(anon.get("content") or "")),
"modified_at": "",
"depth": depth,
}
):
return {"entries": entries, "truncated": True}
for path_key, fd in files.items():
if not isinstance(path_key, str):
continue
if not _is_under(path_key, normalized):
continue
if any(e["path"] == path_key for e in entries):
continue
if not (
path_key.startswith(DOCUMENTS_ROOT)
or _basename(path_key).startswith(_TEMP_PREFIX)
):
continue
depth = _depth_of(path_key)
if max_depth is not None and depth > max_depth:
continue
size = self._file_data_size(fd) if isinstance(fd, dict) else 0
if not _add_entry(
{
"path": path_key,
"is_dir": False,
"size": int(size),
"modified_at": fd.get("modified_at", "")
if isinstance(fd, dict)
else "",
"depth": depth,
}
):
return {"entries": entries, "truncated": True}
return {"entries": entries, "truncated": truncated}
# ------------------------------------------------------------------ uploads (unsupported)
def upload_files( # type: ignore[override]
self, files: list[tuple[str, bytes]]
) -> list[FileUploadResponse]:
msg = "KBPostgresBackend does not support upload_files."
raise NotImplementedError(msg)
def download_files( # type: ignore[override]
self, paths: list[str]
) -> list[FileDownloadResponse]:
responses: list[FileDownloadResponse] = []
files = self._state_files()
for path in paths:
fd = files.get(path)
if fd is None:
responses.append(
FileDownloadResponse(
path=path, content=None, error="file_not_found"
)
)
continue
content_str = file_data_to_string(fd)
responses.append(
FileDownloadResponse(
path=path,
content=content_str.encode("utf-8"),
error=None,
)
)
return responses
# --- module-level small helpers ---------------------------------------------
async def list_tree_listing(
backend: KBPostgresBackend,
path: str,
*,
max_depth: int | None = 8,
page_size: int = 500,
include_files: bool = True,
include_dirs: bool = True,
) -> dict[str, Any]:
"""Async helper used by the overridden ``list_tree`` tool wrapper."""
return await backend.alist_tree_listing(
path,
max_depth=max_depth,
page_size=page_size,
include_files=include_files,
include_dirs=include_dirs,
)
__all__ = [
"KBPostgresBackend",
"list_tree_listing",
"paginate_listing",
]

View file

@ -1,10 +1,24 @@
"""Knowledge-base pre-search middleware for the SurfSense new chat agent.
"""Hybrid-search priority middleware for the SurfSense new chat agent.
This middleware runs before the main agent loop and seeds a virtual filesystem
(`files` state) with relevant documents retrieved via hybrid search. On each
turn the filesystem is *expanded* new results merge with documents loaded
during prior turns and a synthetic ``ls`` result is injected into the message
history so the LLM is immediately aware of the current filesystem structure.
This middleware runs ``before_agent`` on every turn and writes:
* ``state["kb_priority"]`` the top-K most relevant documents for the
current user message, used to render a ``<priority_documents>`` system
message immediately before the user turn.
* ``state["kb_matched_chunk_ids"]`` internal hand-off mapping
(``Document.id`` matched chunk IDs) consumed by
:class:`KBPostgresBackend._load_file_data` when the agent first reads each
document, so the XML wrapper can flag matched sections in
``<chunk_index>``.
The previous "scoped filesystem" behaviour (synthetic ``ls`` + state
``files`` seeding) is intentionally removed: documents are now lazy-loaded
from Postgres on demand, with the full workspace tree rendered separately
by :class:`KnowledgeTreeMiddleware`.
In anonymous mode the middleware skips hybrid search entirely and emits a
single-entry priority list pointing at the Redis-loaded document
(``state["kb_anon_doc"]``).
"""
from __future__ import annotations
@ -13,27 +27,30 @@ import asyncio
import json
import logging
import re
import uuid
from collections.abc import Sequence
from datetime import UTC, datetime
from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langgraph.runtime import Runtime
from litellm import token_counter
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
from app.agents.new_chat.path_resolver import (
PathIndex,
build_path_index,
doc_to_virtual_path,
)
from app.agents.new_chat.utils import parse_date_or_datetime, resolve_date_range
from app.db import (
NATIVE_TO_LEGACY_DOCTYPE,
Chunk,
Document,
Folder,
shielded_async_session,
)
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
@ -70,7 +87,6 @@ class KBSearchPlan(BaseModel):
def _extract_text_from_message(message: BaseMessage) -> str:
"""Extract plain text from a message content."""
content = getattr(message, "content", "")
if isinstance(content, str):
return content
@ -85,19 +101,6 @@ def _extract_text_from_message(message: BaseMessage) -> str:
return str(content)
def _safe_filename(value: str, *, fallback: str = "untitled.xml") -> str:
"""Convert arbitrary text into a filesystem-safe filename."""
name = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
name = re.sub(r"\s+", " ", name)
if not name:
name = fallback
if len(name) > 180:
name = name[:180].rstrip()
if not name.lower().endswith(".xml"):
name = f"{name}.xml"
return name
def _render_recent_conversation(
messages: Sequence[BaseMessage],
*,
@ -107,10 +110,9 @@ def _render_recent_conversation(
) -> str:
"""Render recent dialogue for internal planning under a token budget.
Prefers the latest messages and uses the project's existing model-aware
token budgeting hooks when available on the LLM (`_count_tokens`,
`_get_max_input_tokens`). Falls back to the prior fixed-message heuristic
if token counting is unavailable.
Filters to ``HumanMessage`` and ``AIMessage`` (without tool_calls) so that
injected ``SystemMessage`` artefacts (priority list, workspace tree,
file-write contract) don't pollute the planner prompt.
"""
rendered: list[tuple[str, str]] = []
for message in messages:
@ -133,8 +135,6 @@ def _render_recent_conversation(
if not rendered:
return ""
# Exclude the latest user message from "recent conversation" because it is
# already passed separately as "Latest user message" in the planner prompt.
if rendered and rendered[-1][0] == "user" and rendered[-1][1] == user_text.strip():
rendered = rendered[:-1]
@ -216,8 +216,6 @@ def _render_recent_conversation(
selected_lines = candidate_lines
continue
# If the full message does not fit, keep as much of this most-recent
# older message as possible via binary search.
lo, hi = 1, len(text)
best_line: str | None = None
while lo <= hi:
@ -249,7 +247,6 @@ def _build_kb_planner_prompt(
recent_conversation: str,
user_text: str,
) -> str:
"""Build a compact internal prompt for KB query rewriting and date scoping."""
today = datetime.now(UTC).date().isoformat()
return (
"You optimize internal knowledge-base search inputs for document retrieval.\n"
@ -275,12 +272,10 @@ def _build_kb_planner_prompt(
def _extract_json_payload(text: str) -> str:
"""Extract a JSON object from a raw LLM response."""
stripped = text.strip()
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL)
if fenced:
return fenced.group(1)
start = stripped.find("{")
end = stripped.rfind("}")
if start != -1 and end != -1 and end > start:
@ -289,7 +284,6 @@ def _extract_json_payload(text: str) -> str:
def _parse_kb_search_plan_response(response_text: str) -> KBSearchPlan:
"""Parse and validate the planner's JSON response."""
payload = json.loads(_extract_json_payload(response_text))
return KBSearchPlan.model_validate(payload)
@ -298,212 +292,19 @@ def _normalize_optional_date_range(
start_date: str | None,
end_date: str | None,
) -> tuple[datetime | None, datetime | None]:
"""Normalize optional planner dates into a UTC datetime range."""
parsed_start = parse_date_or_datetime(start_date) if start_date else None
parsed_end = parse_date_or_datetime(end_date) if end_date else None
if parsed_start is None and parsed_end is None:
return None, None
resolved_start, resolved_end = resolve_date_range(parsed_start, parsed_end)
return resolved_start, resolved_end
def _build_document_xml(
document: dict[str, Any],
matched_chunk_ids: set[int] | None = None,
) -> str:
"""Build citation-friendly XML with a ``<chunk_index>`` for smart seeking.
The ``<chunk_index>`` at the top of each document lists every chunk with its
line range inside ``<document_content>`` and flags chunks that directly
matched the search query (``matched="true"``). This lets the LLM jump
straight to the most relevant section via ``read_file(offset=, limit=)``
instead of reading sequentially from the start.
"""
matched = matched_chunk_ids or set()
doc_meta = document.get("document") or {}
metadata = (doc_meta.get("metadata") or {}) if isinstance(doc_meta, dict) else {}
document_id = doc_meta.get("id", document.get("document_id", "unknown"))
document_type = doc_meta.get("document_type", document.get("source", "UNKNOWN"))
title = doc_meta.get("title") or metadata.get("title") or "Untitled Document"
url = (
metadata.get("url") or metadata.get("source") or metadata.get("page_url") or ""
)
metadata_json = json.dumps(metadata, ensure_ascii=False)
# --- 1. Metadata header (fixed structure) ---
metadata_lines: list[str] = [
"<document>",
"<document_metadata>",
f" <document_id>{document_id}</document_id>",
f" <document_type>{document_type}</document_type>",
f" <title><![CDATA[{title}]]></title>",
f" <url><![CDATA[{url}]]></url>",
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
"</document_metadata>",
"",
]
# --- 2. Pre-build chunk XML strings to compute line counts ---
chunks = document.get("chunks") or []
chunk_entries: list[tuple[int | None, str]] = [] # (chunk_id, xml_string)
if isinstance(chunks, list):
for chunk in chunks:
if not isinstance(chunk, dict):
continue
chunk_id = chunk.get("chunk_id") or chunk.get("id")
chunk_content = str(chunk.get("content", "")).strip()
if not chunk_content:
continue
if chunk_id is None:
xml = f" <chunk><![CDATA[{chunk_content}]]></chunk>"
else:
xml = f" <chunk id='{chunk_id}'><![CDATA[{chunk_content}]]></chunk>"
chunk_entries.append((chunk_id, xml))
# --- 3. Compute line numbers for every chunk ---
# Layout (1-indexed lines for read_file):
# metadata_lines -> len(metadata_lines) lines
# <chunk_index> -> 1 line
# index entries -> len(chunk_entries) lines
# </chunk_index> -> 1 line
# (empty line) -> 1 line
# <document_content> -> 1 line
# chunk xml lines…
# </document_content> -> 1 line
# </document> -> 1 line
index_overhead = (
1 + len(chunk_entries) + 1 + 1 + 1
) # tags + empty + <document_content>
first_chunk_line = len(metadata_lines) + index_overhead + 1 # 1-indexed
current_line = first_chunk_line
index_entry_lines: list[str] = []
for cid, xml_str in chunk_entries:
num_lines = xml_str.count("\n") + 1
end_line = current_line + num_lines - 1
matched_attr = ' matched="true"' if cid is not None and cid in matched else ""
if cid is not None:
index_entry_lines.append(
f' <entry chunk_id="{cid}" lines="{current_line}-{end_line}"{matched_attr}/>'
)
else:
index_entry_lines.append(
f' <entry lines="{current_line}-{end_line}"{matched_attr}/>'
)
current_line = end_line + 1
# --- 4. Assemble final XML ---
lines = metadata_lines.copy()
lines.append("<chunk_index>")
lines.extend(index_entry_lines)
lines.append("</chunk_index>")
lines.append("")
lines.append("<document_content>")
for _, xml_str in chunk_entries:
lines.append(xml_str)
lines.extend(["</document_content>", "</document>"])
return "\n".join(lines)
async def _get_folder_paths(
session: AsyncSession, search_space_id: int
) -> dict[int, str]:
"""Return a map of folder_id -> virtual folder path under /documents."""
result = await session.execute(
select(Folder.id, Folder.name, Folder.parent_id).where(
Folder.search_space_id == search_space_id
)
)
rows = result.all()
by_id = {row.id: {"name": row.name, "parent_id": row.parent_id} for row in rows}
cache: dict[int, str] = {}
def resolve_path(folder_id: int) -> str:
if folder_id in cache:
return cache[folder_id]
parts: list[str] = []
cursor: int | None = folder_id
visited: set[int] = set()
while cursor is not None and cursor in by_id and cursor not in visited:
visited.add(cursor)
entry = by_id[cursor]
parts.append(
_safe_filename(str(entry["name"]), fallback="folder").removesuffix(
".xml"
)
)
cursor = entry["parent_id"]
parts.reverse()
path = "/documents/" + "/".join(parts) if parts else "/documents"
cache[folder_id] = path
return path
for folder_id in by_id:
resolve_path(folder_id)
return cache
def _build_synthetic_ls(
existing_files: dict[str, Any] | None,
new_files: dict[str, Any],
*,
mentioned_paths: set[str] | None = None,
) -> tuple[AIMessage, ToolMessage]:
"""Build a synthetic ls("/documents") tool-call + result for the LLM context.
Mentioned files are listed first. A separate header tells the LLM which
files the user explicitly selected; the path list itself stays clean so
paths can be passed directly to ``read_file`` without stripping tags.
"""
_mentioned = mentioned_paths or set()
merged: dict[str, Any] = {**(existing_files or {}), **new_files}
doc_paths = [
p for p, v in merged.items() if p.startswith("/documents/") and v is not None
]
new_set = set(new_files)
mentioned_list = [p for p in doc_paths if p in _mentioned]
new_non_mentioned = [p for p in doc_paths if p in new_set and p not in _mentioned]
old_paths = [p for p in doc_paths if p not in new_set]
ordered = mentioned_list + new_non_mentioned + old_paths
parts: list[str] = []
if mentioned_list:
parts.append(
"USER-MENTIONED documents (read these thoroughly before answering):"
)
for p in mentioned_list:
parts.append(f" {p}")
parts.append("")
parts.append(str(ordered) if ordered else "No documents found.")
tool_call_id = f"auto_ls_{uuid.uuid4().hex[:12]}"
ai_msg = AIMessage(
content="",
tool_calls=[{"name": "ls", "args": {"path": "/documents"}, "id": tool_call_id}],
)
tool_msg = ToolMessage(
content="\n".join(parts),
tool_call_id=tool_call_id,
)
return ai_msg, tool_msg
return resolve_date_range(parsed_start, parsed_end)
def _resolve_search_types(
available_connectors: list[str] | None,
available_document_types: list[str] | None,
) -> list[str] | None:
"""Build a flat list of document-type strings for the chunk retriever.
Includes legacy equivalents from ``NATIVE_TO_LEGACY_DOCTYPE`` so that
old documents indexed under Composio names are still found.
Returns ``None`` when no filtering is desired (search all types).
"""
types: set[str] = set()
if available_document_types:
types.update(available_document_types)
@ -531,13 +332,8 @@ async def browse_recent_documents(
start_date: datetime | None = None,
end_date: datetime | None = None,
) -> list[dict[str, Any]]:
"""Return documents ordered by recency (newest first), no relevance ranking.
Used when the user's intent is temporal ("latest file", "most recent upload")
and hybrid search would produce poor results because the query has no
meaningful topical signal.
"""
from sqlalchemy import func, select
"""Return documents ordered by recency (newest first), no relevance ranking."""
from sqlalchemy import func
from app.db import DocumentType
@ -581,7 +377,6 @@ async def browse_recent_documents(
return []
doc_ids = [d.id for d in documents]
numbered = (
select(
Chunk.id.label("chunk_id"),
@ -632,6 +427,7 @@ async def browse_recent_documents(
else None
),
"metadata": metadata,
"folder_id": getattr(doc, "folder_id", None),
},
"source": (
doc.document_type.value
@ -640,12 +436,6 @@ async def browse_recent_documents(
),
}
)
logger.info(
"browse_recent_documents: %d docs returned for space=%d",
len(results),
search_space_id,
)
return results
@ -659,17 +449,11 @@ async def search_knowledge_base(
start_date: datetime | None = None,
end_date: datetime | None = None,
) -> list[dict[str, Any]]:
"""Run a single unified hybrid search against the knowledge base.
Uses one ``ChucksHybridSearchRetriever`` call across all document types
instead of fanning out per-connector. This reduces the number of DB
queries from ~10 to 2 (one RRF query + one chunk fetch).
"""
"""Run a single unified hybrid search against the knowledge base."""
if not query:
return []
[embedding] = embed_texts([query])
doc_types = _resolve_search_types(available_connectors, available_document_types)
retriever_top_k = min(top_k * 3, 30)
@ -693,14 +477,7 @@ async def fetch_mentioned_documents(
document_ids: list[int],
search_space_id: int,
) -> list[dict[str, Any]]:
"""Fetch explicitly mentioned documents with *all* their chunks.
Returns the same dict structure as ``search_knowledge_base`` so results
can be merged directly into ``build_scoped_filesystem``. Unlike search
results, every chunk is included (no top-K limiting) and none are marked
as ``matched`` since the entire document is relevant by virtue of the
user's explicit mention.
"""
"""Fetch explicitly mentioned documents."""
if not document_ids:
return []
@ -750,6 +527,7 @@ async def fetch_mentioned_documents(
else None
),
"metadata": metadata,
"folder_id": getattr(doc, "folder_id", None),
},
"source": (
doc.document_type.value
@ -762,96 +540,36 @@ async def fetch_mentioned_documents(
return results
async def build_scoped_filesystem(
*,
documents: Sequence[dict[str, Any]],
search_space_id: int,
) -> tuple[dict[str, dict[str, str]], dict[int, str]]:
"""Build a StateBackend-compatible files dict from search results.
Returns ``(files, doc_id_to_path)`` so callers can reliably map a
document id back to its filesystem path without guessing by title.
Paths are collision-proof: when two documents resolve to the same
path the doc-id is appended to disambiguate.
"""
async with shielded_async_session() as session:
folder_paths = await _get_folder_paths(session, search_space_id)
doc_ids = [
(doc.get("document") or {}).get("id")
for doc in documents
if isinstance(doc, dict)
]
doc_ids = [doc_id for doc_id in doc_ids if isinstance(doc_id, int)]
folder_by_doc_id: dict[int, int | None] = {}
if doc_ids:
doc_rows = await session.execute(
select(Document.id, Document.folder_id).where(
Document.search_space_id == search_space_id,
Document.id.in_(doc_ids),
)
)
folder_by_doc_id = {
row.id: row.folder_id for row in doc_rows.all() if row.id is not None
}
files: dict[str, dict[str, str]] = {}
doc_id_to_path: dict[int, str] = {}
for document in documents:
doc_meta = document.get("document") or {}
title = str(doc_meta.get("title") or "untitled")
doc_id = doc_meta.get("id")
folder_id = folder_by_doc_id.get(doc_id) if isinstance(doc_id, int) else None
base_folder = folder_paths.get(folder_id, "/documents")
file_name = _safe_filename(title)
path = f"{base_folder}/{file_name}"
if path in files:
stem = file_name.removesuffix(".xml")
path = f"{base_folder}/{stem} ({doc_id}).xml"
matched_ids = set(document.get("matched_chunk_ids") or [])
xml_content = _build_document_xml(document, matched_chunk_ids=matched_ids)
files[path] = {
"content": xml_content.split("\n"),
"encoding": "utf-8",
"created_at": "",
"modified_at": "",
}
if isinstance(doc_id, int):
doc_id_to_path[doc_id] = path
return files, doc_id_to_path
def _render_priority_message(priority: list[dict[str, Any]]) -> SystemMessage:
"""Render the priority list as a single ``<priority_documents>`` system message."""
if not priority:
body = "(no priority documents for this turn)"
else:
lines: list[str] = []
for entry in priority:
score = entry.get("score")
mentioned = entry.get("mentioned")
score_str = f"{score:.3f}" if isinstance(score, (int, float)) else "n/a"
mark = " [USER-MENTIONED]" if mentioned else ""
lines.append(f"- {entry.get('path', '')} (score={score_str}){mark}")
body = "\n".join(lines)
return SystemMessage(
content=(
"<priority_documents>\n"
"These documents are most relevant to the latest user message; "
"read them first. Matched sections are flagged inside each "
"document's <chunk_index>.\n"
f"{body}\n"
"</priority_documents>"
)
)
def _build_anon_scoped_filesystem(
documents: Sequence[dict[str, Any]],
) -> dict[str, dict[str, str]]:
"""Build a scoped filesystem for anonymous documents without DB queries.
Anonymous uploads have no folders, so all files go under /documents.
"""
files: dict[str, dict[str, str]] = {}
for document in documents:
doc_meta = document.get("document") or {}
title = str(doc_meta.get("title") or "untitled")
file_name = _safe_filename(title)
path = f"/documents/{file_name}"
if path in files:
doc_id = doc_meta.get("id", "dup")
stem = file_name.removesuffix(".xml")
path = f"/documents/{stem} ({doc_id}).xml"
matched_ids = set(document.get("matched_chunk_ids") or [])
xml_content = _build_document_xml(document, matched_chunk_ids=matched_ids)
files[path] = {
"content": xml_content.split("\n"),
"encoding": "utf-8",
"created_at": "",
"modified_at": "",
}
return files
class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Pre-agent middleware that always searches the KB and seeds a scoped filesystem."""
class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Compute hybrid-search priority hints for the current turn."""
tools = ()
state_schema = SurfSenseFilesystemState
def __init__(
self,
@ -863,7 +581,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
available_document_types: list[str] | None = None,
top_k: int = 10,
mentioned_document_ids: list[int] | None = None,
anon_session_id: str | None = None,
) -> None:
self.llm = llm
self.search_space_id = search_space_id
@ -872,7 +589,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
self.available_document_types = available_document_types
self.top_k = top_k
self.mentioned_document_ids = mentioned_document_ids or []
self.anon_session_id = anon_session_id
async def _plan_search_inputs(
self,
@ -880,10 +596,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
messages: Sequence[BaseMessage],
user_text: str,
) -> tuple[str, datetime | None, datetime | None, bool]:
"""Rewrite the KB query and infer optional date filters with the LLM.
Returns (optimized_query, start_date, end_date, is_recency_query).
"""
if self.llm is None:
return user_text, None, None, False
@ -914,7 +626,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
)
is_recency = plan.is_recency_query
_perf_log.info(
"[kb_fs_middleware] planner in %.3fs query=%r optimized=%r "
"[kb_priority] planner in %.3fs query=%r optimized=%r "
"start=%s end=%s recency=%s",
loop.time() - t0,
user_text[:80],
@ -946,106 +658,68 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
pass
return asyncio.run(self.abefore_agent(state, runtime))
async def _load_anon_document(self) -> dict[str, Any] | None:
"""Load the anonymous user's uploaded document from Redis."""
if not self.anon_session_id:
return None
try:
import redis.asyncio as aioredis
from app.config import config
redis_client = aioredis.from_url(
config.REDIS_APP_URL, decode_responses=True
)
try:
redis_key = f"anon:doc:{self.anon_session_id}"
data = await redis_client.get(redis_key)
if not data:
return None
doc = json.loads(data)
return {
"document_id": -1,
"content": doc.get("content", ""),
"score": 1.0,
"chunks": [
{
"chunk_id": -1,
"content": doc.get("content", ""),
}
],
"matched_chunk_ids": [-1],
"document": {
"id": -1,
"title": doc.get("filename", "uploaded_document"),
"document_type": "FILE",
"metadata": {"source": "anonymous_upload"},
},
"source": "FILE",
"_user_mentioned": True,
}
finally:
await redis_client.aclose()
except Exception as exc:
logger.warning("Failed to load anonymous document from Redis: %s", exc)
return None
async def abefore_agent( # type: ignore[override]
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
del runtime
if self.filesystem_mode != FilesystemMode.CLOUD:
return None
messages = state.get("messages") or []
if not messages:
return None
if self.filesystem_mode != FilesystemMode.CLOUD:
# Local-folder mode should not seed cloud KB documents into filesystem.
return None
last_human = None
last_human: HumanMessage | None = None
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
last_human = msg
break
if last_human is None:
return None
user_text = _extract_text_from_message(last_human).strip()
if not user_text:
return None
t0 = _perf_log and asyncio.get_event_loop().time()
existing_files = state.get("files")
anon_doc = state.get("kb_anon_doc")
if anon_doc:
return self._anon_priority(state, anon_doc)
# --- Anonymous session: load Redis doc and skip DB queries ---
if self.anon_session_id:
merged: list[dict[str, Any]] = []
anon_doc = await self._load_anon_document()
if anon_doc:
merged.append(anon_doc)
return await self._authenticated_priority(state, messages, user_text)
if merged:
new_files = _build_anon_scoped_filesystem(merged)
mentioned_paths = set(new_files.keys())
else:
new_files = {}
mentioned_paths = set()
def _anon_priority(
self,
state: AgentState,
anon_doc: dict[str, Any],
) -> dict[str, Any]:
path = str(anon_doc.get("path") or "")
title = str(anon_doc.get("title") or "uploaded_document")
priority = [
{
"path": path,
"score": 1.0,
"document_id": None,
"title": title,
"mentioned": True,
}
]
new_messages = list(state.get("messages") or [])
insert_at = max(len(new_messages) - 1, 0)
new_messages.insert(insert_at, _render_priority_message(priority))
return {
"kb_priority": priority,
"kb_matched_chunk_ids": {},
"messages": new_messages,
}
ai_msg, tool_msg = _build_synthetic_ls(
existing_files,
new_files,
mentioned_paths=mentioned_paths,
)
if t0 is not None:
_perf_log.info(
"[kb_fs_middleware] anon completed in %.3fs new_files=%d",
asyncio.get_event_loop().time() - t0,
len(new_files),
)
return {"files": new_files, "messages": [ai_msg, tool_msg]}
# --- Authenticated session: full KB search ---
async def _authenticated_priority(
self,
state: AgentState,
messages: Sequence[BaseMessage],
user_text: str,
) -> dict[str, Any]:
t0 = asyncio.get_event_loop().time()
(
planned_query,
start_date,
@ -1056,7 +730,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
user_text=user_text,
)
# --- 1. Fetch mentioned documents (user-selected, all chunks) ---
mentioned_results: list[dict[str, Any]] = []
if self.mentioned_document_ids:
mentioned_results = await fetch_mentioned_documents(
@ -1065,7 +738,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
)
self.mentioned_document_ids = []
# --- 2. Run KB search (recency browse or hybrid) ---
if is_recency:
doc_types = _resolve_search_types(
self.available_connectors, self.available_document_types
@ -1088,48 +760,108 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
end_date=end_date,
)
# --- 3. Merge: mentioned first, then search (dedup by doc id) ---
seen_doc_ids: set[int] = set()
merged_auth: list[dict[str, Any]] = []
merged: list[dict[str, Any]] = []
for doc in mentioned_results:
doc_id = (doc.get("document") or {}).get("id")
if doc_id is not None:
if isinstance(doc_id, int):
seen_doc_ids.add(doc_id)
merged_auth.append(doc)
merged.append(doc)
for doc in search_results:
doc_id = (doc.get("document") or {}).get("id")
if doc_id is not None and doc_id in seen_doc_ids:
if isinstance(doc_id, int) and doc_id in seen_doc_ids:
continue
merged_auth.append(doc)
merged.append(doc)
# --- 4. Build scoped filesystem ---
new_files, doc_id_to_path = await build_scoped_filesystem(
documents=merged_auth,
search_space_id=self.search_space_id,
priority, matched_chunk_ids = await self._materialize_priority(merged)
new_messages = list(messages)
insert_at = max(len(new_messages) - 1, 0)
new_messages.insert(insert_at, _render_priority_message(priority))
_perf_log.info(
"[kb_priority] completed in %.3fs query=%r priority=%d mentioned=%d",
asyncio.get_event_loop().time() - t0,
user_text[:80],
len(priority),
len(mentioned_results),
)
mentioned_doc_ids = {
(d.get("document") or {}).get("id") for d in mentioned_results
}
mentioned_paths = {
doc_id_to_path[did] for did in mentioned_doc_ids if did in doc_id_to_path
return {
"kb_priority": priority,
"kb_matched_chunk_ids": matched_chunk_ids,
"messages": new_messages,
}
ai_msg, tool_msg = _build_synthetic_ls(
existing_files,
new_files,
mentioned_paths=mentioned_paths,
)
async def _materialize_priority(
self, merged: list[dict[str, Any]]
) -> tuple[list[dict[str, Any]], dict[int, list[int]]]:
"""Resolve canonical paths and matched chunk ids for the priority list."""
priority: list[dict[str, Any]] = []
matched_chunk_ids: dict[int, list[int]] = {}
if t0 is not None:
_perf_log.info(
"[kb_fs_middleware] completed in %.3fs query=%r optimized=%r "
"mentioned=%d new_files=%d total=%d",
asyncio.get_event_loop().time() - t0,
user_text[:80],
planned_query[:120],
len(mentioned_results),
len(new_files),
len(new_files) + len(existing_files or {}),
if not merged:
return priority, matched_chunk_ids
async with shielded_async_session() as session:
index: PathIndex = await build_path_index(session, self.search_space_id)
doc_ids = [
(doc.get("document") or {}).get("id")
for doc in merged
if isinstance(doc, dict)
]
doc_ids = [doc_id for doc_id in doc_ids if isinstance(doc_id, int)]
folder_by_doc_id: dict[int, int | None] = {}
if doc_ids:
folder_rows = await session.execute(
select(Document.id, Document.folder_id).where(
Document.search_space_id == self.search_space_id,
Document.id.in_(doc_ids),
)
)
folder_by_doc_id = {row.id: row.folder_id for row in folder_rows.all()}
for doc in merged:
doc_meta = doc.get("document") or {}
doc_id = doc_meta.get("id")
title = doc_meta.get("title") or "untitled"
folder_id = (
folder_by_doc_id.get(doc_id)
if isinstance(doc_id, int)
else doc_meta.get("folder_id")
)
return {"files": new_files, "messages": [ai_msg, tool_msg]}
path = doc_to_virtual_path(
doc_id=doc_id if isinstance(doc_id, int) else None,
title=str(title),
folder_id=folder_id if isinstance(folder_id, int) else None,
index=index,
)
priority.append(
{
"path": path,
"score": float(doc.get("score") or 0.0),
"document_id": doc_id if isinstance(doc_id, int) else None,
"title": str(title),
"mentioned": bool(doc.get("_user_mentioned")),
}
)
if isinstance(doc_id, int):
chunk_ids = doc.get("matched_chunk_ids") or []
if chunk_ids:
matched_chunk_ids[doc_id] = [
int(cid) for cid in chunk_ids if isinstance(cid, (int, str))
]
return priority, matched_chunk_ids
# Backwards-compatible alias for any external imports.
KnowledgeBaseSearchMiddleware = KnowledgePriorityMiddleware
__all__ = [
"KnowledgeBaseSearchMiddleware",
"KnowledgePriorityMiddleware",
"browse_recent_documents",
"fetch_mentioned_documents",
"search_knowledge_base",
]

View file

@ -0,0 +1,272 @@
"""Workspace-tree middleware for the SurfSense agent.
Renders the full ``Folder``+``Document`` tree under ``/documents/`` once per
turn (cloud only), caches it by ``(search_space_id, tree_version)``, and
injects the result as a ``<workspace_tree>`` system message immediately
before the latest human turn.
The render is bounded by two truncation layers:
1. **Entry cap** at most ``MAX_TREE_ENTRIES`` lines. The remainder is
replaced with a "use ls" hint.
2. **Token cap** at most ``MAX_TREE_TOKENS`` tokens (using the LLM's
token-count profile when available). If the entry-truncated tree still
exceeds the token cap we fall back to a root-only summary.
Anonymous mode renders only ``state['kb_anon_doc']`` (no DB calls).
This middleware also performs a one-time initialization of ``state['cwd']``
to ``"/documents"`` so subsequent middlewares and tools always see a valid
cwd in cloud mode.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import SystemMessage
from langgraph.runtime import Runtime
from sqlalchemy import select
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
from app.agents.new_chat.path_resolver import (
DOCUMENTS_ROOT,
PathIndex,
build_path_index,
doc_to_virtual_path,
)
from app.db import Document, shielded_async_session
try:
from litellm import token_counter
except Exception: # pragma: no cover - optional dep
token_counter = None # type: ignore[assignment]
logger = logging.getLogger(__name__)
MAX_TREE_ENTRIES = 500
MAX_TREE_TOKENS = 4000
def _approx_tokens(text: str) -> int:
"""Cheap fallback token estimate (1 token ~= 4 chars)."""
return max(1, (len(text) + 3) // 4)
def _count_tokens(text: str, *, llm: BaseChatModel | None) -> int:
if llm is None:
return _approx_tokens(text)
count_fn = getattr(llm, "_count_tokens", None)
if callable(count_fn):
try:
return int(count_fn([{"role": "user", "content": text}]))
except Exception:
pass
profile = getattr(llm, "profile", None)
model_names: list[str] = []
if isinstance(profile, dict):
tcms = profile.get("token_count_models")
if isinstance(tcms, list):
model_names.extend(name for name in tcms if isinstance(name, str) and name)
tcm = profile.get("token_count_model")
if isinstance(tcm, str) and tcm and tcm not in model_names:
model_names.append(tcm)
model_name = model_names[0] if model_names else getattr(llm, "model", None)
if not isinstance(model_name, str) or not model_name or token_counter is None:
return _approx_tokens(text)
try:
return int(
token_counter(
messages=[{"role": "user", "content": text}],
model=model_name,
)
)
except Exception:
return _approx_tokens(text)
class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Inject the workspace folder/document tree into the agent's context."""
tools = ()
state_schema = SurfSenseFilesystemState
def __init__(
self,
*,
search_space_id: int,
filesystem_mode: FilesystemMode,
llm: BaseChatModel | None = None,
max_entries: int = MAX_TREE_ENTRIES,
max_tokens: int = MAX_TREE_TOKENS,
) -> None:
self.search_space_id = search_space_id
self.filesystem_mode = filesystem_mode
self.llm = llm
self.max_entries = max_entries
self.max_tokens = max_tokens
self._cache: dict[tuple[int, int, bool], str] = {}
async def abefore_agent( # type: ignore[override]
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
del runtime
if self.filesystem_mode != FilesystemMode.CLOUD:
return None
update: dict[str, Any] = {}
if not state.get("cwd"):
update["cwd"] = DOCUMENTS_ROOT
anon_doc = state.get("kb_anon_doc")
if anon_doc:
tree_msg = self._render_anon_tree(anon_doc)
else:
tree_msg = await self._render_kb_tree(state)
messages = list(state.get("messages") or [])
insert_at = max(len(messages) - 1, 0)
messages.insert(insert_at, SystemMessage(content=tree_msg))
update["messages"] = messages
return update
def before_agent( # type: ignore[override]
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
try:
loop = asyncio.get_running_loop()
if loop.is_running():
return None
except RuntimeError:
pass
return asyncio.run(self.abefore_agent(state, runtime))
# ------------------------------------------------------------------ render
def _render_anon_tree(self, anon_doc: dict[str, Any]) -> str:
path = str(anon_doc.get("path") or "")
title = str(anon_doc.get("title") or "uploaded_document")
return (
"<workspace_tree>\n"
"Anonymous session — only one read-only document is available.\n"
f"{DOCUMENTS_ROOT}/\n"
f" {path}{title}\n"
"</workspace_tree>"
)
async def _render_kb_tree(self, state: AgentState) -> str:
version = int(state.get("tree_version") or 0)
cache_key = (self.search_space_id, version, False)
cached = self._cache.get(cache_key)
if cached is not None:
return cached
try:
async with shielded_async_session() as session:
index = await build_path_index(session, self.search_space_id)
doc_rows = await session.execute(
select(Document.id, Document.title, Document.folder_id).where(
Document.search_space_id == self.search_space_id
)
)
docs = list(doc_rows.all())
except Exception as exc: # pragma: no cover - defensive
logger.warning("knowledge_tree: DB error %s", exc)
return "<workspace_tree>\n(unavailable)\n</workspace_tree>"
rendered = self._format_tree(index, docs)
self._cache[cache_key] = rendered
return rendered
def _format_tree(self, index: PathIndex, docs: list[Any]) -> str:
folder_paths = sorted(set(index.folder_paths.values()))
doc_paths = sorted(
doc_to_virtual_path(
doc_id=row.id,
title=str(row.title or "untitled"),
folder_id=row.folder_id,
index=index,
)
for row in docs
)
all_paths = sorted(set(folder_paths + doc_paths + [DOCUMENTS_ROOT]))
lines: list[str] = []
for path in all_paths:
depth = (
0
if path == DOCUMENTS_ROOT
else len([p for p in path[len(DOCUMENTS_ROOT) :].split("/") if p])
)
indent = " " * depth
is_dir = path == DOCUMENTS_ROOT or path in folder_paths
display = (
path.rsplit("/", 1)[-1] if path != DOCUMENTS_ROOT else "/documents"
)
if is_dir:
lines.append(f"{indent}{display}/")
else:
lines.append(f"{indent}{display}")
if len(lines) >= self.max_entries:
remaining = len(all_paths) - len(lines)
if remaining > 0:
lines.append(
f"... {remaining} more entries — use "
"ls('/documents/<folder>', offset, limit) to expand"
)
break
body = "\n".join(lines)
rendered = f"<workspace_tree>\n{body}\n</workspace_tree>"
token_count = _count_tokens(rendered, llm=self.llm)
if token_count <= self.max_tokens:
return rendered
return self._format_root_summary(folder_paths, doc_paths)
def _format_root_summary(
self, folder_paths: list[str], doc_paths: list[str]
) -> str:
top_level: dict[str, int] = {}
loose_docs = 0
for path in doc_paths:
rel = path[len(DOCUMENTS_ROOT) :].lstrip("/")
if "/" in rel:
top = rel.split("/", 1)[0]
top_level[top] = top_level.get(top, 0) + 1
else:
loose_docs += 1
for path in folder_paths:
rel = path[len(DOCUMENTS_ROOT) :].lstrip("/")
if not rel:
continue
top = rel.split("/", 1)[0]
top_level.setdefault(top, 0)
lines = [DOCUMENTS_ROOT + "/"]
for name in sorted(top_level):
count = top_level[name]
lines.append(f" {name}/ ({count} document{'s' if count != 1 else ''})")
if loose_docs:
lines.append(
f" ({loose_docs} loose document{'s' if loose_docs != 1 else ''})"
)
lines.append(
"Tree is large; use list_tree('/documents/<folder>') to drill in "
"or ls('/documents/<folder>', offset, limit) for paginated listings."
)
return "<workspace_tree>\n" + "\n".join(lines) + "\n</workspace_tree>"
__all__ = ["KnowledgeTreeMiddleware"]

View file

@ -0,0 +1,351 @@
"""Canonical virtual-path resolver for SurfSense knowledge-base documents.
This module is the single source of truth for mapping ``Document`` rows to
virtual paths under ``/documents/`` and back. It is used by:
* :class:`KnowledgeTreeMiddleware` (rendering the workspace tree)
* :class:`KnowledgePriorityMiddleware` (computing priority paths)
* :class:`KBPostgresBackend` (``als_info`` / ``aread`` / move operations)
* :class:`KnowledgeBasePersistenceMiddleware` (resolving moves and creates)
Centralising the logic ensures that title-collision suffixes, folder paths,
and ``unique_identifier_hash`` lookups never drift between renders and
commits.
"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Document, DocumentType, Folder
from app.utils.document_converters import generate_unique_identifier_hash
DOCUMENTS_ROOT = "/documents"
"""Root virtual folder for all KB documents."""
_INVALID_FILENAME_CHARS = re.compile(r"[\\/:*?\"<>|]+")
_WHITESPACE_RUN = re.compile(r"\s+")
def safe_filename(value: str, *, fallback: str = "untitled.xml") -> str:
"""Convert arbitrary text into a filesystem-safe ``.xml`` filename."""
name = _INVALID_FILENAME_CHARS.sub("_", value).strip()
name = _WHITESPACE_RUN.sub(" ", name)
if not name:
name = fallback
if len(name) > 180:
name = name[:180].rstrip()
if not name.lower().endswith(".xml"):
name = f"{name}.xml"
return name
def safe_folder_segment(value: str, *, fallback: str = "folder") -> str:
"""Sanitize a single folder name into a path-safe segment."""
name = _INVALID_FILENAME_CHARS.sub("_", value).strip()
name = _WHITESPACE_RUN.sub(" ", name)
if not name:
return fallback
if len(name) > 180:
name = name[:180].rstrip()
return name
def _suffix_with_doc_id(filename: str, doc_id: int | None) -> str:
if doc_id is None:
return filename
if not filename.lower().endswith(".xml"):
return f"{filename} ({doc_id}).xml"
stem = filename[:-4]
return f"{stem} ({doc_id}).xml"
_SUFFIX_PATTERN = re.compile(r"\s\((\d+)\)\.xml$", re.IGNORECASE)
def parse_doc_id_suffix(filename: str) -> tuple[str, int | None]:
"""Strip a trailing ``" (<doc_id>).xml"`` suffix; return ``(stem, doc_id)``.
If no suffix is present, returns ``(stem_without_xml_extension, None)``.
"""
match = _SUFFIX_PATTERN.search(filename)
if match:
doc_id = int(match.group(1))
stem = filename[: match.start()]
return stem, doc_id
if filename.lower().endswith(".xml"):
return filename[:-4], None
return filename, None
@dataclass
class PathIndex:
"""In-memory occupancy snapshot used by :func:`doc_to_virtual_path`.
Built once per call site so collision handling is deterministic and so
we don't perform N folder lookups per render.
"""
folder_paths: dict[int, str] = field(default_factory=dict)
"""``Folder.id`` -> absolute virtual folder path under ``/documents``."""
occupants: dict[str, int] = field(default_factory=dict)
"""virtual path -> ``Document.id`` already occupying that path (this render)."""
async def _build_folder_paths(
session: AsyncSession,
search_space_id: int,
) -> dict[int, str]:
"""Compute ``Folder.id`` -> absolute virtual path under ``/documents``."""
result = await session.execute(
select(Folder.id, Folder.name, Folder.parent_id).where(
Folder.search_space_id == search_space_id
)
)
rows = result.all()
by_id = {row.id: {"name": row.name, "parent_id": row.parent_id} for row in rows}
cache: dict[int, str] = {}
def resolve(folder_id: int) -> str:
if folder_id in cache:
return cache[folder_id]
parts: list[str] = []
cursor: int | None = folder_id
visited: set[int] = set()
while cursor is not None and cursor in by_id and cursor not in visited:
visited.add(cursor)
entry = by_id[cursor]
parts.append(safe_folder_segment(str(entry["name"])))
cursor = entry["parent_id"]
parts.reverse()
path = f"{DOCUMENTS_ROOT}/" + "/".join(parts) if parts else DOCUMENTS_ROOT
cache[folder_id] = path
return path
for folder_id in by_id:
resolve(folder_id)
return cache
async def build_path_index(
session: AsyncSession,
search_space_id: int,
*,
populate_occupants: bool = True,
) -> PathIndex:
"""Build a :class:`PathIndex` for a search space.
``populate_occupants`` controls whether the occupancy map is pre-seeded
from existing ``Document`` rows. Most callers want this so that
:func:`doc_to_virtual_path` can detect collisions across the whole space;
the persistence middleware sets this to ``False`` when it is iterating to
decide where to place fresh documents.
"""
folder_paths = await _build_folder_paths(session, search_space_id)
occupants: dict[str, int] = {}
if populate_occupants:
rows = await session.execute(
select(Document.id, Document.title, Document.folder_id).where(
Document.search_space_id == search_space_id,
)
)
for row in rows.all():
base = folder_paths.get(row.folder_id, DOCUMENTS_ROOT)
filename = safe_filename(str(row.title or "untitled"))
path = f"{base}/{filename}"
if path in occupants and occupants[path] != row.id:
path = f"{base}/{_suffix_with_doc_id(filename, row.id)}"
occupants[path] = row.id
return PathIndex(folder_paths=folder_paths, occupants=occupants)
def doc_to_virtual_path(
*,
doc_id: int | None,
title: str,
folder_id: int | None,
index: PathIndex,
) -> str:
"""Return the canonical virtual path for a document.
Mutates ``index.occupants`` so subsequent calls see this assignment and
deterministically pick a different suffix for the next colliding doc.
"""
base = index.folder_paths.get(folder_id, DOCUMENTS_ROOT)
filename = safe_filename(str(title or "untitled"))
path = f"{base}/{filename}"
occupant = index.occupants.get(path)
if occupant is not None and occupant != doc_id:
path = f"{base}/{_suffix_with_doc_id(filename, doc_id)}"
if doc_id is not None:
index.occupants[path] = doc_id
return path
async def virtual_path_to_doc(
session: AsyncSession,
*,
search_space_id: int,
virtual_path: str,
) -> Document | None:
"""Resolve a virtual path back to a ``Document`` row.
Resolution order:
1. ``Document.unique_identifier_hash`` lookup (fast path for paths created
by SurfSense itself every NOTE write goes through this hash).
2. If the basename carries a ``" (<doc_id>).xml"`` disambiguation suffix,
try a direct id lookup constrained to the search space.
3. Title-from-basename + folder-resolution lookup as a last resort.
"""
if not virtual_path or not virtual_path.startswith(DOCUMENTS_ROOT):
return None
unique_hash = generate_unique_identifier_hash(
DocumentType.NOTE,
virtual_path,
search_space_id,
)
result = await session.execute(
select(Document).where(
Document.search_space_id == search_space_id,
Document.unique_identifier_hash == unique_hash,
)
)
document = result.scalar_one_or_none()
if document is not None:
return document
rel = virtual_path[len(DOCUMENTS_ROOT) :].lstrip("/")
if not rel:
return None
parts = [p for p in rel.split("/") if p]
if not parts:
return None
basename = parts[-1]
folder_parts = parts[:-1]
stem, suffix_doc_id = parse_doc_id_suffix(basename)
if suffix_doc_id is not None:
result = await session.execute(
select(Document).where(
Document.search_space_id == search_space_id,
Document.id == suffix_doc_id,
)
)
document = result.scalar_one_or_none()
if document is not None:
return document
folder_id = await _resolve_folder_id(
session, search_space_id=search_space_id, folder_parts=folder_parts
)
title_candidates: list[str] = []
raw_title = stem
title_candidates.append(raw_title)
if raw_title.endswith(".xml"):
title_candidates.append(raw_title[:-4])
for candidate in dict.fromkeys(title_candidates):
if not candidate:
continue
query = select(Document).where(
Document.search_space_id == search_space_id,
Document.title == candidate,
)
if folder_id is None:
query = query.where(Document.folder_id.is_(None))
else:
query = query.where(Document.folder_id == folder_id)
result = await session.execute(query)
document = result.scalars().first()
if document is not None:
return document
# Fallback: title-as-string lookup misses when the real DB title contains
# characters that ``safe_filename`` lossily replaces (``:``, ``/``, ``*``,
# etc.) — common for connector-imported docs (Google Calendar/Drive etc.).
# The workspace tree shows the lossy filename, so the agent passes that
# filename back here. Scan all documents in the resolved folder and match
# by ``safe_filename(title)`` to recover the original document.
folder_scan = select(Document).where(
Document.search_space_id == search_space_id,
)
if folder_id is None:
folder_scan = folder_scan.where(Document.folder_id.is_(None))
else:
folder_scan = folder_scan.where(Document.folder_id == folder_id)
result = await session.execute(folder_scan)
for candidate_doc in result.scalars().all():
encoded = safe_filename(str(candidate_doc.title or "untitled"))
if encoded == basename:
return candidate_doc
return None
async def _resolve_folder_id(
session: AsyncSession,
*,
search_space_id: int,
folder_parts: list[str],
) -> int | None:
"""Look up the leaf folder id for a chain of folder names; return ``None`` if missing."""
if not folder_parts:
return None
parent_id: int | None = None
for raw in folder_parts:
name = safe_folder_segment(raw)
query = select(Folder.id).where(
Folder.search_space_id == search_space_id,
Folder.name == name,
)
if parent_id is None:
query = query.where(Folder.parent_id.is_(None))
else:
query = query.where(Folder.parent_id == parent_id)
result = await session.execute(query)
row = result.first()
if row is None:
return None
parent_id = row[0]
return parent_id
def parse_documents_path(virtual_path: str) -> tuple[list[str], str]:
"""Parse a ``/documents/...`` path into ``(folder_parts, document_title)``.
The title has any ``.xml`` extension and trailing ``" (<doc_id>)"``
disambiguation suffix stripped.
"""
if not virtual_path or not virtual_path.startswith(DOCUMENTS_ROOT):
return [], ""
rel = virtual_path[len(DOCUMENTS_ROOT) :].strip("/")
if not rel:
return [], ""
parts = [p for p in rel.split("/") if p]
if not parts:
return [], ""
folder_parts = parts[:-1]
basename = parts[-1]
stem, _ = parse_doc_id_suffix(basename)
title = stem
if title.endswith(".xml"):
title = title[:-4]
return folder_parts, title
__all__ = [
"DOCUMENTS_ROOT",
"PathIndex",
"build_path_index",
"doc_to_virtual_path",
"parse_doc_id_suffix",
"parse_documents_path",
"safe_filename",
"safe_folder_segment",
"virtual_path_to_doc",
]

View file

@ -0,0 +1,201 @@
"""Reducers and sentinels for SurfSense filesystem state.
These reducers back the extra state fields used by the cloud-mode filesystem
agent (`cwd`, `staged_dirs`, `pending_moves`, `dirty_paths`, `doc_id_by_path`,
`kb_priority`, `kb_matched_chunk_ids`, `kb_anon_doc`, `tree_version`).
Tools mutate these fields ONLY via `Command(update={...})` returns; the
reducers are responsible for merging successive updates atomically and for
honouring an explicit reset sentinel (`_CLEAR`) so that a single update can
both reset and reseed a list (used by `move_file` / `aafter_agent`).
The sentinel is intentionally a plain string constant rather than a custom
object so that LangGraph's checkpointer (which serializes raw `Command.update`
deltas via ``ormsgpack`` BEFORE reducers are applied) can round-trip writes
that contain it. The token uses a NUL-bracketed form that cannot collide with
any real virtual path, document title, or dict key produced by the agent.
"""
from __future__ import annotations
from typing import Any, Final, TypeVar
_CLEAR: Final[str] = "\x00__SURFSENSE_FILESYSTEM_CLEAR__\x00"
"""Reset sentinel; pass it inside a list/dict update to request a reset.
For list reducers: ``[_CLEAR, *items]`` resets the field then appends ``items``.
For dict reducers: ``{_CLEAR: True, **items}`` resets the field then merges ``items``.
Because the value is a plain string with embedded NUL bytes, it is natively
serializable by ``ormsgpack`` (used by LangGraph's PostgreSQL checkpointer)
yet still distinct from any real path / key produced by application code.
"""
T = TypeVar("T")
def _replace_reducer[T](left: T | None, right: T | None) -> T | None:
"""Replace `left` outright with `right`. ``None`` on the right is honored as a reset."""
return right
def _is_clear(value: Any) -> bool:
return isinstance(value, str) and value == _CLEAR
def _add_unique_reducer(
left: list[Any] | None,
right: list[Any] | None,
) -> list[Any]:
"""Append items from ``right`` to ``left`` while preserving uniqueness.
Semantics:
- If ``right`` is ``None`` or empty, return ``left`` unchanged.
- If ``right`` contains the ``_CLEAR`` sentinel anywhere, the result is
reseeded with only the items that appear AFTER the LAST occurrence of
``_CLEAR`` (deduplicated, preserving first-seen order). This gives a
single-update "reset and reseed" capability.
- Otherwise, items from ``right`` are appended to ``left`` (order preserved
from first seen) while skipping values that are already present.
"""
if right is None:
return list(left or [])
if not right:
return list(left or [])
last_clear = -1
for index, item in enumerate(right):
if _is_clear(item):
last_clear = index
if last_clear >= 0:
seed: list[Any] = []
seen: set[Any] = set()
for item in right[last_clear + 1 :]:
if _is_clear(item):
continue
try:
if item in seen:
continue
seen.add(item)
except TypeError:
if item in seed:
continue
seed.append(item)
return seed
base = list(left or [])
try:
seen: set[Any] = set(base)
except TypeError:
seen = set()
for item in right:
if _is_clear(item):
continue
try:
if item in seen:
continue
seen.add(item)
except TypeError:
if item in base:
continue
base.append(item)
return base
def _list_append_reducer(
left: list[Any] | None,
right: list[Any] | None,
) -> list[Any]:
"""Append items from ``right`` to ``left`` preserving order and duplicates.
Honours the ``_CLEAR`` sentinel exactly like :func:`_add_unique_reducer`,
but does NOT deduplicate. Used for queues whose ordering and duplicate
occurrences matter (e.g. ``pending_moves``).
"""
if right is None:
return list(left or [])
if not right:
return list(left or [])
last_clear = -1
for index, item in enumerate(right):
if _is_clear(item):
last_clear = index
if last_clear >= 0:
return [item for item in right[last_clear + 1 :] if not _is_clear(item)]
base = list(left or [])
base.extend(item for item in right if not _is_clear(item))
return base
def _dict_merge_with_tombstones_reducer(
left: dict[Any, Any] | None,
right: dict[Any, Any] | None,
) -> dict[Any, Any]:
"""Merge ``right`` into ``left`` with two extra capabilities:
* Keys whose value is ``None`` are removed from the merged result
(tombstone semantics, matching the deepagents file-data reducer).
* The special key ``_CLEAR`` (with any truthy value) resets ``left`` to
``{}`` before merging the remaining keys from ``right``. This makes it
possible to atomically clear and reseed the dictionary in a single
update.
"""
if right is None:
return dict(left or {})
if _CLEAR in right or any(_is_clear(k) for k in right):
result: dict[Any, Any] = {}
for key, value in right.items():
if _is_clear(key):
continue
if value is None:
result.pop(key, None)
continue
result[key] = value
return result
if left is None:
return {key: value for key, value in right.items() if value is not None}
result = dict(left)
for key, value in right.items():
if value is None:
result.pop(key, None)
else:
result[key] = value
return result
def _initial_filesystem_state() -> dict[str, Any]:
"""Default empty values for SurfSense filesystem state fields.
Consumers should always treat these fields as ``state.get(key) or
DEFAULT`` so that fresh threads (without checkpointed state) work
correctly.
"""
return {
"cwd": "/documents",
"staged_dirs": [],
"pending_moves": [],
"doc_id_by_path": {},
"dirty_paths": [],
"kb_priority": [],
"kb_matched_chunk_ids": {},
"kb_anon_doc": None,
"tree_version": 0,
}
__all__ = [
"_CLEAR",
"_add_unique_reducer",
"_dict_merge_with_tombstones_reducer",
"_initial_filesystem_state",
"_list_append_reducer",
"_replace_reducer",
]

View file

@ -332,7 +332,7 @@ _TOOL_INSTRUCTIONS["scrape_webpage"] = """
* When a user asks to "get", "fetch", "pull", "grab", "scrape", or "read" content from a URL
* When the user wants live/dynamic data from a specific webpage (e.g., tables, scores, stats, prices)
* When a URL was mentioned earlier in the conversation and the user asks for its actual content
* When preloaded `/documents/` data is insufficient and the user wants more
* When `/documents/` knowledge-base data is insufficient and the user wants more
- Trigger scenarios:
* "Read this article and summarize it"
* "What does this page say about X?"

View file

@ -20,7 +20,12 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import ImageGeneration, ImageGenerationConfig, SearchSpace
from app.db import (
ImageGeneration,
ImageGenerationConfig,
SearchSpace,
shielded_async_session,
)
from app.services.image_gen_router_service import (
IMAGE_GEN_AUTO_MODE_ID,
ImageGenRouterService,
@ -70,8 +75,13 @@ def create_generate_image_tool(
Args:
search_space_id: The search space ID (for config resolution)
db_session: Async database session
db_session: Reserved for compatibility with the tool registry.
The streaming task's ``AsyncSession`` is shared by every tool;
because AsyncSession is not concurrency-safe, parallel tool calls
would interleave flushes (e.g. podcast + image in the same step)
and poison the transaction. This tool opens its own session.
"""
del db_session # use a fresh per-call session, see below
@tool
async def generate_image(
@ -93,110 +103,119 @@ def create_generate_image_tool(
A dictionary containing the generated image(s) for display in the chat.
"""
try:
# Resolve the image generation config from the search space preference
result = await db_session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id)
)
search_space = result.scalars().first()
if not search_space:
return {"error": "Search space not found"}
config_id = (
search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
)
# Build generation kwargs
# NOTE: size, quality, and style are intentionally NOT passed.
# Different models support different values for these params
# (e.g. DALL-E 3 wants "hd"/"standard" for quality while
# gpt-image-1 wants "high"/"medium"/"low"; size options also
# differ). Letting the model use its own defaults avoids errors.
gen_kwargs: dict[str, Any] = {}
if n is not None and n > 1:
gen_kwargs["n"] = n
# Call litellm based on config type
if is_image_gen_auto_mode(config_id):
if not ImageGenRouterService.is_initialized():
return {
"error": "No image generation models configured. "
"Please add an image model in Settings > Image Models."
}
response = await ImageGenRouterService.aimage_generation(
prompt=prompt, model="auto", **gen_kwargs
# Use a per-call session so concurrent tool calls don't share an
# AsyncSession (which is not concurrency-safe). The streaming
# task's session is shared across every tool; without isolation,
# autoflushes from a concurrent writer poison this tool too.
async with shielded_async_session() as session:
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id)
)
elif config_id < 0:
cfg = _get_global_image_gen_config(config_id)
if not cfg:
return {"error": f"Image generation config {config_id} not found"}
search_space = result.scalars().first()
if not search_space:
return {"error": "Search space not found"}
model_string = _build_model_string(
cfg.get("provider", ""),
cfg["model_name"],
cfg.get("custom_provider"),
config_id = (
search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
)
gen_kwargs["api_key"] = cfg.get("api_key")
if cfg.get("api_base"):
gen_kwargs["api_base"] = cfg["api_base"]
if cfg.get("api_version"):
gen_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"):
gen_kwargs.update(cfg["litellm_params"])
response = await aimage_generation(
prompt=prompt, model=model_string, **gen_kwargs
)
else:
# Positive ID = user-created ImageGenerationConfig
cfg_result = await db_session.execute(
select(ImageGenerationConfig).filter(
ImageGenerationConfig.id == config_id
# Build generation kwargs
# NOTE: size, quality, and style are intentionally NOT passed.
# Different models support different values for these params
# (e.g. DALL-E 3 wants "hd"/"standard" for quality while
# gpt-image-1 wants "high"/"medium"/"low"; size options also
# differ). Letting the model use its own defaults avoids errors.
gen_kwargs: dict[str, Any] = {}
if n is not None and n > 1:
gen_kwargs["n"] = n
# Call litellm based on config type
if is_image_gen_auto_mode(config_id):
if not ImageGenRouterService.is_initialized():
return {
"error": "No image generation models configured. "
"Please add an image model in Settings > Image Models."
}
response = await ImageGenRouterService.aimage_generation(
prompt=prompt, model="auto", **gen_kwargs
)
)
db_cfg = cfg_result.scalars().first()
if not db_cfg:
return {"error": f"Image generation config {config_id} not found"}
elif config_id < 0:
cfg = _get_global_image_gen_config(config_id)
if not cfg:
return {
"error": f"Image generation config {config_id} not found"
}
model_string = _build_model_string(
db_cfg.provider.value,
db_cfg.model_name,
db_cfg.custom_provider,
)
gen_kwargs["api_key"] = db_cfg.api_key
if db_cfg.api_base:
gen_kwargs["api_base"] = db_cfg.api_base
if db_cfg.api_version:
gen_kwargs["api_version"] = db_cfg.api_version
if db_cfg.litellm_params:
gen_kwargs.update(db_cfg.litellm_params)
model_string = _build_model_string(
cfg.get("provider", ""),
cfg["model_name"],
cfg.get("custom_provider"),
)
gen_kwargs["api_key"] = cfg.get("api_key")
if cfg.get("api_base"):
gen_kwargs["api_base"] = cfg["api_base"]
if cfg.get("api_version"):
gen_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"):
gen_kwargs.update(cfg["litellm_params"])
response = await aimage_generation(
prompt=prompt, model=model_string, **gen_kwargs
response = await aimage_generation(
prompt=prompt, model=model_string, **gen_kwargs
)
else:
# Positive ID = user-created ImageGenerationConfig
cfg_result = await session.execute(
select(ImageGenerationConfig).filter(
ImageGenerationConfig.id == config_id
)
)
db_cfg = cfg_result.scalars().first()
if not db_cfg:
return {
"error": f"Image generation config {config_id} not found"
}
model_string = _build_model_string(
db_cfg.provider.value,
db_cfg.model_name,
db_cfg.custom_provider,
)
gen_kwargs["api_key"] = db_cfg.api_key
if db_cfg.api_base:
gen_kwargs["api_base"] = db_cfg.api_base
if db_cfg.api_version:
gen_kwargs["api_version"] = db_cfg.api_version
if db_cfg.litellm_params:
gen_kwargs.update(db_cfg.litellm_params)
response = await aimage_generation(
prompt=prompt, model=model_string, **gen_kwargs
)
# Parse the response and store in DB
response_dict = (
response.model_dump()
if hasattr(response, "model_dump")
else dict(response)
)
# Parse the response and store in DB
response_dict = (
response.model_dump()
if hasattr(response, "model_dump")
else dict(response)
)
# Generate a random access token for this image
access_token = generate_image_token()
# Generate a random access token for this image
access_token = generate_image_token()
# Save to image_generations table for history
db_image_gen = ImageGeneration(
prompt=prompt,
model=getattr(response, "_hidden_params", {}).get("model"),
n=n,
image_generation_config_id=config_id,
response_data=response_dict,
search_space_id=search_space_id,
access_token=access_token,
)
db_session.add(db_image_gen)
await db_session.commit()
await db_session.refresh(db_image_gen)
# Save to image_generations table for history
db_image_gen = ImageGeneration(
prompt=prompt,
model=getattr(response, "_hidden_params", {}).get("model"),
n=n,
image_generation_config_id=config_id,
response_data=response_dict,
search_space_id=search_space_id,
access_token=access_token,
)
session.add(db_image_gen)
await session.commit()
await session.refresh(db_image_gen)
db_image_gen_id = db_image_gen.id
# Extract image URLs from response
images = response_dict.get("data", [])
@ -217,7 +236,7 @@ def create_generate_image_tool(
backend_url = config.BACKEND_URL or "http://localhost:8000"
image_url = (
f"{backend_url}/api/v1/image-generations/"
f"{db_image_gen.id}/image?token={access_token}"
f"{db_image_gen_id}/image?token={access_token}"
)
else:
return {"error": "No displayable image data in the response"}

View file

@ -11,7 +11,7 @@ from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Podcast, PodcastStatus
from app.db import Podcast, PodcastStatus, shielded_async_session
def create_generate_podcast_tool(
@ -27,12 +27,16 @@ def create_generate_podcast_tool(
Args:
search_space_id: The user's search space ID
db_session: Database session for creating the podcast record
db_session: Reserved for future read-side use; the row is written via a
fresh, tool-local session so parallel tool calls (e.g. podcast +
video presentation in the same agent step) don't share an
``AsyncSession`` (which is not concurrency-safe).
thread_id: The chat thread ID for associating the podcast
Returns:
A configured tool function for generating podcasts
"""
del db_session # writes use a fresh tool-local session, see below
@tool
async def generate_podcast(
@ -64,32 +68,40 @@ def create_generate_podcast_tool(
- message: Status message (or "error" field if status is failed)
"""
try:
podcast = Podcast(
title=podcast_title,
status=PodcastStatus.PENDING,
search_space_id=search_space_id,
thread_id=thread_id,
)
db_session.add(podcast)
await db_session.commit()
await db_session.refresh(podcast)
# Open a fresh session per call. The streaming task's session is
# shared between every tool, and ``AsyncSession`` is NOT safe for
# concurrent use: when the LLM emits parallel tool calls, two
# concurrent ``add()`` / ``commit()`` paths interleave and the
# second one hits "Session.add() during flush" → the transaction
# is poisoned for both tools.
async with shielded_async_session() as session:
podcast = Podcast(
title=podcast_title,
status=PodcastStatus.PENDING,
search_space_id=search_space_id,
thread_id=thread_id,
)
session.add(podcast)
await session.commit()
await session.refresh(podcast)
podcast_id = podcast.id
from app.tasks.celery_tasks.podcast_tasks import (
generate_content_podcast_task,
)
task = generate_content_podcast_task.delay(
podcast_id=podcast.id,
podcast_id=podcast_id,
source_content=source_content,
search_space_id=search_space_id,
user_prompt=user_prompt,
)
print(f"[generate_podcast] Created podcast {podcast.id}, task: {task.id}")
print(f"[generate_podcast] Created podcast {podcast_id}, task: {task.id}")
return {
"status": PodcastStatus.PENDING.value,
"podcast_id": podcast.id,
"podcast_id": podcast_id,
"title": podcast_title,
"message": "Podcast generation started. This may take a few minutes.",
}

View file

@ -11,7 +11,7 @@ from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import VideoPresentation, VideoPresentationStatus
from app.db import VideoPresentation, VideoPresentationStatus, shielded_async_session
def create_generate_video_presentation_tool(
@ -23,8 +23,11 @@ def create_generate_video_presentation_tool(
Factory function to create the generate_video_presentation tool with injected dependencies.
Pre-creates video presentation record with pending status so the ID is available
immediately for frontend polling.
immediately for frontend polling. The row is written via a fresh, tool-local
session so parallel tool calls (e.g. video + podcast in the same agent step)
don't share an ``AsyncSession`` (which is not concurrency-safe).
"""
del db_session # writes use a fresh tool-local session, see below
@tool
async def generate_video_presentation(
@ -42,34 +45,40 @@ def create_generate_video_presentation_tool(
user_prompt: Optional style/tone instructions.
"""
try:
video_pres = VideoPresentation(
title=video_title,
status=VideoPresentationStatus.PENDING,
search_space_id=search_space_id,
thread_id=thread_id,
)
db_session.add(video_pres)
await db_session.commit()
await db_session.refresh(video_pres)
# See podcast.py for the rationale: parallel tool calls share the
# streaming session, and AsyncSession is not concurrency-safe —
# interleaved flushes produce "Session.add() during flush" and
# poison the transaction for every concurrent tool.
async with shielded_async_session() as session:
video_pres = VideoPresentation(
title=video_title,
status=VideoPresentationStatus.PENDING,
search_space_id=search_space_id,
thread_id=thread_id,
)
session.add(video_pres)
await session.commit()
await session.refresh(video_pres)
video_pres_id = video_pres.id
from app.tasks.celery_tasks.video_presentation_tasks import (
generate_video_presentation_task,
)
task = generate_video_presentation_task.delay(
video_presentation_id=video_pres.id,
video_presentation_id=video_pres_id,
source_content=source_content,
search_space_id=search_space_id,
user_prompt=user_prompt,
)
print(
f"[generate_video_presentation] Created video presentation {video_pres.id}, task: {task.id}"
f"[generate_video_presentation] Created video presentation {video_pres_id}, task: {task.id}"
)
return {
"status": VideoPresentationStatus.PENDING.value,
"video_presentation_id": video_pres.id,
"video_presentation_id": video_pres_id,
"title": video_title,
"message": "Video presentation generation started. This may take a few minutes.",
}

View file

@ -30,7 +30,7 @@ from sqlalchemy.orm import selectinload
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.agents.new_chat.checkpointer import get_checkpointer
from app.agents.new_chat.filesystem_selection import FilesystemSelection
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
from app.agents.new_chat.llm_config import (
AgentConfig,
create_chat_litellm_from_agent_config,
@ -42,6 +42,9 @@ from app.agents.new_chat.memory_extraction import (
extract_and_save_memory,
extract_and_save_team_memory,
)
from app.agents.new_chat.middleware.kb_persistence import (
commit_staged_filesystem_state,
)
from app.db import (
ChatVisibility,
NewChatMessage,
@ -258,6 +261,10 @@ async def _stream_agent_events(
initial_step_id: str | None = None,
initial_step_title: str = "",
initial_step_items: list[str] | None = None,
*,
fallback_commit_search_space_id: int | None = None,
fallback_commit_created_by_id: str | None = None,
fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
) -> AsyncGenerator[str, None]:
"""Shared async generator that streams and formats astream_events from the agent.
@ -1280,6 +1287,40 @@ async def _stream_agent_events(
state = await agent.aget_state(config)
state_values = getattr(state, "values", {}) or {}
# Safety net: if astream_events was cancelled before
# KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work
# (dirty_paths / staged_dirs / pending_moves) will still be in the
# checkpointed state. Run the SAME shared commit helper here so the
# turn's writes don't get lost on client disconnect, then push the
# delta back into the graph using `as_node=...` so reducers fire as if
# the after_agent hook produced it.
if (
fallback_commit_filesystem_mode == FilesystemMode.CLOUD
and fallback_commit_search_space_id is not None
and (
(state_values.get("dirty_paths") or [])
or (state_values.get("staged_dirs") or [])
or (state_values.get("pending_moves") or [])
)
):
try:
delta = await commit_staged_filesystem_state(
state_values,
search_space_id=fallback_commit_search_space_id,
created_by_id=fallback_commit_created_by_id,
filesystem_mode=fallback_commit_filesystem_mode,
dispatch_events=False,
)
if delta:
await agent.aupdate_state(
config,
delta,
as_node="KnowledgeBasePersistenceMiddleware.after_agent",
)
except Exception as exc:
_perf_log.warning("[stream_new_chat] safety-net commit failed: %s", exc)
contract_state = state_values.get("file_operation_contract") or {}
contract_turn_id = contract_state.get("turn_id")
current_turn_id = config.get("configurable", {}).get("turn_id", "")
@ -1814,6 +1855,13 @@ async def stream_new_chat(
initial_step_id=initial_step_id,
initial_step_title=initial_title,
initial_step_items=initial_items,
fallback_commit_search_space_id=search_space_id,
fallback_commit_created_by_id=user_id,
fallback_commit_filesystem_mode=(
filesystem_selection.mode
if filesystem_selection
else FilesystemMode.CLOUD
),
):
if not _first_event_logged:
_perf_log.info(
@ -2251,6 +2299,13 @@ async def stream_resume_chat(
streaming_service=streaming_service,
result=stream_result,
step_prefix="thinking-resume",
fallback_commit_search_space_id=search_space_id,
fallback_commit_created_by_id=user_id,
fallback_commit_filesystem_mode=(
filesystem_selection.mode
if filesystem_selection
else FilesystemMode.CLOUD
),
):
if not _first_event_logged:
_perf_log.info(

View file

@ -0,0 +1,198 @@
"""Tests for canonical virtual-path resolver helpers."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.agents.new_chat.path_resolver import (
DOCUMENTS_ROOT,
PathIndex,
doc_to_virtual_path,
parse_doc_id_suffix,
parse_documents_path,
safe_filename,
safe_folder_segment,
virtual_path_to_doc,
)
pytestmark = pytest.mark.unit
class TestSafeFilename:
def test_appends_xml_extension(self):
assert safe_filename("notes").endswith(".xml")
def test_strips_invalid_chars(self):
assert "/" not in safe_filename("a/b\\c.xml")
def test_falls_back_when_empty(self):
assert safe_filename("").endswith(".xml")
assert safe_filename("///") == "untitled.xml" or safe_filename("///").endswith(
".xml"
)
class TestSafeFolderSegment:
def test_strips_path_separators(self):
assert "/" not in safe_folder_segment("a/b")
def test_falls_back(self):
assert safe_folder_segment("") == "folder"
class TestParseDocIdSuffix:
def test_parses_suffix(self):
stem, doc_id = parse_doc_id_suffix("My Doc (42).xml")
assert stem == "My Doc"
assert doc_id == 42
def test_no_suffix_returns_none(self):
stem, doc_id = parse_doc_id_suffix("My Doc.xml")
assert stem == "My Doc"
assert doc_id is None
def test_no_xml_extension(self):
stem, doc_id = parse_doc_id_suffix("plain")
assert stem == "plain"
assert doc_id is None
class TestDocToVirtualPath:
def test_root_when_no_folder(self):
index = PathIndex()
path = doc_to_virtual_path(doc_id=1, title="Hello", folder_id=None, index=index)
assert path == f"{DOCUMENTS_ROOT}/Hello.xml"
assert index.occupants[path] == 1
def test_collision_picks_doc_id_suffix(self):
index = PathIndex(occupants={f"{DOCUMENTS_ROOT}/Hello.xml": 7})
path = doc_to_virtual_path(doc_id=8, title="Hello", folder_id=None, index=index)
assert path == f"{DOCUMENTS_ROOT}/Hello (8).xml"
assert index.occupants[path] == 8
def test_uses_folder_path_when_known(self):
index = PathIndex(folder_paths={5: f"{DOCUMENTS_ROOT}/notes"})
path = doc_to_virtual_path(doc_id=2, title="A", folder_id=5, index=index)
assert path == f"{DOCUMENTS_ROOT}/notes/A.xml"
class TestParseDocumentsPath:
def test_extracts_folder_parts_and_title(self):
parts, title = parse_documents_path(f"{DOCUMENTS_ROOT}/foo/bar/baz.xml")
assert parts == ["foo", "bar"]
assert title == "baz"
def test_strips_doc_id_suffix(self):
parts, title = parse_documents_path(f"{DOCUMENTS_ROOT}/foo/My Doc (12).xml")
assert parts == ["foo"]
assert title == "My Doc"
def test_non_documents_returns_empty(self):
assert parse_documents_path("/other/x.xml") == ([], "")
def _result_from_scalars(rows: list):
"""Build a fake SQLAlchemy ``Result`` whose ``.scalars().all()`` and
``.scalars().first()`` yield ``rows``."""
scalars = MagicMock()
scalars.all.return_value = list(rows)
scalars.first.return_value = rows[0] if rows else None
result = MagicMock()
result.scalars.return_value = scalars
result.scalar_one_or_none.return_value = None
result.first.return_value = None
return result
def _result_from_one(value):
result = MagicMock()
result.scalar_one_or_none.return_value = value
return result
class TestVirtualPathToDoc:
"""Lookup must round-trip through ``safe_filename``'s lossy encoding.
The workspace tree displays ``safe_filename(title)`` as the basename, so
when the agent passes that basename back to a tool (move/edit/read) the
resolver must find the original document even though characters like
``:`` were replaced with ``_``.
"""
@pytest.mark.asyncio
async def test_falls_back_to_safe_filename_match_when_title_lossy(self):
# A Google Calendar-style title that contains a colon — safe_filename
# rewrites the colon to ``_``, so the literal title-equality lookup
# would miss this row.
original_title = "Calendar: Happy birthday!"
encoded_basename = safe_filename(original_title)
assert encoded_basename == "Calendar_ Happy birthday!.xml"
target_doc = SimpleNamespace(id=42, title=original_title, folder_id=None)
session = MagicMock()
# Each ``await session.execute(...)`` returns a fresh canned result.
# Order matches the resolver's lookup steps:
# 1) unique_identifier_hash → no match
# 2) literal title match → no match (lossy encoding)
# 3) folder scan → returns the row whose title encodes to basename
session.execute = AsyncMock(
side_effect=[
_result_from_one(None),
_result_from_scalars([]),
_result_from_scalars([target_doc]),
]
)
document = await virtual_path_to_doc(
session,
search_space_id=5,
virtual_path=f"{DOCUMENTS_ROOT}/{encoded_basename}",
)
assert document is target_doc
@pytest.mark.asyncio
async def test_returns_none_when_no_doc_matches_safe_filename(self):
session = MagicMock()
session.execute = AsyncMock(
side_effect=[
_result_from_one(None),
_result_from_scalars([]),
_result_from_scalars(
[SimpleNamespace(id=1, title="Something else", folder_id=None)]
),
]
)
document = await virtual_path_to_doc(
session,
search_space_id=5,
virtual_path=f"{DOCUMENTS_ROOT}/Calendar_ Happy birthday!.xml",
)
assert document is None
@pytest.mark.asyncio
async def test_literal_title_match_short_circuits_fallback(self):
# When the literal title query hits, the folder-scan fallback must
# NOT run (saves a query and avoids picking the wrong doc when two
# rows share a lossy encoding).
target_doc = SimpleNamespace(id=7, title="Plain Note", folder_id=None)
session = MagicMock()
session.execute = AsyncMock(
side_effect=[
_result_from_one(None),
_result_from_scalars([target_doc]),
]
)
document = await virtual_path_to_doc(
session,
search_space_id=5,
virtual_path=f"{DOCUMENTS_ROOT}/Plain Note.xml",
)
assert document is target_doc
assert session.execute.await_count == 2

View file

@ -0,0 +1,107 @@
"""Tests for SurfSense filesystem state reducers."""
from __future__ import annotations
import pytest
from app.agents.new_chat.state_reducers import (
_CLEAR,
_add_unique_reducer,
_dict_merge_with_tombstones_reducer,
_initial_filesystem_state,
_list_append_reducer,
_replace_reducer,
)
pytestmark = pytest.mark.unit
class TestReplaceReducer:
def test_right_wins_outright(self):
assert _replace_reducer("a", "b") == "b"
def test_none_right_returns_none(self):
assert _replace_reducer("a", None) is None
def test_none_left_returns_right(self):
assert _replace_reducer(None, "b") == "b"
class TestAddUniqueReducer:
def test_appends_unique_items(self):
assert _add_unique_reducer(["a"], ["b", "c"]) == ["a", "b", "c"]
def test_dedupes_against_left(self):
assert _add_unique_reducer(["a", "b"], ["b", "c"]) == ["a", "b", "c"]
def test_dedupes_within_right(self):
assert _add_unique_reducer([], ["a", "a", "b"]) == ["a", "b"]
def test_clear_anywhere_resets_and_reseeds_with_after_items(self):
# _CLEAR semantics: only items AFTER the LAST _CLEAR are kept.
result = _add_unique_reducer(["x", "y"], ["a", _CLEAR, "b", "c"])
assert result == ["b", "c"]
def test_multiple_clears_use_last(self):
result = _add_unique_reducer(["x"], [_CLEAR, "a", _CLEAR, "b"])
assert result == ["b"]
def test_clear_only_resets_to_empty(self):
assert _add_unique_reducer(["x", "y"], [_CLEAR]) == []
def test_empty_right_keeps_left(self):
assert _add_unique_reducer(["a"], []) == ["a"]
assert _add_unique_reducer(["a"], None) == ["a"]
class TestListAppendReducer:
def test_preserves_order_and_duplicates(self):
result = _list_append_reducer([{"a": 1}], [{"b": 2}, {"a": 1}])
assert result == [{"a": 1}, {"b": 2}, {"a": 1}]
def test_clear_resets_keeping_after_items(self):
result = _list_append_reducer([{"a": 1}], [{"old": 1}, _CLEAR, {"new": 2}])
assert result == [{"new": 2}]
class TestDictMergeWithTombstones:
def test_merges_keys(self):
assert _dict_merge_with_tombstones_reducer({"a": 1}, {"b": 2}) == {
"a": 1,
"b": 2,
}
def test_none_value_deletes_key(self):
result = _dict_merge_with_tombstones_reducer({"a": 1, "b": 2}, {"a": None})
assert result == {"b": 2}
def test_clear_resets_then_merges(self):
result = _dict_merge_with_tombstones_reducer(
{"a": 1, "b": 2}, {_CLEAR: True, "c": 3}
)
assert result == {"c": 3}
def test_clear_keeps_only_post_clear_non_none(self):
result = _dict_merge_with_tombstones_reducer(
{"a": 1}, {_CLEAR: True, "b": 2, "c": None}
)
assert result == {"b": 2}
def test_none_left_handled(self):
assert _dict_merge_with_tombstones_reducer(None, {"a": 1, "b": None}) == {
"a": 1
}
class TestInitialFilesystemState:
def test_default_shape(self):
state = _initial_filesystem_state()
assert state["cwd"] == "/documents"
assert state["staged_dirs"] == []
assert state["pending_moves"] == []
assert state["doc_id_by_path"] == {}
assert state["dirty_paths"] == []
assert state["kb_priority"] == []
assert state["kb_matched_chunk_ids"] == {}
assert state["kb_anon_doc"] is None
assert state["tree_version"] == 0

View file

@ -36,11 +36,18 @@ def test_backend_resolver_returns_multi_root_backend_for_single_root(tmp_path: P
def test_backend_resolver_uses_cloud_mode_by_default():
resolver = build_backend_resolver(FilesystemSelection())
backend = resolver(_RuntimeStub())
# StateBackend class name check keeps this test decoupled
# from internal deepagents runtime class identity.
# When no search_space_id is provided we fall back to StateBackend so
# sub-agents / tests without DB access still work.
assert backend.__class__.__name__ == "StateBackend"
def test_backend_resolver_uses_kb_postgres_in_cloud_with_search_space():
resolver = build_backend_resolver(FilesystemSelection(), search_space_id=42)
backend = resolver(_RuntimeStub())
assert backend.__class__.__name__ == "KBPostgresBackend"
assert backend.search_space_id == 42
def test_backend_resolver_returns_multi_root_backend_for_multiple_roots(tmp_path: Path):
root_one = tmp_path / "resume"
root_two = tmp_path / "notes"

View file

@ -0,0 +1,204 @@
"""Unit tests for the SurfSense filesystem middleware new behaviors.
Covers:
* cloud cwd defaults to ``/documents`` and relative paths resolve under it
* cloud writes outside ``/documents/`` are rejected unless basename starts
with ``temp_``
* cloud writes/edits to the anonymous document are rejected (read-only)
* helper methods on the middleware (``_resolve_relative``,
``_check_cloud_write_namespace``, ``_default_cwd``)
These tests use ``__new__`` to bypass the heavy ``__init__`` and exercise
the helper methods directly so the test surface stays narrow and fast.
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware.filesystem import (
SurfSenseFilesystemMiddleware,
_build_filesystem_system_prompt,
_build_tool_descriptions,
)
pytestmark = pytest.mark.unit
def _make_middleware(mode: FilesystemMode = FilesystemMode.CLOUD):
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
middleware._filesystem_mode = mode
return middleware
def _runtime(state: dict | None = None) -> SimpleNamespace:
return SimpleNamespace(state=state or {})
class TestCloudCwdDefaults:
def test_default_cwd_in_cloud_is_documents_root(self):
m = _make_middleware()
assert m._default_cwd() == "/documents"
def test_default_cwd_in_desktop_is_root(self):
m = _make_middleware(FilesystemMode.DESKTOP_LOCAL_FOLDER)
assert m._default_cwd() == "/"
def test_current_cwd_uses_state_when_set(self):
m = _make_middleware()
runtime = _runtime({"cwd": "/documents/notes"})
assert m._current_cwd(runtime) == "/documents/notes"
def test_current_cwd_falls_back_to_default(self):
m = _make_middleware()
runtime = _runtime({})
assert m._current_cwd(runtime) == "/documents"
def test_current_cwd_ignores_invalid(self):
m = _make_middleware()
runtime = _runtime({"cwd": "not-absolute"})
assert m._current_cwd(runtime) == "/documents"
class TestRelativePathResolution:
def test_relative_path_resolves_against_cwd(self):
m = _make_middleware()
runtime = _runtime({"cwd": "/documents/projects"})
assert (
m._resolve_relative("notes.md", runtime) == "/documents/projects/notes.md"
)
def test_relative_path_with_dotdot(self):
m = _make_middleware()
runtime = _runtime({"cwd": "/documents/a/b"})
assert m._resolve_relative("../c.md", runtime) == "/documents/a/c.md"
def test_absolute_path_is_kept(self):
m = _make_middleware()
runtime = _runtime({"cwd": "/documents"})
assert m._resolve_relative("/other/x.md", runtime) == "/other/x.md"
def test_empty_path_returns_cwd(self):
m = _make_middleware()
runtime = _runtime({"cwd": "/documents/projects"})
assert m._resolve_relative("", runtime) == "/documents/projects"
class TestCloudWriteNamespacePolicy:
def test_documents_path_allowed(self):
m = _make_middleware()
runtime = _runtime()
assert m._check_cloud_write_namespace("/documents/foo.md", runtime) is None
def test_documents_root_allowed(self):
m = _make_middleware()
runtime = _runtime()
assert m._check_cloud_write_namespace("/documents", runtime) is None
def test_temp_basename_anywhere_allowed(self):
m = _make_middleware()
runtime = _runtime()
assert m._check_cloud_write_namespace("/temp_scratch.md", runtime) is None
assert m._check_cloud_write_namespace("/foo/temp_x.md", runtime) is None
assert m._check_cloud_write_namespace("/documents/temp_x.md", runtime) is None
def test_other_paths_rejected(self):
m = _make_middleware()
runtime = _runtime()
err = m._check_cloud_write_namespace("/foo/bar.md", runtime)
assert err is not None
assert "must target /documents" in err
def test_anon_doc_path_is_read_only(self):
m = _make_middleware()
runtime = _runtime(
{
"kb_anon_doc": {
"path": "/documents/uploaded.xml",
"title": "uploaded",
"content": "",
"chunks": [],
}
}
)
err = m._check_cloud_write_namespace("/documents/uploaded.xml", runtime)
assert err is not None
assert "read-only" in err
def test_desktop_mode_skips_namespace_policy(self):
m = _make_middleware(FilesystemMode.DESKTOP_LOCAL_FOLDER)
runtime = _runtime()
assert m._check_cloud_write_namespace("/random/path.md", runtime) is None
class TestModeSpecificPrompts:
"""The prompt and tool descriptions must only describe the active mode.
Cross-mode noise wastes tokens and confuses the model with rules it
cannot use this session.
"""
def test_cloud_prompt_omits_desktop_section(self):
prompt = _build_filesystem_system_prompt(
FilesystemMode.CLOUD, sandbox_available=False
)
assert "Local Folder Mode" not in prompt
assert "mount-prefixed" not in prompt
assert "Persistence Rules" in prompt
assert "/documents" in prompt
assert "temp_" in prompt
def test_desktop_prompt_omits_cloud_persistence_rules(self):
prompt = _build_filesystem_system_prompt(
FilesystemMode.DESKTOP_LOCAL_FOLDER, sandbox_available=False
)
assert "Persistence Rules" not in prompt
assert "Workspace Tree" not in prompt
assert "<priority_documents>" not in prompt
assert "Local Folder Mode" in prompt
assert "mount-prefixed" in prompt
def test_cloud_tool_descs_omit_desktop_phrases(self):
descs = _build_tool_descriptions(FilesystemMode.CLOUD)
for name in (
"write_file",
"edit_file",
"move_file",
"mkdir",
"list_tree",
"grep",
):
text = descs[name]
assert "Desktop" not in text, f"{name} leaks desktop hints"
assert "Cloud mode:" not in text, f"{name} qualifies a cloud-only desc"
def test_desktop_tool_descs_omit_cloud_phrases(self):
descs = _build_tool_descriptions(FilesystemMode.DESKTOP_LOCAL_FOLDER)
for name in (
"write_file",
"edit_file",
"move_file",
"mkdir",
"list_tree",
"grep",
):
text = descs[name]
assert "Cloud" not in text, f"{name} leaks cloud hints"
assert "/documents/" not in text, f"{name} mentions cloud namespace"
assert "temp_" not in text, f"{name} mentions cloud temp_ semantics"
def test_sandbox_addendum_appended_when_available(self):
prompt = _build_filesystem_system_prompt(
FilesystemMode.CLOUD, sandbox_available=True
)
assert "execute_code" in prompt
assert "Code Execution" in prompt
def test_sandbox_addendum_absent_when_unavailable(self):
prompt = _build_filesystem_system_prompt(
FilesystemMode.CLOUD, sandbox_available=False
)
assert "execute_code" not in prompt

View file

@ -11,25 +11,6 @@ from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
pytestmark = pytest.mark.unit
class _BackendWithRawRead:
def __init__(self, content: str) -> None:
self._content = content
def read(self, file_path: str, offset: int = 0, limit: int = 200000) -> str:
del file_path, offset, limit
return " 1\tline1\n 2\tline2"
async def aread(self, file_path: str, offset: int = 0, limit: int = 200000) -> str:
return self.read(file_path, offset, limit)
def read_raw(self, file_path: str) -> str:
del file_path
return self._content
async def aread_raw(self, file_path: str) -> str:
return self.read_raw(file_path)
class _RuntimeNoSuggestedPath:
state = {"file_operation_contract": {}}
@ -39,40 +20,19 @@ class _RuntimeWithSuggestedPath:
self.state = {"file_operation_contract": {"suggested_path": suggested_path}}
def test_verify_written_content_prefers_raw_sync() -> None:
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
expected = "line1\nline2"
backend = _BackendWithRawRead(expected)
verify_error = middleware._verify_written_content_sync(
backend=backend,
path="/note.md",
expected_content=expected,
)
assert verify_error is None
def test_contract_suggested_path_falls_back_to_notes_md() -> None:
def test_contract_suggested_path_falls_back_to_documents_notes_md() -> None:
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
middleware._filesystem_mode = FilesystemMode.CLOUD
suggested = middleware._get_contract_suggested_path(_RuntimeNoSuggestedPath()) # type: ignore[arg-type]
assert suggested == "/notes.md"
# Cloud default cwd is /documents so the fallback lands in the KB.
assert suggested == "/documents/notes.md"
@pytest.mark.asyncio
async def test_verify_written_content_prefers_raw_async() -> None:
def test_contract_suggested_path_falls_back_to_root_notes_md_in_desktop() -> None:
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
expected = "line1\nline2"
backend = _BackendWithRawRead(expected)
verify_error = await middleware._verify_written_content_async(
backend=backend,
path="/note.md",
expected_content=expected,
)
assert verify_error is None
middleware._filesystem_mode = FilesystemMode.DESKTOP_LOCAL_FOLDER
suggested = middleware._get_contract_suggested_path(_RuntimeNoSuggestedPath()) # type: ignore[arg-type]
assert suggested == "/notes.md"
def test_normalize_local_mount_path_prefixes_default_mount(tmp_path: Path) -> None:

View file

@ -5,10 +5,10 @@ import json
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from app.agents.new_chat.document_xml import build_document_xml as _build_document_xml
from app.agents.new_chat.middleware.knowledge_search import (
KBSearchPlan,
KnowledgeBaseSearchMiddleware,
_build_document_xml,
_normalize_optional_date_range,
_parse_kb_search_plan_response,
_render_recent_conversation,
@ -248,17 +248,10 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
captured.update(kwargs)
return []
async def fake_build_scoped_filesystem(**kwargs):
return {}, {}
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
fake_search_knowledge_base,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
fake_build_scoped_filesystem,
)
llm = FakeLLM(
json.dumps(
@ -298,17 +291,10 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
captured.update(kwargs)
return []
async def fake_build_scoped_filesystem(**kwargs):
return {}, {}
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
fake_search_knowledge_base,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
fake_build_scoped_filesystem,
)
middleware = KnowledgeBaseSearchMiddleware(
llm=FakeLLM("not json"),
@ -334,17 +320,10 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
captured.update(kwargs)
return []
async def fake_build_scoped_filesystem(**kwargs):
return {}, {}
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
fake_search_knowledge_base,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
fake_build_scoped_filesystem,
)
middleware = KnowledgeBaseSearchMiddleware(
llm=FakeLLM(
@ -386,9 +365,6 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
search_called = True
return []
async def fake_build_scoped_filesystem(**kwargs):
return {}, {}
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.browse_recent_documents",
fake_browse_recent_documents,
@ -397,10 +373,6 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
fake_search_knowledge_base,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
fake_build_scoped_filesystem,
)
llm = FakeLLM(
json.dumps(
@ -440,9 +412,6 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
search_captured.update(kwargs)
return []
async def fake_build_scoped_filesystem(**kwargs):
return {}, {}
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.browse_recent_documents",
fake_browse_recent_documents,
@ -451,10 +420,6 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
fake_search_knowledge_base,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
fake_build_scoped_filesystem,
)
llm = FakeLLM(
json.dumps(