mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-30 21:59:46 +02:00
feat: updated file management for main agent
This commit is contained in:
parent
8d50f90060
commit
05ca4c0b9f
27 changed files with 5054 additions and 1803 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -7,4 +7,5 @@ node_modules/
|
||||||
.pnpm-store
|
.pnpm-store
|
||||||
.DS_Store
|
.DS_Store
|
||||||
deepagents/
|
deepagents/
|
||||||
debug.log
|
debug.log
|
||||||
|
opencode/
|
||||||
|
|
@ -28,13 +28,76 @@ from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.messages import AIMessage, ToolMessage
|
from langchain_core.messages import AIMessage, ToolMessage
|
||||||
|
|
||||||
|
from app.agents.new_chat.document_xml import build_document_xml
|
||||||
from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware
|
from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware
|
||||||
from app.agents.new_chat.middleware.knowledge_search import (
|
from app.agents.new_chat.middleware.knowledge_search import (
|
||||||
build_scoped_filesystem,
|
|
||||||
search_knowledge_base,
|
search_knowledge_base,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.path_resolver import (
|
||||||
|
DOCUMENTS_ROOT,
|
||||||
|
build_path_index,
|
||||||
|
doc_to_virtual_path,
|
||||||
|
)
|
||||||
|
from app.db import shielded_async_session
|
||||||
from app.services.new_streaming_service import VercelStreamingService
|
from app.services.new_streaming_service import VercelStreamingService
|
||||||
|
|
||||||
|
try:
|
||||||
|
from deepagents.backends.utils import create_file_data
|
||||||
|
except Exception: # pragma: no cover - defensive
|
||||||
|
|
||||||
|
def create_file_data(content: str) -> dict[str, Any]:
|
||||||
|
return {"content": content.split("\n")}
|
||||||
|
|
||||||
|
|
||||||
|
async def _build_autocomplete_filesystem(
|
||||||
|
*,
|
||||||
|
documents: Any,
|
||||||
|
search_space_id: int,
|
||||||
|
) -> tuple[dict[str, Any], dict[int, str]]:
|
||||||
|
"""Build a ``state['files']``-shaped dict from KB search results.
|
||||||
|
|
||||||
|
This is the autocomplete-specific replacement for the previous
|
||||||
|
``build_scoped_filesystem`` helper. It uses the canonical path resolver
|
||||||
|
so paths line up with the rest of the system, including collision
|
||||||
|
suffixes for duplicate titles.
|
||||||
|
"""
|
||||||
|
files: dict[str, Any] = {}
|
||||||
|
doc_id_to_path: dict[int, str] = {}
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
return files, doc_id_to_path
|
||||||
|
|
||||||
|
async with shielded_async_session() as session:
|
||||||
|
index = await build_path_index(session, search_space_id)
|
||||||
|
|
||||||
|
for document in documents:
|
||||||
|
if not isinstance(document, dict):
|
||||||
|
continue
|
||||||
|
meta = document.get("document") or {}
|
||||||
|
doc_id = meta.get("id")
|
||||||
|
if not isinstance(doc_id, int):
|
||||||
|
continue
|
||||||
|
title = str(meta.get("title") or "untitled")
|
||||||
|
folder_id = meta.get("folder_id")
|
||||||
|
path = doc_to_virtual_path(
|
||||||
|
doc_id=doc_id, title=title, folder_id=folder_id, index=index
|
||||||
|
)
|
||||||
|
chunk_ids = document.get("matched_chunk_ids") or []
|
||||||
|
try:
|
||||||
|
matched_set = {int(c) for c in chunk_ids}
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
matched_set = set()
|
||||||
|
xml = build_document_xml(document, matched_chunk_ids=matched_set)
|
||||||
|
files[path] = create_file_data(xml)
|
||||||
|
doc_id_to_path[doc_id] = path
|
||||||
|
|
||||||
|
if not files:
|
||||||
|
# Ensure the synthetic /documents folder is visible even when empty.
|
||||||
|
files.setdefault(f"{DOCUMENTS_ROOT}/.placeholder", create_file_data(""))
|
||||||
|
|
||||||
|
return files, doc_id_to_path
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
KB_TOP_K = 10
|
KB_TOP_K = 10
|
||||||
|
|
@ -174,7 +237,7 @@ async def precompute_kb_filesystem(
|
||||||
if not search_results:
|
if not search_results:
|
||||||
return _KBResult()
|
return _KBResult()
|
||||||
|
|
||||||
new_files, _ = await build_scoped_filesystem(
|
new_files, _ = await _build_autocomplete_filesystem(
|
||||||
documents=search_results,
|
documents=search_results,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
)
|
)
|
||||||
|
|
@ -215,13 +278,12 @@ async def precompute_kb_filesystem(
|
||||||
class AutocompleteFilesystemMiddleware(SurfSenseFilesystemMiddleware):
|
class AutocompleteFilesystemMiddleware(SurfSenseFilesystemMiddleware):
|
||||||
"""Filesystem middleware for autocomplete — read-only exploration only.
|
"""Filesystem middleware for autocomplete — read-only exploration only.
|
||||||
|
|
||||||
Strips ``save_document`` (permanent KB persistence) and passes
|
Passes ``search_space_id=None`` so the new persistence pipeline is
|
||||||
``search_space_id=None`` so ``write_file`` / ``edit_file`` stay ephemeral.
|
bypassed; the autocomplete flow only reads, never commits to Postgres.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__(search_space_id=None, created_by_id=None)
|
super().__init__(search_space_id=None, created_by_id=None)
|
||||||
self.tools = [t for t in self.tools if t.name != "save_document"]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -34,12 +34,15 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||||
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
||||||
from app.agents.new_chat.filesystem_selection import FilesystemSelection
|
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||||
from app.agents.new_chat.llm_config import AgentConfig
|
from app.agents.new_chat.llm_config import AgentConfig
|
||||||
from app.agents.new_chat.middleware import (
|
from app.agents.new_chat.middleware import (
|
||||||
|
AnonymousDocumentMiddleware,
|
||||||
DedupHITLToolCallsMiddleware,
|
DedupHITLToolCallsMiddleware,
|
||||||
FileIntentMiddleware,
|
FileIntentMiddleware,
|
||||||
KnowledgeBaseSearchMiddleware,
|
KnowledgeBasePersistenceMiddleware,
|
||||||
|
KnowledgePriorityMiddleware,
|
||||||
|
KnowledgeTreeMiddleware,
|
||||||
MemoryInjectionMiddleware,
|
MemoryInjectionMiddleware,
|
||||||
SurfSenseFilesystemMiddleware,
|
SurfSenseFilesystemMiddleware,
|
||||||
)
|
)
|
||||||
|
|
@ -246,7 +249,12 @@ async def create_surfsense_deep_agent(
|
||||||
"""
|
"""
|
||||||
_t_agent_total = time.perf_counter()
|
_t_agent_total = time.perf_counter()
|
||||||
filesystem_selection = filesystem_selection or FilesystemSelection()
|
filesystem_selection = filesystem_selection or FilesystemSelection()
|
||||||
backend_resolver = build_backend_resolver(filesystem_selection)
|
backend_resolver = build_backend_resolver(
|
||||||
|
filesystem_selection,
|
||||||
|
search_space_id=search_space_id
|
||||||
|
if filesystem_selection.mode == FilesystemMode.CLOUD
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
|
||||||
# Discover available connectors and document types for this search space
|
# Discover available connectors and document types for this search space
|
||||||
available_connectors: list[str] | None = None
|
available_connectors: list[str] | None = None
|
||||||
|
|
@ -299,7 +307,9 @@ async def create_surfsense_deep_agent(
|
||||||
modified_disabled_tools = list(disabled_tools) if disabled_tools else []
|
modified_disabled_tools = list(disabled_tools) if disabled_tools else []
|
||||||
modified_disabled_tools.extend(get_connector_gated_tools(available_connectors))
|
modified_disabled_tools.extend(get_connector_gated_tools(available_connectors))
|
||||||
|
|
||||||
# Remove direct KB search tool; we now pre-seed a scoped filesystem via middleware.
|
# Remove direct KB search tool; KnowledgePriorityMiddleware now runs hybrid
|
||||||
|
# search per turn and surfaces hits as a <priority_documents> hint plus
|
||||||
|
# `<chunk_index matched="true">` markers inside lazy-loaded XML.
|
||||||
if "search_knowledge_base" not in modified_disabled_tools:
|
if "search_knowledge_base" not in modified_disabled_tools:
|
||||||
modified_disabled_tools.append("search_knowledge_base")
|
modified_disabled_tools.append("search_knowledge_base")
|
||||||
|
|
||||||
|
|
@ -365,6 +375,11 @@ async def create_surfsense_deep_agent(
|
||||||
)
|
)
|
||||||
|
|
||||||
# General-purpose subagent middleware
|
# General-purpose subagent middleware
|
||||||
|
# Subagent omits AnonymousDocumentMiddleware, KnowledgeTreeMiddleware,
|
||||||
|
# KnowledgePriorityMiddleware, and KnowledgeBasePersistenceMiddleware - it
|
||||||
|
# inherits state and tools from the parent, but should not (a) re-load
|
||||||
|
# anon docs / re-render the tree / re-run hybrid search, or (b) commit at
|
||||||
|
# its own completion (only the top-level agent's aafter_agent commits).
|
||||||
gp_middleware = [
|
gp_middleware = [
|
||||||
TodoListMiddleware(),
|
TodoListMiddleware(),
|
||||||
_memory_middleware,
|
_memory_middleware,
|
||||||
|
|
@ -389,19 +404,35 @@ async def create_surfsense_deep_agent(
|
||||||
}
|
}
|
||||||
|
|
||||||
# Main agent middleware
|
# Main agent middleware
|
||||||
|
# Order: AnonDoc -> Tree -> Priority -> FileIntent -> Filesystem -> Persistence -> ...
|
||||||
|
# before_agent hooks run in declared order; later injections sit closer to
|
||||||
|
# the latest human turn. Tree (large + cacheable) is injected earliest so
|
||||||
|
# provider-side prefix caching has more material to hit; FileIntent (most
|
||||||
|
# actionable per-turn contract) is injected closest to the user message.
|
||||||
deepagent_middleware = [
|
deepagent_middleware = [
|
||||||
TodoListMiddleware(),
|
TodoListMiddleware(),
|
||||||
_memory_middleware,
|
_memory_middleware,
|
||||||
FileIntentMiddleware(llm=llm),
|
AnonymousDocumentMiddleware(
|
||||||
KnowledgeBaseSearchMiddleware(
|
anon_session_id=anon_session_id,
|
||||||
|
)
|
||||||
|
if filesystem_selection.mode == FilesystemMode.CLOUD
|
||||||
|
else None,
|
||||||
|
KnowledgeTreeMiddleware(
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
filesystem_mode=filesystem_selection.mode,
|
||||||
|
llm=llm,
|
||||||
|
)
|
||||||
|
if filesystem_selection.mode == FilesystemMode.CLOUD
|
||||||
|
else None,
|
||||||
|
KnowledgePriorityMiddleware(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
filesystem_mode=filesystem_selection.mode,
|
filesystem_mode=filesystem_selection.mode,
|
||||||
available_connectors=available_connectors,
|
available_connectors=available_connectors,
|
||||||
available_document_types=available_document_types,
|
available_document_types=available_document_types,
|
||||||
mentioned_document_ids=mentioned_document_ids,
|
mentioned_document_ids=mentioned_document_ids,
|
||||||
anon_session_id=anon_session_id,
|
|
||||||
),
|
),
|
||||||
|
FileIntentMiddleware(llm=llm),
|
||||||
SurfSenseFilesystemMiddleware(
|
SurfSenseFilesystemMiddleware(
|
||||||
backend=backend_resolver,
|
backend=backend_resolver,
|
||||||
filesystem_mode=filesystem_selection.mode,
|
filesystem_mode=filesystem_selection.mode,
|
||||||
|
|
@ -409,12 +440,20 @@ async def create_surfsense_deep_agent(
|
||||||
created_by_id=user_id,
|
created_by_id=user_id,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
),
|
),
|
||||||
|
KnowledgeBasePersistenceMiddleware(
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
created_by_id=user_id,
|
||||||
|
filesystem_mode=filesystem_selection.mode,
|
||||||
|
)
|
||||||
|
if filesystem_selection.mode == FilesystemMode.CLOUD
|
||||||
|
else None,
|
||||||
SubAgentMiddleware(backend=StateBackend, subagents=[general_purpose_spec]),
|
SubAgentMiddleware(backend=StateBackend, subagents=[general_purpose_spec]),
|
||||||
create_safe_summarization_middleware(llm, StateBackend),
|
create_safe_summarization_middleware(llm, StateBackend),
|
||||||
PatchToolCallsMiddleware(),
|
PatchToolCallsMiddleware(),
|
||||||
DedupHITLToolCallsMiddleware(agent_tools=tools),
|
DedupHITLToolCallsMiddleware(agent_tools=tools),
|
||||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
||||||
]
|
]
|
||||||
|
deepagent_middleware = [m for m in deepagent_middleware if m is not None]
|
||||||
|
|
||||||
# Combine system_prompt with BASE_AGENT_PROMPT (same as create_deep_agent)
|
# Combine system_prompt with BASE_AGENT_PROMPT (same as create_deep_agent)
|
||||||
final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT
|
final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT
|
||||||
|
|
|
||||||
103
surfsense_backend/app/agents/new_chat/document_xml.py
Normal file
103
surfsense_backend/app/agents/new_chat/document_xml.py
Normal file
|
|
@ -0,0 +1,103 @@
|
||||||
|
"""Shared XML builder for KB documents.
|
||||||
|
|
||||||
|
Produces the citation-friendly XML used by every read of a knowledge-base
|
||||||
|
document (lazy-loaded by :class:`KBPostgresBackend` and synthetic anonymous
|
||||||
|
files). The XML carries a ``<chunk_index>`` near the top so the LLM can jump
|
||||||
|
directly to matched-chunk line ranges via ``read_file(offset=…, limit=…)``.
|
||||||
|
|
||||||
|
Extracted from the original ``knowledge_search.py`` so the backend, the
|
||||||
|
priority middleware, and any future renderer share a single implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def build_document_xml(
|
||||||
|
document: dict[str, Any],
|
||||||
|
matched_chunk_ids: set[int] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Build citation-friendly XML with a ``<chunk_index>`` for smart seeking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document: Dict shape produced by hybrid search / lazy-load helpers.
|
||||||
|
Expected keys: ``document`` (with ``id``, ``title``,
|
||||||
|
``document_type``, ``metadata``) and ``chunks``
|
||||||
|
(list of ``{chunk_id, content}``).
|
||||||
|
matched_chunk_ids: Optional set of chunk IDs to flag as
|
||||||
|
``matched="true"`` in the chunk index.
|
||||||
|
"""
|
||||||
|
matched = matched_chunk_ids or set()
|
||||||
|
|
||||||
|
doc_meta = document.get("document") or {}
|
||||||
|
metadata = (doc_meta.get("metadata") or {}) if isinstance(doc_meta, dict) else {}
|
||||||
|
document_id = doc_meta.get("id", document.get("document_id", "unknown"))
|
||||||
|
document_type = doc_meta.get("document_type", document.get("source", "UNKNOWN"))
|
||||||
|
title = doc_meta.get("title") or metadata.get("title") or "Untitled Document"
|
||||||
|
url = (
|
||||||
|
metadata.get("url") or metadata.get("source") or metadata.get("page_url") or ""
|
||||||
|
)
|
||||||
|
metadata_json = json.dumps(metadata, ensure_ascii=False)
|
||||||
|
|
||||||
|
metadata_lines: list[str] = [
|
||||||
|
"<document>",
|
||||||
|
"<document_metadata>",
|
||||||
|
f" <document_id>{document_id}</document_id>",
|
||||||
|
f" <document_type>{document_type}</document_type>",
|
||||||
|
f" <title><![CDATA[{title}]]></title>",
|
||||||
|
f" <url><![CDATA[{url}]]></url>",
|
||||||
|
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
|
||||||
|
"</document_metadata>",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
|
||||||
|
chunks = document.get("chunks") or []
|
||||||
|
chunk_entries: list[tuple[int | None, str]] = []
|
||||||
|
if isinstance(chunks, list):
|
||||||
|
for chunk in chunks:
|
||||||
|
if not isinstance(chunk, dict):
|
||||||
|
continue
|
||||||
|
chunk_id = chunk.get("chunk_id") or chunk.get("id")
|
||||||
|
chunk_content = str(chunk.get("content", "")).strip()
|
||||||
|
if not chunk_content:
|
||||||
|
continue
|
||||||
|
if chunk_id is None:
|
||||||
|
xml = f" <chunk><![CDATA[{chunk_content}]]></chunk>"
|
||||||
|
else:
|
||||||
|
xml = f" <chunk id='{chunk_id}'><![CDATA[{chunk_content}]]></chunk>"
|
||||||
|
chunk_entries.append((chunk_id, xml))
|
||||||
|
|
||||||
|
index_overhead = 1 + len(chunk_entries) + 1 + 1 + 1
|
||||||
|
first_chunk_line = len(metadata_lines) + index_overhead + 1
|
||||||
|
|
||||||
|
current_line = first_chunk_line
|
||||||
|
index_entry_lines: list[str] = []
|
||||||
|
for cid, xml_str in chunk_entries:
|
||||||
|
num_lines = xml_str.count("\n") + 1
|
||||||
|
end_line = current_line + num_lines - 1
|
||||||
|
matched_attr = ' matched="true"' if cid is not None and cid in matched else ""
|
||||||
|
if cid is not None:
|
||||||
|
index_entry_lines.append(
|
||||||
|
f' <entry chunk_id="{cid}" lines="{current_line}-{end_line}"{matched_attr}/>'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
index_entry_lines.append(
|
||||||
|
f' <entry lines="{current_line}-{end_line}"{matched_attr}/>'
|
||||||
|
)
|
||||||
|
current_line = end_line + 1
|
||||||
|
|
||||||
|
lines = metadata_lines.copy()
|
||||||
|
lines.append("<chunk_index>")
|
||||||
|
lines.extend(index_entry_lines)
|
||||||
|
lines.append("</chunk_index>")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("<document_content>")
|
||||||
|
for _, xml_str in chunk_entries:
|
||||||
|
lines.append(xml_str)
|
||||||
|
lines.extend(["</document_content>", "</document>"])
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["build_document_xml"]
|
||||||
|
|
@ -5,10 +5,12 @@ from __future__ import annotations
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from deepagents.backends.protocol import BackendProtocol
|
||||||
from deepagents.backends.state import StateBackend
|
from deepagents.backends.state import StateBackend
|
||||||
from langgraph.prebuilt.tool_node import ToolRuntime
|
from langgraph.prebuilt.tool_node import ToolRuntime
|
||||||
|
|
||||||
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||||
|
from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend
|
||||||
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
|
from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
|
||||||
MultiRootLocalFolderBackend,
|
MultiRootLocalFolderBackend,
|
||||||
)
|
)
|
||||||
|
|
@ -23,8 +25,20 @@ def _cached_multi_root_backend(
|
||||||
|
|
||||||
def build_backend_resolver(
|
def build_backend_resolver(
|
||||||
selection: FilesystemSelection,
|
selection: FilesystemSelection,
|
||||||
) -> Callable[[ToolRuntime], StateBackend | MultiRootLocalFolderBackend]:
|
*,
|
||||||
"""Create deepagents backend resolver for the selected filesystem mode."""
|
search_space_id: int | None = None,
|
||||||
|
) -> Callable[[ToolRuntime], BackendProtocol]:
|
||||||
|
"""Create deepagents backend resolver for the selected filesystem mode.
|
||||||
|
|
||||||
|
In cloud mode the resolver returns a fresh :class:`KBPostgresBackend`
|
||||||
|
bound to the current ``runtime`` so the backend can read staging state
|
||||||
|
(``staged_dirs``, ``pending_moves``, ``files`` cache, ``kb_anon_doc``,
|
||||||
|
``kb_matched_chunk_ids``) for each tool call. When no ``search_space_id``
|
||||||
|
is provided, the resolver falls back to :class:`StateBackend` (used by
|
||||||
|
sub-agents and tests that don't need DB-backed reads).
|
||||||
|
|
||||||
|
Desktop-local mode unchanged.
|
||||||
|
"""
|
||||||
|
|
||||||
if selection.mode == FilesystemMode.DESKTOP_LOCAL_FOLDER and selection.local_mounts:
|
if selection.mode == FilesystemMode.DESKTOP_LOCAL_FOLDER and selection.local_mounts:
|
||||||
|
|
||||||
|
|
@ -36,7 +50,14 @@ def build_backend_resolver(
|
||||||
|
|
||||||
return _resolve_local
|
return _resolve_local
|
||||||
|
|
||||||
def _resolve_cloud(runtime: ToolRuntime) -> StateBackend:
|
if search_space_id is not None:
|
||||||
|
|
||||||
|
def _resolve_kb(runtime: ToolRuntime) -> BackendProtocol:
|
||||||
|
return KBPostgresBackend(search_space_id, runtime)
|
||||||
|
|
||||||
|
return _resolve_kb
|
||||||
|
|
||||||
|
def _resolve_state(runtime: ToolRuntime) -> StateBackend:
|
||||||
return StateBackend(runtime)
|
return StateBackend(runtime)
|
||||||
|
|
||||||
return _resolve_cloud
|
return _resolve_state
|
||||||
|
|
|
||||||
113
surfsense_backend/app/agents/new_chat/filesystem_state.py
Normal file
113
surfsense_backend/app/agents/new_chat/filesystem_state.py
Normal file
|
|
@ -0,0 +1,113 @@
|
||||||
|
"""LangGraph state schema additions used by the SurfSense filesystem agent.
|
||||||
|
|
||||||
|
This schema extends deepagents' upstream :class:`FilesystemState` with the
|
||||||
|
extra fields needed to implement Postgres-backed virtual filesystem semantics:
|
||||||
|
|
||||||
|
* ``cwd`` — current working directory (per-thread checkpointed).
|
||||||
|
* ``staged_dirs`` — pending mkdir requests (cloud only).
|
||||||
|
* ``pending_moves`` — pending move_file requests (cloud only).
|
||||||
|
* ``doc_id_by_path`` — virtual_path -> Document.id, populated by lazy reads.
|
||||||
|
* ``dirty_paths`` — paths whose state file content differs from DB.
|
||||||
|
* ``kb_priority`` — top-K priority hints rendered into a system message.
|
||||||
|
* ``kb_matched_chunk_ids`` — internal hand-off for matched-chunk highlighting.
|
||||||
|
* ``kb_anon_doc`` — Redis-loaded anonymous document (if any).
|
||||||
|
* ``tree_version`` — bumped by persistence; invalidates the tree render cache.
|
||||||
|
|
||||||
|
Tools mutate these fields ONLY via ``Command(update=...)`` returns; the
|
||||||
|
reducers in :mod:`app.agents.new_chat.state_reducers` handle merging.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Annotated, Any, NotRequired
|
||||||
|
|
||||||
|
from deepagents.middleware.filesystem import FilesystemState
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from app.agents.new_chat.state_reducers import (
|
||||||
|
_add_unique_reducer,
|
||||||
|
_dict_merge_with_tombstones_reducer,
|
||||||
|
_list_append_reducer,
|
||||||
|
_replace_reducer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PendingMove(TypedDict):
|
||||||
|
"""A staged move_file operation pending end-of-turn commit."""
|
||||||
|
|
||||||
|
source: str
|
||||||
|
dest: str
|
||||||
|
overwrite: bool
|
||||||
|
|
||||||
|
|
||||||
|
class KbPriorityEntry(TypedDict, total=False):
|
||||||
|
path: str
|
||||||
|
score: float
|
||||||
|
document_id: int | None
|
||||||
|
title: str
|
||||||
|
mentioned: bool
|
||||||
|
|
||||||
|
|
||||||
|
class KbAnonDoc(TypedDict, total=False):
|
||||||
|
"""In-memory anonymous-session document loaded from Redis."""
|
||||||
|
|
||||||
|
path: str
|
||||||
|
title: str
|
||||||
|
content: str
|
||||||
|
chunks: list[dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class SurfSenseFilesystemState(FilesystemState):
|
||||||
|
"""Filesystem state used by the SurfSense agent (cloud + desktop).
|
||||||
|
|
||||||
|
Extends deepagents' :class:`FilesystemState` (which provides ``files``)
|
||||||
|
with cloud-mode staging fields and search-priority hints. All extra fields
|
||||||
|
are reducer-backed so that ``Command(update=...)`` payloads merge cleanly
|
||||||
|
across agent steps and across checkpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cwd: NotRequired[Annotated[str, _replace_reducer]]
|
||||||
|
"""Current working directory.
|
||||||
|
|
||||||
|
Defaults to ``"/documents"`` in cloud mode and ``"/"`` (or first mount) in
|
||||||
|
desktop mode. Initialized once per thread by ``KnowledgeTreeMiddleware``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
staged_dirs: NotRequired[Annotated[list[str], _add_unique_reducer]]
|
||||||
|
"""mkdir paths staged for end-of-turn folder creation (cloud only)."""
|
||||||
|
|
||||||
|
pending_moves: NotRequired[Annotated[list[PendingMove], _list_append_reducer]]
|
||||||
|
"""move_file ops staged for end-of-turn commit (cloud only)."""
|
||||||
|
|
||||||
|
doc_id_by_path: NotRequired[
|
||||||
|
Annotated[dict[str, int], _dict_merge_with_tombstones_reducer]
|
||||||
|
]
|
||||||
|
"""virtual_path -> ``Document.id`` for lazily loaded files.
|
||||||
|
|
||||||
|
Populated on first read of a KB document. Used by edit_file/move_file/
|
||||||
|
aafter_agent to map paths back to a real DB row. ``None`` values delete
|
||||||
|
the key (tombstones).
|
||||||
|
"""
|
||||||
|
|
||||||
|
dirty_paths: NotRequired[Annotated[list[str], _add_unique_reducer]]
|
||||||
|
"""Paths whose ``state["files"]`` content has been modified this turn."""
|
||||||
|
|
||||||
|
kb_priority: NotRequired[Annotated[list[KbPriorityEntry], _replace_reducer]]
|
||||||
|
"""Top-K priority hints rendered as a system message before the user turn."""
|
||||||
|
|
||||||
|
kb_matched_chunk_ids: NotRequired[Annotated[dict[int, list[int]], _replace_reducer]]
|
||||||
|
"""Internal: ``Document.id`` -> list of matched chunk IDs from hybrid search."""
|
||||||
|
|
||||||
|
kb_anon_doc: NotRequired[Annotated[KbAnonDoc | None, _replace_reducer]]
|
||||||
|
"""Anonymous-session document loaded from Redis (read-only, no DB row)."""
|
||||||
|
|
||||||
|
tree_version: NotRequired[Annotated[int, _replace_reducer]]
|
||||||
|
"""Monotonically increasing counter; bumped when commits change the KB tree."""
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"KbAnonDoc",
|
||||||
|
"KbPriorityEntry",
|
||||||
|
"PendingMove",
|
||||||
|
"SurfSenseFilesystemState",
|
||||||
|
]
|
||||||
|
|
@ -1,5 +1,8 @@
|
||||||
"""Middleware components for the SurfSense new chat agent."""
|
"""Middleware components for the SurfSense new chat agent."""
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware.anonymous_document import (
|
||||||
|
AnonymousDocumentMiddleware,
|
||||||
|
)
|
||||||
from app.agents.new_chat.middleware.dedup_tool_calls import (
|
from app.agents.new_chat.middleware.dedup_tool_calls import (
|
||||||
DedupHITLToolCallsMiddleware,
|
DedupHITLToolCallsMiddleware,
|
||||||
)
|
)
|
||||||
|
|
@ -9,17 +12,30 @@ from app.agents.new_chat.middleware.file_intent import (
|
||||||
from app.agents.new_chat.middleware.filesystem import (
|
from app.agents.new_chat.middleware.filesystem import (
|
||||||
SurfSenseFilesystemMiddleware,
|
SurfSenseFilesystemMiddleware,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.middleware.kb_persistence import (
|
||||||
|
KnowledgeBasePersistenceMiddleware,
|
||||||
|
commit_staged_filesystem_state,
|
||||||
|
)
|
||||||
from app.agents.new_chat.middleware.knowledge_search import (
|
from app.agents.new_chat.middleware.knowledge_search import (
|
||||||
KnowledgeBaseSearchMiddleware,
|
KnowledgeBaseSearchMiddleware,
|
||||||
|
KnowledgePriorityMiddleware,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.middleware.knowledge_tree import (
|
||||||
|
KnowledgeTreeMiddleware,
|
||||||
)
|
)
|
||||||
from app.agents.new_chat.middleware.memory_injection import (
|
from app.agents.new_chat.middleware.memory_injection import (
|
||||||
MemoryInjectionMiddleware,
|
MemoryInjectionMiddleware,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"AnonymousDocumentMiddleware",
|
||||||
"DedupHITLToolCallsMiddleware",
|
"DedupHITLToolCallsMiddleware",
|
||||||
"FileIntentMiddleware",
|
"FileIntentMiddleware",
|
||||||
|
"KnowledgeBasePersistenceMiddleware",
|
||||||
"KnowledgeBaseSearchMiddleware",
|
"KnowledgeBaseSearchMiddleware",
|
||||||
|
"KnowledgePriorityMiddleware",
|
||||||
|
"KnowledgeTreeMiddleware",
|
||||||
"MemoryInjectionMiddleware",
|
"MemoryInjectionMiddleware",
|
||||||
"SurfSenseFilesystemMiddleware",
|
"SurfSenseFilesystemMiddleware",
|
||||||
|
"commit_staged_filesystem_state",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,91 @@
|
||||||
|
"""Lightweight middleware that loads the anonymous-session document into state.
|
||||||
|
|
||||||
|
Anonymous chats receive a single uploaded document via Redis (no DB row,
|
||||||
|
read-only). This middleware loads it once on the first turn into
|
||||||
|
``state['kb_anon_doc']`` so:
|
||||||
|
|
||||||
|
* :class:`KnowledgeTreeMiddleware` can render the synthetic ``/documents``
|
||||||
|
view without touching the DB.
|
||||||
|
* :class:`KnowledgePriorityMiddleware` skips hybrid search and emits a
|
||||||
|
degenerate priority list.
|
||||||
|
* :class:`KBPostgresBackend` (``als_info`` / ``aread`` / ``_load_file_data``)
|
||||||
|
recognises the synthetic path.
|
||||||
|
|
||||||
|
The middleware is a no-op when ``anon_session_id`` is not provided or when
|
||||||
|
the document is already cached in state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
|
||||||
|
from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT, safe_filename
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AnonymousDocumentMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
"""Load the anonymous user's uploaded document from Redis into state."""
|
||||||
|
|
||||||
|
tools = ()
|
||||||
|
state_schema = SurfSenseFilesystemState
|
||||||
|
|
||||||
|
def __init__(self, *, anon_session_id: str | None) -> None:
|
||||||
|
self.anon_session_id = anon_session_id
|
||||||
|
|
||||||
|
async def abefore_agent( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
state: AgentState,
|
||||||
|
runtime: Runtime[Any],
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
del runtime
|
||||||
|
if not self.anon_session_id:
|
||||||
|
return None
|
||||||
|
if state.get("kb_anon_doc"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
anon_doc = await self._load_anon_document()
|
||||||
|
if anon_doc is None:
|
||||||
|
return None
|
||||||
|
return {"kb_anon_doc": anon_doc}
|
||||||
|
|
||||||
|
async def _load_anon_document(self) -> dict[str, Any] | None:
|
||||||
|
"""Read ``anon:doc:<session_id>`` from Redis."""
|
||||||
|
try:
|
||||||
|
import redis.asyncio as aioredis # local import to keep cold paths cheap
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
redis_client = aioredis.from_url(
|
||||||
|
config.REDIS_APP_URL, decode_responses=True
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
redis_key = f"anon:doc:{self.anon_session_id}"
|
||||||
|
data = await redis_client.get(redis_key)
|
||||||
|
if not data:
|
||||||
|
return None
|
||||||
|
payload = json.loads(data)
|
||||||
|
finally:
|
||||||
|
await redis_client.aclose()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to load anonymous document from Redis: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
title = str(payload.get("filename") or "uploaded_document")
|
||||||
|
content = str(payload.get("content") or "")
|
||||||
|
path = f"{DOCUMENTS_ROOT}/{safe_filename(title)}"
|
||||||
|
return {
|
||||||
|
"path": path,
|
||||||
|
"title": title,
|
||||||
|
"content": content,
|
||||||
|
"chunks": [{"chunk_id": -1, "content": content}] if content else [],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["AnonymousDocumentMiddleware"]
|
||||||
|
|
@ -21,7 +21,7 @@ from typing import Any
|
||||||
|
|
||||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
from pydantic import BaseModel, Field, ValidationError
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
|
|
||||||
|
|
@ -217,8 +217,19 @@ def _build_recent_conversation(
|
||||||
messages: list[BaseMessage], *, max_messages: int = 6
|
messages: list[BaseMessage], *, max_messages: int = 6
|
||||||
) -> str:
|
) -> str:
|
||||||
rows: list[str] = []
|
rows: list[str] = []
|
||||||
for msg in messages[-max_messages:]:
|
filtered: list[tuple[str, BaseMessage]] = []
|
||||||
role = "user" if isinstance(msg, HumanMessage) else "assistant"
|
for msg in messages:
|
||||||
|
role: str | None = None
|
||||||
|
if isinstance(msg, HumanMessage):
|
||||||
|
role = "user"
|
||||||
|
elif isinstance(msg, AIMessage):
|
||||||
|
if getattr(msg, "tool_calls", None):
|
||||||
|
continue
|
||||||
|
role = "assistant"
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
filtered.append((role, msg))
|
||||||
|
for role, msg in filtered[-max_messages:]:
|
||||||
text = re.sub(r"\s+", " ", _extract_text_from_message(msg)).strip()
|
text = re.sub(r"\s+", " ", _extract_text_from_message(msg)).strip()
|
||||||
if text:
|
if text:
|
||||||
rows.append(f"{role}: {text[:280]}")
|
rows.append(f"{role}: {text[:280]}")
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,622 @@
|
||||||
|
"""End-of-turn persistence for the cloud-mode SurfSense filesystem.
|
||||||
|
|
||||||
|
This middleware runs ``aafter_agent`` once per turn (cloud only). It commits
|
||||||
|
all staged folder creations, file moves, and content writes/edits to
|
||||||
|
Postgres in a single ordered pass:
|
||||||
|
|
||||||
|
1. Materialize ``staged_dirs`` into ``Folder`` rows.
|
||||||
|
2. Apply ``pending_moves`` in order (chained moves resolved via
|
||||||
|
``doc_id_by_path``).
|
||||||
|
3. Normalize ``dirty_paths`` through ``pending_moves`` so write-then-move
|
||||||
|
sequences commit at the final path.
|
||||||
|
4. Commit content writes / edits for ``/documents/*`` paths, skipping
|
||||||
|
``temp_*`` basenames.
|
||||||
|
|
||||||
|
The commit body is exposed as a free function ``commit_staged_filesystem_state``
|
||||||
|
so the optional stream-task fallback (``stream_new_chat.py``) can call the
|
||||||
|
exact same routine when ``aafter_agent`` was skipped (e.g. client disconnect).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fractional_indexing import generate_key_between
|
||||||
|
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||||
|
from langchain_core.callbacks import dispatch_custom_event
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
from sqlalchemy import delete, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
|
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
|
||||||
|
from app.agents.new_chat.path_resolver import (
|
||||||
|
DOCUMENTS_ROOT,
|
||||||
|
parse_documents_path,
|
||||||
|
safe_folder_segment,
|
||||||
|
virtual_path_to_doc,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.state_reducers import _CLEAR
|
||||||
|
from app.db import (
|
||||||
|
Chunk,
|
||||||
|
Document,
|
||||||
|
DocumentType,
|
||||||
|
Folder,
|
||||||
|
shielded_async_session,
|
||||||
|
)
|
||||||
|
from app.indexing_pipeline.document_chunker import chunk_text
|
||||||
|
from app.utils.document_converters import (
|
||||||
|
embed_texts,
|
||||||
|
generate_content_hash,
|
||||||
|
generate_unique_identifier_hash,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
_TEMP_PREFIX = "temp_"
|
||||||
|
|
||||||
|
|
||||||
|
def _basename(path: str) -> str:
|
||||||
|
return path.rsplit("/", 1)[-1]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Folder helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def _ensure_folder_hierarchy(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
search_space_id: int,
|
||||||
|
created_by_id: str | None,
|
||||||
|
folder_parts: list[str],
|
||||||
|
) -> int | None:
|
||||||
|
"""Ensure a chain of folder names exists under the search space.
|
||||||
|
|
||||||
|
Returns the leaf folder id, or ``None`` if ``folder_parts`` is empty
|
||||||
|
(i.e. a document directly under ``/documents/``).
|
||||||
|
"""
|
||||||
|
if not folder_parts:
|
||||||
|
return None
|
||||||
|
parent_id: int | None = None
|
||||||
|
for raw in folder_parts:
|
||||||
|
name = safe_folder_segment(str(raw))
|
||||||
|
query = select(Folder).where(
|
||||||
|
Folder.search_space_id == search_space_id,
|
||||||
|
Folder.name == name,
|
||||||
|
)
|
||||||
|
if parent_id is None:
|
||||||
|
query = query.where(Folder.parent_id.is_(None))
|
||||||
|
else:
|
||||||
|
query = query.where(Folder.parent_id == parent_id)
|
||||||
|
result = await session.execute(query)
|
||||||
|
folder = result.scalar_one_or_none()
|
||||||
|
if folder is None:
|
||||||
|
sibling_query = (
|
||||||
|
select(Folder.position).order_by(Folder.position.desc()).limit(1)
|
||||||
|
)
|
||||||
|
sibling_query = sibling_query.where(
|
||||||
|
Folder.search_space_id == search_space_id
|
||||||
|
)
|
||||||
|
if parent_id is None:
|
||||||
|
sibling_query = sibling_query.where(Folder.parent_id.is_(None))
|
||||||
|
else:
|
||||||
|
sibling_query = sibling_query.where(Folder.parent_id == parent_id)
|
||||||
|
sibling_result = await session.execute(sibling_query)
|
||||||
|
last_position = sibling_result.scalar_one_or_none()
|
||||||
|
folder = Folder(
|
||||||
|
name=name,
|
||||||
|
position=generate_key_between(last_position, None),
|
||||||
|
parent_id=parent_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
created_by_id=created_by_id,
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
session.add(folder)
|
||||||
|
await session.flush()
|
||||||
|
parent_id = folder.id
|
||||||
|
return parent_id
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Document helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_document(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
virtual_path: str,
|
||||||
|
content: str,
|
||||||
|
search_space_id: int,
|
||||||
|
created_by_id: str | None,
|
||||||
|
) -> Document:
|
||||||
|
"""Create a NOTE Document + Chunks for ``virtual_path``."""
|
||||||
|
folder_parts, title = parse_documents_path(virtual_path)
|
||||||
|
if not title:
|
||||||
|
raise ValueError(f"invalid /documents path '{virtual_path}'")
|
||||||
|
folder_id = await _ensure_folder_hierarchy(
|
||||||
|
session,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
created_by_id=created_by_id,
|
||||||
|
folder_parts=folder_parts,
|
||||||
|
)
|
||||||
|
unique_identifier_hash = generate_unique_identifier_hash(
|
||||||
|
DocumentType.NOTE,
|
||||||
|
virtual_path,
|
||||||
|
search_space_id,
|
||||||
|
)
|
||||||
|
# Guard against the unique_identifier_hash constraint: another row at the
|
||||||
|
# same virtual_path (this search space) already owns the hash. Callers are
|
||||||
|
# expected to upsert via the wrapper, but this defends against bypasses
|
||||||
|
# and gives a clean ValueError instead of a session-poisoning IntegrityError.
|
||||||
|
path_collision = await session.execute(
|
||||||
|
select(Document.id).where(
|
||||||
|
Document.search_space_id == search_space_id,
|
||||||
|
Document.unique_identifier_hash == unique_identifier_hash,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if path_collision.scalar_one_or_none() is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"a document already exists at path '{virtual_path}' "
|
||||||
|
"(unique_identifier_hash collision)"
|
||||||
|
)
|
||||||
|
content_hash = generate_content_hash(content, search_space_id)
|
||||||
|
content_collision = await session.execute(
|
||||||
|
select(Document.id).where(
|
||||||
|
Document.search_space_id == search_space_id,
|
||||||
|
Document.content_hash == content_hash,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if content_collision.scalar_one_or_none() is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"a document with identical content already exists for path '{virtual_path}'"
|
||||||
|
)
|
||||||
|
doc = Document(
|
||||||
|
title=title,
|
||||||
|
document_type=DocumentType.NOTE,
|
||||||
|
document_metadata={"virtual_path": virtual_path},
|
||||||
|
content=content,
|
||||||
|
content_hash=content_hash,
|
||||||
|
unique_identifier_hash=unique_identifier_hash,
|
||||||
|
source_markdown=content,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
folder_id=folder_id,
|
||||||
|
created_by_id=created_by_id,
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
session.add(doc)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
summary_embedding = embed_texts([content])[0]
|
||||||
|
doc.embedding = summary_embedding
|
||||||
|
chunks = chunk_text(content)
|
||||||
|
if chunks:
|
||||||
|
chunk_embeddings = embed_texts(chunks)
|
||||||
|
session.add_all(
|
||||||
|
[
|
||||||
|
Chunk(document_id=doc.id, content=text, embedding=embedding)
|
||||||
|
for text, embedding in zip(chunks, chunk_embeddings, strict=True)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return doc
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_document(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
doc_id: int,
|
||||||
|
content: str,
|
||||||
|
virtual_path: str,
|
||||||
|
search_space_id: int,
|
||||||
|
) -> Document | None:
|
||||||
|
"""Update an existing Document's content + chunks."""
|
||||||
|
result = await session.execute(
|
||||||
|
select(Document).where(
|
||||||
|
Document.id == doc_id,
|
||||||
|
Document.search_space_id == search_space_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
document = result.scalar_one_or_none()
|
||||||
|
if document is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
document.content = content
|
||||||
|
document.source_markdown = content
|
||||||
|
document.content_hash = generate_content_hash(content, search_space_id)
|
||||||
|
document.updated_at = datetime.now(UTC)
|
||||||
|
metadata = dict(document.document_metadata or {})
|
||||||
|
metadata["virtual_path"] = virtual_path
|
||||||
|
document.document_metadata = metadata
|
||||||
|
document.unique_identifier_hash = generate_unique_identifier_hash(
|
||||||
|
DocumentType.NOTE,
|
||||||
|
virtual_path,
|
||||||
|
search_space_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
summary_embedding = embed_texts([content])[0]
|
||||||
|
document.embedding = summary_embedding
|
||||||
|
|
||||||
|
await session.execute(delete(Chunk).where(Chunk.document_id == document.id))
|
||||||
|
chunks = chunk_text(content)
|
||||||
|
if chunks:
|
||||||
|
chunk_embeddings = embed_texts(chunks)
|
||||||
|
session.add_all(
|
||||||
|
[
|
||||||
|
Chunk(document_id=document.id, content=text, embedding=embedding)
|
||||||
|
for text, embedding in zip(chunks, chunk_embeddings, strict=True)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return document
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Move helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def _apply_move(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
search_space_id: int,
|
||||||
|
created_by_id: str | None,
|
||||||
|
move: dict[str, Any],
|
||||||
|
doc_id_by_path: dict[str, int],
|
||||||
|
doc_id_path_tombstones: dict[str, int | None],
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Apply a single staged move; updates the in-memory mapping for chain resolution."""
|
||||||
|
source = str(move.get("source") or "")
|
||||||
|
dest = str(move.get("dest") or "")
|
||||||
|
if not source or not dest or source == dest:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not source.startswith(DOCUMENTS_ROOT + "/") or not dest.startswith(
|
||||||
|
DOCUMENTS_ROOT + "/"
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
doc_id: int | None = doc_id_by_path.get(source)
|
||||||
|
document: Document | None = None
|
||||||
|
if doc_id is not None:
|
||||||
|
result = await session.execute(
|
||||||
|
select(Document).where(
|
||||||
|
Document.id == doc_id,
|
||||||
|
Document.search_space_id == search_space_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
document = result.scalar_one_or_none()
|
||||||
|
if document is None:
|
||||||
|
document = await virtual_path_to_doc(
|
||||||
|
session,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
virtual_path=source,
|
||||||
|
)
|
||||||
|
if document is None:
|
||||||
|
logger.info(
|
||||||
|
"kb_persistence: skipping move %s -> %s (source not found)",
|
||||||
|
source,
|
||||||
|
dest,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
folder_parts, new_title = parse_documents_path(dest)
|
||||||
|
if not new_title:
|
||||||
|
return None
|
||||||
|
folder_id = await _ensure_folder_hierarchy(
|
||||||
|
session,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
created_by_id=created_by_id,
|
||||||
|
folder_parts=folder_parts,
|
||||||
|
)
|
||||||
|
|
||||||
|
document.title = new_title
|
||||||
|
document.folder_id = folder_id
|
||||||
|
metadata = dict(document.document_metadata or {})
|
||||||
|
metadata["virtual_path"] = dest
|
||||||
|
document.document_metadata = metadata
|
||||||
|
document.unique_identifier_hash = generate_unique_identifier_hash(
|
||||||
|
DocumentType.NOTE,
|
||||||
|
dest,
|
||||||
|
search_space_id,
|
||||||
|
)
|
||||||
|
document.updated_at = datetime.now(UTC)
|
||||||
|
|
||||||
|
doc_id_by_path.pop(source, None)
|
||||||
|
doc_id_by_path[dest] = document.id
|
||||||
|
doc_id_path_tombstones[source] = None
|
||||||
|
doc_id_path_tombstones[dest] = document.id
|
||||||
|
return {"id": document.id, "source": source, "dest": dest, "title": new_title}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Commit body
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def commit_staged_filesystem_state(
|
||||||
|
state: dict[str, Any] | AgentState,
|
||||||
|
*,
|
||||||
|
search_space_id: int,
|
||||||
|
created_by_id: str | None,
|
||||||
|
filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
||||||
|
dispatch_events: bool = True,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Commit all staged filesystem changes; return the state delta for reducers.
|
||||||
|
|
||||||
|
Shared between :class:`KnowledgeBasePersistenceMiddleware.aafter_agent`
|
||||||
|
and the optional stream-task fallback.
|
||||||
|
"""
|
||||||
|
if filesystem_mode != FilesystemMode.CLOUD:
|
||||||
|
return None
|
||||||
|
|
||||||
|
state_dict: dict[str, Any] = (
|
||||||
|
dict(state)
|
||||||
|
if isinstance(state, dict)
|
||||||
|
else dict(getattr(state, "values", {}) or {})
|
||||||
|
)
|
||||||
|
|
||||||
|
files: dict[str, Any] = state_dict.get("files") or {}
|
||||||
|
staged_dirs: list[str] = list(state_dict.get("staged_dirs") or [])
|
||||||
|
pending_moves: list[dict[str, Any]] = list(state_dict.get("pending_moves") or [])
|
||||||
|
dirty_paths: list[str] = list(state_dict.get("dirty_paths") or [])
|
||||||
|
doc_id_by_path: dict[str, int] = dict(state_dict.get("doc_id_by_path") or {})
|
||||||
|
kb_anon_doc = state_dict.get("kb_anon_doc")
|
||||||
|
|
||||||
|
if kb_anon_doc:
|
||||||
|
temp_paths = [
|
||||||
|
p
|
||||||
|
for p in files
|
||||||
|
if isinstance(p, str) and _basename(p).startswith(_TEMP_PREFIX)
|
||||||
|
]
|
||||||
|
return {
|
||||||
|
"dirty_paths": [_CLEAR],
|
||||||
|
"staged_dirs": [_CLEAR],
|
||||||
|
"pending_moves": [_CLEAR],
|
||||||
|
"files": dict.fromkeys(temp_paths),
|
||||||
|
}
|
||||||
|
|
||||||
|
if not (staged_dirs or pending_moves or dirty_paths):
|
||||||
|
return None
|
||||||
|
|
||||||
|
committed_creates: list[dict[str, Any]] = []
|
||||||
|
committed_updates: list[dict[str, Any]] = []
|
||||||
|
discarded: list[str] = []
|
||||||
|
applied_moves: list[dict[str, Any]] = []
|
||||||
|
doc_id_path_tombstones: dict[str, int | None] = {}
|
||||||
|
tree_changed = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with shielded_async_session() as session:
|
||||||
|
for folder_path in staged_dirs:
|
||||||
|
if not isinstance(folder_path, str):
|
||||||
|
continue
|
||||||
|
if not folder_path.startswith(DOCUMENTS_ROOT):
|
||||||
|
continue
|
||||||
|
rel = folder_path[len(DOCUMENTS_ROOT) :].strip("/")
|
||||||
|
folder_parts_full = [p for p in rel.split("/") if p]
|
||||||
|
if not folder_parts_full:
|
||||||
|
continue
|
||||||
|
await _ensure_folder_hierarchy(
|
||||||
|
session,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
created_by_id=created_by_id,
|
||||||
|
folder_parts=folder_parts_full,
|
||||||
|
)
|
||||||
|
tree_changed = True
|
||||||
|
|
||||||
|
for move in pending_moves:
|
||||||
|
applied = await _apply_move(
|
||||||
|
session,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
created_by_id=created_by_id,
|
||||||
|
move=move,
|
||||||
|
doc_id_by_path=doc_id_by_path,
|
||||||
|
doc_id_path_tombstones=doc_id_path_tombstones,
|
||||||
|
)
|
||||||
|
if applied:
|
||||||
|
applied_moves.append(applied)
|
||||||
|
tree_changed = True
|
||||||
|
|
||||||
|
move_alias = {
|
||||||
|
m["source"]: m["dest"] for m in pending_moves if m.get("source")
|
||||||
|
}
|
||||||
|
|
||||||
|
def _final_path(path: str) -> str:
|
||||||
|
seen: set[str] = set()
|
||||||
|
while path in move_alias and path not in seen:
|
||||||
|
seen.add(path)
|
||||||
|
path = move_alias[path]
|
||||||
|
return path
|
||||||
|
|
||||||
|
kb_dirty_seen: set[str] = set()
|
||||||
|
kb_dirty: list[str] = []
|
||||||
|
for raw in dirty_paths:
|
||||||
|
if not isinstance(raw, str):
|
||||||
|
continue
|
||||||
|
final = _final_path(raw)
|
||||||
|
if not final.startswith(DOCUMENTS_ROOT + "/"):
|
||||||
|
continue
|
||||||
|
if final in kb_dirty_seen:
|
||||||
|
continue
|
||||||
|
kb_dirty_seen.add(final)
|
||||||
|
kb_dirty.append(final)
|
||||||
|
|
||||||
|
for path in kb_dirty:
|
||||||
|
basename = _basename(path)
|
||||||
|
if basename.startswith(_TEMP_PREFIX):
|
||||||
|
discarded.append(path)
|
||||||
|
continue
|
||||||
|
file_data = files.get(path)
|
||||||
|
if not isinstance(file_data, dict):
|
||||||
|
continue
|
||||||
|
content = "\n".join(file_data.get("content") or [])
|
||||||
|
doc_id = doc_id_by_path.get(path)
|
||||||
|
if doc_id is None:
|
||||||
|
# The in-memory ``doc_id_by_path`` is per-thread and starts
|
||||||
|
# empty in every new chat. If the agent writes to a path
|
||||||
|
# that already exists in the DB (e.g. a previous chat's
|
||||||
|
# ``notes.md``), we must NOT try to INSERT — it would hit
|
||||||
|
# ``unique_identifier_hash`` (path-derived). Look up the
|
||||||
|
# existing doc and update it in place instead.
|
||||||
|
existing = await virtual_path_to_doc(
|
||||||
|
session,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
virtual_path=path,
|
||||||
|
)
|
||||||
|
if existing is not None:
|
||||||
|
doc_id = existing.id
|
||||||
|
doc_id_by_path[path] = existing.id
|
||||||
|
if doc_id is not None:
|
||||||
|
updated = await _update_document(
|
||||||
|
session,
|
||||||
|
doc_id=doc_id,
|
||||||
|
content=content,
|
||||||
|
virtual_path=path,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
)
|
||||||
|
if updated is not None:
|
||||||
|
committed_updates.append(
|
||||||
|
{
|
||||||
|
"id": updated.id,
|
||||||
|
"title": updated.title,
|
||||||
|
"documentType": DocumentType.NOTE.value,
|
||||||
|
"searchSpaceId": search_space_id,
|
||||||
|
"folderId": updated.folder_id,
|
||||||
|
"createdById": str(created_by_id)
|
||||||
|
if created_by_id
|
||||||
|
else None,
|
||||||
|
"virtualPath": path,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
new_doc = await _create_document(
|
||||||
|
session,
|
||||||
|
virtual_path=path,
|
||||||
|
content=content,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
created_by_id=created_by_id,
|
||||||
|
)
|
||||||
|
except ValueError as exc:
|
||||||
|
logger.warning(
|
||||||
|
"kb_persistence: skipping %s create: %s", path, exc
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
doc_id_by_path[path] = new_doc.id
|
||||||
|
committed_creates.append(
|
||||||
|
{
|
||||||
|
"id": new_doc.id,
|
||||||
|
"title": new_doc.title,
|
||||||
|
"documentType": DocumentType.NOTE.value,
|
||||||
|
"searchSpaceId": search_space_id,
|
||||||
|
"folderId": new_doc.folder_id,
|
||||||
|
"createdById": str(created_by_id)
|
||||||
|
if created_by_id
|
||||||
|
else None,
|
||||||
|
"virtualPath": path,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tree_changed = True
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
except Exception: # pragma: no cover - rollback safety net
|
||||||
|
logger.exception(
|
||||||
|
"kb_persistence: commit failed (search_space=%s)", search_space_id
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if dispatch_events:
|
||||||
|
for payload in committed_creates:
|
||||||
|
try:
|
||||||
|
dispatch_custom_event("document_created", payload)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"kb_persistence: failed to dispatch document_created event"
|
||||||
|
)
|
||||||
|
for payload in committed_updates:
|
||||||
|
try:
|
||||||
|
dispatch_custom_event("document_updated", payload)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"kb_persistence: failed to dispatch document_updated event"
|
||||||
|
)
|
||||||
|
|
||||||
|
temp_paths = [
|
||||||
|
p for p in files if isinstance(p, str) and _basename(p).startswith(_TEMP_PREFIX)
|
||||||
|
]
|
||||||
|
|
||||||
|
doc_id_update: dict[str, int | None] = {**doc_id_path_tombstones}
|
||||||
|
for payload in committed_creates:
|
||||||
|
doc_id_update[str(payload.get("virtualPath") or "")] = int(payload["id"])
|
||||||
|
|
||||||
|
delta: dict[str, Any] = {
|
||||||
|
"dirty_paths": [_CLEAR],
|
||||||
|
"staged_dirs": [_CLEAR],
|
||||||
|
"pending_moves": [_CLEAR],
|
||||||
|
}
|
||||||
|
if temp_paths:
|
||||||
|
delta["files"] = dict.fromkeys(temp_paths)
|
||||||
|
if doc_id_update:
|
||||||
|
delta["doc_id_by_path"] = doc_id_update
|
||||||
|
if tree_changed:
|
||||||
|
delta["tree_version"] = int(state_dict.get("tree_version") or 0) + 1
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"kb_persistence: commit (search_space=%s) creates=%d updates=%d "
|
||||||
|
"moves=%d staged_dirs=%d discarded=%d",
|
||||||
|
search_space_id,
|
||||||
|
len(committed_creates),
|
||||||
|
len(committed_updates),
|
||||||
|
len(applied_moves),
|
||||||
|
len(staged_dirs),
|
||||||
|
len(discarded),
|
||||||
|
)
|
||||||
|
return delta
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Middleware
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeBasePersistenceMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
"""End-of-turn cloud persistence for the SurfSense filesystem agent."""
|
||||||
|
|
||||||
|
tools = ()
|
||||||
|
state_schema = SurfSenseFilesystemState
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
search_space_id: int,
|
||||||
|
created_by_id: str | None,
|
||||||
|
filesystem_mode: FilesystemMode,
|
||||||
|
) -> None:
|
||||||
|
self.search_space_id = search_space_id
|
||||||
|
self.created_by_id = created_by_id
|
||||||
|
self.filesystem_mode = filesystem_mode
|
||||||
|
|
||||||
|
async def aafter_agent( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
state: AgentState,
|
||||||
|
runtime: Runtime[Any],
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
del runtime
|
||||||
|
if self.filesystem_mode != FilesystemMode.CLOUD:
|
||||||
|
return None
|
||||||
|
return await commit_staged_filesystem_state(
|
||||||
|
state,
|
||||||
|
search_space_id=self.search_space_id,
|
||||||
|
created_by_id=self.created_by_id,
|
||||||
|
filesystem_mode=self.filesystem_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"KnowledgeBasePersistenceMiddleware",
|
||||||
|
"commit_staged_filesystem_state",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,963 @@
|
||||||
|
"""Postgres-backed virtual filesystem for the SurfSense agent (cloud mode).
|
||||||
|
|
||||||
|
The backend is **strictly conforming** to deepagents'
|
||||||
|
:class:`BackendProtocol`. It returns ``WriteResult`` / ``EditResult`` / list
|
||||||
|
shapes exactly as upstream expects (no extra fields). All side-state
|
||||||
|
plumbing — ``dirty_paths``, ``doc_id_by_path``, ``staged_dirs``,
|
||||||
|
``pending_moves``, ``files`` cache — is appended by the overridden tool
|
||||||
|
wrappers in :class:`SurfSenseFilesystemMiddleware` via ``Command.update``.
|
||||||
|
|
||||||
|
The backend never writes to Postgres. End-of-turn persistence is handled by
|
||||||
|
:class:`KnowledgeBasePersistenceMiddleware`. This module is purely a
|
||||||
|
read-side and a state-merging helper.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import contextlib
|
||||||
|
import fnmatch
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import UTC
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from deepagents.backends.protocol import (
|
||||||
|
BackendProtocol,
|
||||||
|
EditResult,
|
||||||
|
FileDownloadResponse,
|
||||||
|
FileInfo,
|
||||||
|
FileUploadResponse,
|
||||||
|
GrepMatch,
|
||||||
|
WriteResult,
|
||||||
|
)
|
||||||
|
from deepagents.backends.utils import (
|
||||||
|
create_file_data,
|
||||||
|
file_data_to_string,
|
||||||
|
format_read_response,
|
||||||
|
perform_string_replacement,
|
||||||
|
update_file_data,
|
||||||
|
)
|
||||||
|
from langchain.tools import ToolRuntime
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.agents.new_chat.document_xml import build_document_xml
|
||||||
|
from app.agents.new_chat.path_resolver import (
|
||||||
|
DOCUMENTS_ROOT,
|
||||||
|
build_path_index,
|
||||||
|
doc_to_virtual_path,
|
||||||
|
virtual_path_to_doc,
|
||||||
|
)
|
||||||
|
from app.db import Chunk, Document, shielded_async_session
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_TEMP_PREFIX = "temp_"
|
||||||
|
_GREP_MAX_TOTAL_MATCHES = 50
|
||||||
|
_GREP_MAX_PER_DOC = 5
|
||||||
|
|
||||||
|
|
||||||
|
def _basename(path: str) -> str:
|
||||||
|
return path.rsplit("/", 1)[-1]
|
||||||
|
|
||||||
|
|
||||||
|
def _is_under(child: str, parent: str) -> bool:
|
||||||
|
"""Return True iff ``child`` is at-or-under ``parent`` (directory semantics)."""
|
||||||
|
if parent == "/":
|
||||||
|
return child.startswith("/")
|
||||||
|
return child == parent or child.startswith(parent.rstrip("/") + "/")
|
||||||
|
|
||||||
|
|
||||||
|
def paginate_listing(
|
||||||
|
infos: list[FileInfo],
|
||||||
|
*,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> list[FileInfo]:
|
||||||
|
"""Paginate a listing produced by :meth:`KBPostgresBackend.als_info`."""
|
||||||
|
if offset < 0:
|
||||||
|
offset = 0
|
||||||
|
end: int | None
|
||||||
|
end = None if limit is None or limit < 0 else offset + limit
|
||||||
|
return list(infos[offset:end])
|
||||||
|
|
||||||
|
|
||||||
|
class KBPostgresBackend(BackendProtocol):
|
||||||
|
"""Lazy, read-only Postgres view for ``/documents/*`` virtual paths.
|
||||||
|
|
||||||
|
The backend exposes a virtual ``/documents/`` namespace mirroring the
|
||||||
|
``Folder``/``Document`` graph. Reads materialize XML on first access and
|
||||||
|
cache it via the overriding tool wrappers (NOT here). Writes never touch
|
||||||
|
the DB — they return ``files_update`` deltas that the wrappers turn into
|
||||||
|
Command updates, and the persistence middleware commits them at end of
|
||||||
|
turn.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_IMAGE_EXTENSIONS = frozenset({".png", ".jpg", ".jpeg", ".gif", ".webp"})
|
||||||
|
|
||||||
|
def __init__(self, search_space_id: int, runtime: ToolRuntime) -> None:
|
||||||
|
self.search_space_id = search_space_id
|
||||||
|
self.runtime = runtime
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self) -> dict[str, Any]:
|
||||||
|
return getattr(self.runtime, "state", {}) or {}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ helpers
|
||||||
|
|
||||||
|
def _state_files(self) -> dict[str, Any]:
|
||||||
|
return dict(self.state.get("files") or {})
|
||||||
|
|
||||||
|
def _staged_dirs(self) -> list[str]:
|
||||||
|
return list(self.state.get("staged_dirs") or [])
|
||||||
|
|
||||||
|
def _pending_moves(self) -> list[dict[str, Any]]:
|
||||||
|
return list(self.state.get("pending_moves") or [])
|
||||||
|
|
||||||
|
def _kb_anon_doc(self) -> dict[str, Any] | None:
|
||||||
|
anon = self.state.get("kb_anon_doc")
|
||||||
|
return anon if isinstance(anon, dict) else None
|
||||||
|
|
||||||
|
def _matched_chunk_ids(self, doc_id: int) -> set[int]:
|
||||||
|
mapping = self.state.get("kb_matched_chunk_ids") or {}
|
||||||
|
try:
|
||||||
|
return set(mapping.get(doc_id, []) or [])
|
||||||
|
except TypeError:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _file_data_size(file_data: dict[str, Any]) -> int:
|
||||||
|
try:
|
||||||
|
return len("\n".join(file_data.get("content") or []))
|
||||||
|
except Exception:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def _normalize_listing_path(self, path: str) -> str:
|
||||||
|
if not path:
|
||||||
|
return DOCUMENTS_ROOT
|
||||||
|
if path == "/":
|
||||||
|
return path
|
||||||
|
return path.rstrip("/") if path != "/" else path
|
||||||
|
|
||||||
|
def _moved_view_paths(
|
||||||
|
self,
|
||||||
|
existing: dict[str, dict[str, Any]],
|
||||||
|
) -> tuple[set[str], dict[str, str]]:
|
||||||
|
"""Apply ``pending_moves`` to a path set and return ``(removed, alias)``.
|
||||||
|
|
||||||
|
Removed paths should disappear from listings; ``alias[source] = dest``
|
||||||
|
means a virtual entry should appear at ``dest`` even if no DB row is
|
||||||
|
yet there.
|
||||||
|
"""
|
||||||
|
removed: set[str] = set()
|
||||||
|
alias: dict[str, str] = {}
|
||||||
|
for move in self._pending_moves():
|
||||||
|
src = move.get("source")
|
||||||
|
dst = move.get("dest")
|
||||||
|
if not src or not dst:
|
||||||
|
continue
|
||||||
|
removed.add(src)
|
||||||
|
alias[src] = dst
|
||||||
|
existing.pop(src, None)
|
||||||
|
return removed, alias
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ ls/read
|
||||||
|
|
||||||
|
async def als_info(self, path: str) -> list[FileInfo]: # type: ignore[override]
|
||||||
|
normalized = self._normalize_listing_path(path)
|
||||||
|
infos: list[FileInfo] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
|
||||||
|
anon = self._kb_anon_doc()
|
||||||
|
if anon:
|
||||||
|
anon_path = str(anon.get("path") or "")
|
||||||
|
if (
|
||||||
|
anon_path
|
||||||
|
and _is_under(anon_path, normalized)
|
||||||
|
and anon_path != normalized
|
||||||
|
and anon_path not in seen
|
||||||
|
):
|
||||||
|
infos.append(
|
||||||
|
FileInfo(
|
||||||
|
path=anon_path,
|
||||||
|
is_dir=False,
|
||||||
|
size=len(str(anon.get("content") or "")),
|
||||||
|
modified_at="",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen.add(anon_path)
|
||||||
|
|
||||||
|
files = self._state_files()
|
||||||
|
moved_removed, moved_alias = self._moved_view_paths(files)
|
||||||
|
|
||||||
|
if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/":
|
||||||
|
try:
|
||||||
|
async with shielded_async_session() as session:
|
||||||
|
db_infos, subdir_paths = await self._list_db_directory(
|
||||||
|
session, normalized
|
||||||
|
)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive
|
||||||
|
logger.warning("KBPostgresBackend.als_info DB error: %s", exc)
|
||||||
|
db_infos, subdir_paths = [], set()
|
||||||
|
|
||||||
|
for info in db_infos:
|
||||||
|
p = info.get("path", "")
|
||||||
|
if not p or p in seen or p in moved_removed:
|
||||||
|
continue
|
||||||
|
infos.append(info)
|
||||||
|
seen.add(p)
|
||||||
|
|
||||||
|
for src, dst in moved_alias.items():
|
||||||
|
if src not in seen:
|
||||||
|
if not _is_under(dst, normalized):
|
||||||
|
continue
|
||||||
|
rel = (
|
||||||
|
dst[len(normalized) :].lstrip("/")
|
||||||
|
if normalized != "/"
|
||||||
|
else dst.lstrip("/")
|
||||||
|
)
|
||||||
|
if "/" in rel:
|
||||||
|
subdir_paths.add(
|
||||||
|
(normalized.rstrip("/") + "/" + rel.split("/", 1)[0])
|
||||||
|
if normalized != "/"
|
||||||
|
else "/" + rel.split("/", 1)[0]
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
if dst in seen:
|
||||||
|
continue
|
||||||
|
fd = files.get(dst)
|
||||||
|
size = self._file_data_size(fd) if isinstance(fd, dict) else 0
|
||||||
|
infos.append(
|
||||||
|
FileInfo(
|
||||||
|
path=dst,
|
||||||
|
is_dir=False,
|
||||||
|
size=int(size),
|
||||||
|
modified_at=fd.get("modified_at", "")
|
||||||
|
if isinstance(fd, dict)
|
||||||
|
else "",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen.add(dst)
|
||||||
|
|
||||||
|
for staged in self._staged_dirs():
|
||||||
|
if not staged or not staged.startswith(DOCUMENTS_ROOT):
|
||||||
|
continue
|
||||||
|
if staged == normalized:
|
||||||
|
continue
|
||||||
|
if not _is_under(staged, normalized):
|
||||||
|
continue
|
||||||
|
rel = (
|
||||||
|
staged[len(normalized) :].lstrip("/")
|
||||||
|
if normalized != "/"
|
||||||
|
else staged.lstrip("/")
|
||||||
|
)
|
||||||
|
if not rel:
|
||||||
|
continue
|
||||||
|
first = rel.split("/", 1)[0]
|
||||||
|
immediate = (
|
||||||
|
normalized.rstrip("/") + "/" + first
|
||||||
|
if normalized != "/"
|
||||||
|
else "/" + first
|
||||||
|
)
|
||||||
|
subdir_paths.add(immediate)
|
||||||
|
|
||||||
|
for sub in sorted(subdir_paths):
|
||||||
|
if sub in seen:
|
||||||
|
continue
|
||||||
|
infos.append(FileInfo(path=sub, is_dir=True, size=0, modified_at=""))
|
||||||
|
seen.add(sub)
|
||||||
|
|
||||||
|
for path_key, fd in files.items():
|
||||||
|
if not isinstance(path_key, str) or path_key in seen:
|
||||||
|
continue
|
||||||
|
if not _is_under(path_key, normalized) or path_key == normalized:
|
||||||
|
continue
|
||||||
|
if normalized == "/":
|
||||||
|
rel = path_key.lstrip("/")
|
||||||
|
else:
|
||||||
|
rel = path_key[len(normalized) :].lstrip("/")
|
||||||
|
if not rel:
|
||||||
|
continue
|
||||||
|
if "/" in rel:
|
||||||
|
first = rel.split("/", 1)[0]
|
||||||
|
immediate = (
|
||||||
|
normalized.rstrip("/") + "/" + first
|
||||||
|
if normalized != "/"
|
||||||
|
else "/" + first
|
||||||
|
)
|
||||||
|
if immediate not in seen:
|
||||||
|
infos.append(
|
||||||
|
FileInfo(path=immediate, is_dir=True, size=0, modified_at="")
|
||||||
|
)
|
||||||
|
seen.add(immediate)
|
||||||
|
continue
|
||||||
|
include = path_key.startswith(DOCUMENTS_ROOT) or _basename(
|
||||||
|
path_key
|
||||||
|
).startswith(_TEMP_PREFIX)
|
||||||
|
if not include:
|
||||||
|
continue
|
||||||
|
size = self._file_data_size(fd) if isinstance(fd, dict) else 0
|
||||||
|
infos.append(
|
||||||
|
FileInfo(
|
||||||
|
path=path_key,
|
||||||
|
is_dir=False,
|
||||||
|
size=int(size),
|
||||||
|
modified_at=fd.get("modified_at", "")
|
||||||
|
if isinstance(fd, dict)
|
||||||
|
else "",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen.add(path_key)
|
||||||
|
|
||||||
|
infos.sort(key=lambda fi: (not fi.get("is_dir", False), fi.get("path", "")))
|
||||||
|
return infos
|
||||||
|
|
||||||
|
def ls_info(self, path: str) -> list[FileInfo]: # type: ignore[override]
|
||||||
|
return asyncio.run(self.als_info(path))
|
||||||
|
|
||||||
|
async def _list_db_directory(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
normalized_path: str,
|
||||||
|
) -> tuple[list[FileInfo], set[str]]:
|
||||||
|
"""List immediate Folders + Documents at ``normalized_path``.
|
||||||
|
|
||||||
|
Returns ``(file_infos, subdirectory_paths)``. ``normalized_path`` may
|
||||||
|
be ``/`` (synthesizes ``/documents``) or a path under ``/documents``.
|
||||||
|
"""
|
||||||
|
if normalized_path == "/":
|
||||||
|
return (
|
||||||
|
[],
|
||||||
|
{DOCUMENTS_ROOT},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not normalized_path.startswith(DOCUMENTS_ROOT):
|
||||||
|
return [], set()
|
||||||
|
|
||||||
|
index = await build_path_index(session, self.search_space_id)
|
||||||
|
target_folder_id: int | None = None
|
||||||
|
if normalized_path != DOCUMENTS_ROOT:
|
||||||
|
target_path = normalized_path
|
||||||
|
matches = [
|
||||||
|
fid for fid, fpath in index.folder_paths.items() if fpath == target_path
|
||||||
|
]
|
||||||
|
if not matches:
|
||||||
|
return [], set()
|
||||||
|
target_folder_id = matches[0]
|
||||||
|
|
||||||
|
result = await session.execute(
|
||||||
|
select(Document.id, Document.title, Document.folder_id, Document.updated_at)
|
||||||
|
.where(Document.search_space_id == self.search_space_id)
|
||||||
|
.where(
|
||||||
|
Document.folder_id == target_folder_id
|
||||||
|
if target_folder_id is not None
|
||||||
|
else Document.folder_id.is_(None)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rows = result.all()
|
||||||
|
|
||||||
|
file_infos: list[FileInfo] = []
|
||||||
|
for row in rows:
|
||||||
|
path = doc_to_virtual_path(
|
||||||
|
doc_id=row.id,
|
||||||
|
title=str(row.title or "untitled"),
|
||||||
|
folder_id=row.folder_id,
|
||||||
|
index=index,
|
||||||
|
)
|
||||||
|
modified = ""
|
||||||
|
if row.updated_at is not None:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
modified = row.updated_at.astimezone(UTC).isoformat()
|
||||||
|
file_infos.append(
|
||||||
|
FileInfo(
|
||||||
|
path=path,
|
||||||
|
is_dir=False,
|
||||||
|
size=0,
|
||||||
|
modified_at=modified,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
subdirs: set[str] = set()
|
||||||
|
for _fid, fpath in index.folder_paths.items():
|
||||||
|
if fpath == normalized_path:
|
||||||
|
continue
|
||||||
|
base = normalized_path.rstrip("/")
|
||||||
|
if not fpath.startswith(base + "/"):
|
||||||
|
continue
|
||||||
|
rel = fpath[len(base) + 1 :]
|
||||||
|
if "/" in rel:
|
||||||
|
continue
|
||||||
|
subdirs.add(base + "/" + rel)
|
||||||
|
return file_infos, subdirs
|
||||||
|
|
||||||
|
async def aread( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 2000,
|
||||||
|
) -> str:
|
||||||
|
files = self._state_files()
|
||||||
|
file_data = files.get(file_path)
|
||||||
|
if file_data is not None:
|
||||||
|
return format_read_response(file_data, offset, limit)
|
||||||
|
|
||||||
|
loaded = await self._load_file_data(file_path)
|
||||||
|
if loaded is None:
|
||||||
|
return f"Error: File '{file_path}' not found"
|
||||||
|
file_data, _ = loaded
|
||||||
|
return format_read_response(file_data, offset, limit)
|
||||||
|
|
||||||
|
def read(self, file_path: str, offset: int = 0, limit: int = 2000) -> str: # type: ignore[override]
|
||||||
|
return asyncio.run(self.aread(file_path, offset, limit))
|
||||||
|
|
||||||
|
async def _load_file_data(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
) -> tuple[dict[str, Any], int | None] | None:
|
||||||
|
"""Lazy-load a virtual KB document into a deepagents ``FileData``.
|
||||||
|
|
||||||
|
Returns ``(file_data, doc_id)`` or ``None`` if the path doesn't map
|
||||||
|
to any known document. ``doc_id`` is ``None`` for the synthetic
|
||||||
|
anonymous document so the caller doesn't track it as a DB-backed file.
|
||||||
|
"""
|
||||||
|
anon = self._kb_anon_doc()
|
||||||
|
if anon and str(anon.get("path") or "") == path:
|
||||||
|
doc_payload = {
|
||||||
|
"document_id": -1,
|
||||||
|
"chunks": list(anon.get("chunks") or []),
|
||||||
|
"matched_chunk_ids": [],
|
||||||
|
"document": {
|
||||||
|
"id": -1,
|
||||||
|
"title": anon.get("title") or "uploaded_document",
|
||||||
|
"document_type": "FILE",
|
||||||
|
"metadata": {"source": "anonymous_upload"},
|
||||||
|
},
|
||||||
|
"source": "FILE",
|
||||||
|
}
|
||||||
|
xml = build_document_xml(doc_payload, matched_chunk_ids=set())
|
||||||
|
file_data = create_file_data(xml)
|
||||||
|
return file_data, None
|
||||||
|
|
||||||
|
if not path.startswith(DOCUMENTS_ROOT):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async with shielded_async_session() as session:
|
||||||
|
document = await virtual_path_to_doc(
|
||||||
|
session,
|
||||||
|
search_space_id=self.search_space_id,
|
||||||
|
virtual_path=path,
|
||||||
|
)
|
||||||
|
if document is None:
|
||||||
|
return None
|
||||||
|
chunk_rows = await session.execute(
|
||||||
|
select(Chunk.id, Chunk.content)
|
||||||
|
.where(Chunk.document_id == document.id)
|
||||||
|
.order_by(Chunk.id)
|
||||||
|
)
|
||||||
|
chunks = [
|
||||||
|
{"chunk_id": row.id, "content": row.content} for row in chunk_rows.all()
|
||||||
|
]
|
||||||
|
|
||||||
|
doc_payload = {
|
||||||
|
"document_id": document.id,
|
||||||
|
"chunks": chunks,
|
||||||
|
"matched_chunk_ids": list(self._matched_chunk_ids(document.id)),
|
||||||
|
"document": {
|
||||||
|
"id": document.id,
|
||||||
|
"title": document.title,
|
||||||
|
"document_type": (
|
||||||
|
document.document_type.value
|
||||||
|
if getattr(document, "document_type", None) is not None
|
||||||
|
else "UNKNOWN"
|
||||||
|
),
|
||||||
|
"metadata": dict(document.document_metadata or {}),
|
||||||
|
},
|
||||||
|
"source": (
|
||||||
|
document.document_type.value
|
||||||
|
if getattr(document, "document_type", None) is not None
|
||||||
|
else "UNKNOWN"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
xml = build_document_xml(
|
||||||
|
doc_payload,
|
||||||
|
matched_chunk_ids=self._matched_chunk_ids(document.id),
|
||||||
|
)
|
||||||
|
file_data = create_file_data(xml)
|
||||||
|
return file_data, document.id
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ writes
|
||||||
|
|
||||||
|
async def awrite(self, file_path: str, content: str) -> WriteResult: # type: ignore[override]
|
||||||
|
files = self._state_files()
|
||||||
|
if file_path in files:
|
||||||
|
return WriteResult(
|
||||||
|
error=(
|
||||||
|
f"Cannot write to {file_path} because it already exists. "
|
||||||
|
"Read and then make an edit, or write to a new path."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
new_file_data = create_file_data(content)
|
||||||
|
return WriteResult(path=file_path, files_update={file_path: new_file_data})
|
||||||
|
|
||||||
|
def write(self, file_path: str, content: str) -> WriteResult: # type: ignore[override]
|
||||||
|
return asyncio.run(self.awrite(file_path, content))
|
||||||
|
|
||||||
|
async def aedit( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
old_string: str,
|
||||||
|
new_string: str,
|
||||||
|
replace_all: bool = False,
|
||||||
|
) -> EditResult:
|
||||||
|
files = self._state_files()
|
||||||
|
file_data = files.get(file_path)
|
||||||
|
if file_data is None:
|
||||||
|
loaded = await self._load_file_data(file_path)
|
||||||
|
if loaded is None:
|
||||||
|
return EditResult(error=f"Error: File '{file_path}' not found")
|
||||||
|
file_data, _ = loaded
|
||||||
|
|
||||||
|
content = file_data_to_string(file_data)
|
||||||
|
result = perform_string_replacement(
|
||||||
|
content, old_string, new_string, replace_all
|
||||||
|
)
|
||||||
|
if isinstance(result, str):
|
||||||
|
return EditResult(error=result)
|
||||||
|
|
||||||
|
new_content, occurrences = result
|
||||||
|
new_file_data = update_file_data(file_data, new_content)
|
||||||
|
return EditResult(
|
||||||
|
path=file_path,
|
||||||
|
files_update={file_path: new_file_data},
|
||||||
|
occurrences=int(occurrences),
|
||||||
|
)
|
||||||
|
|
||||||
|
def edit( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
old_string: str,
|
||||||
|
new_string: str,
|
||||||
|
replace_all: bool = False,
|
||||||
|
) -> EditResult:
|
||||||
|
return asyncio.run(self.aedit(file_path, old_string, new_string, replace_all))
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ glob/grep
|
||||||
|
|
||||||
|
async def aglob_info(self, pattern: str, path: str = "/") -> list[FileInfo]: # type: ignore[override]
|
||||||
|
normalized = self._normalize_listing_path(path)
|
||||||
|
results: list[FileInfo] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
|
||||||
|
files = self._state_files()
|
||||||
|
moved_removed, _ = self._moved_view_paths(files)
|
||||||
|
regex = re.compile(fnmatch.translate(pattern))
|
||||||
|
for path_key, fd in files.items():
|
||||||
|
if path_key in moved_removed:
|
||||||
|
continue
|
||||||
|
if not _is_under(path_key, normalized):
|
||||||
|
continue
|
||||||
|
rel = (
|
||||||
|
path_key[len(normalized) :].lstrip("/")
|
||||||
|
if normalized != "/"
|
||||||
|
else path_key.lstrip("/")
|
||||||
|
)
|
||||||
|
if not regex.match(rel) and not regex.match(path_key):
|
||||||
|
continue
|
||||||
|
if path_key in seen:
|
||||||
|
continue
|
||||||
|
size = self._file_data_size(fd) if isinstance(fd, dict) else 0
|
||||||
|
results.append(
|
||||||
|
FileInfo(
|
||||||
|
path=path_key,
|
||||||
|
is_dir=False,
|
||||||
|
size=int(size),
|
||||||
|
modified_at=fd.get("modified_at", "")
|
||||||
|
if isinstance(fd, dict)
|
||||||
|
else "",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen.add(path_key)
|
||||||
|
|
||||||
|
if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/":
|
||||||
|
try:
|
||||||
|
async with shielded_async_session() as session:
|
||||||
|
index = await build_path_index(session, self.search_space_id)
|
||||||
|
rows = await session.execute(
|
||||||
|
select(Document.id, Document.title, Document.folder_id).where(
|
||||||
|
Document.search_space_id == self.search_space_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for row in rows.all():
|
||||||
|
candidate = doc_to_virtual_path(
|
||||||
|
doc_id=row.id,
|
||||||
|
title=str(row.title or "untitled"),
|
||||||
|
folder_id=row.folder_id,
|
||||||
|
index=index,
|
||||||
|
)
|
||||||
|
if candidate in seen or candidate in moved_removed:
|
||||||
|
continue
|
||||||
|
if not _is_under(candidate, normalized):
|
||||||
|
continue
|
||||||
|
rel = (
|
||||||
|
candidate[len(normalized) :].lstrip("/")
|
||||||
|
if normalized != "/"
|
||||||
|
else candidate.lstrip("/")
|
||||||
|
)
|
||||||
|
if not regex.match(rel) and not regex.match(candidate):
|
||||||
|
continue
|
||||||
|
results.append(
|
||||||
|
FileInfo(
|
||||||
|
path=candidate, is_dir=False, size=0, modified_at=""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen.add(candidate)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive
|
||||||
|
logger.warning("KBPostgresBackend.aglob_info DB error: %s", exc)
|
||||||
|
|
||||||
|
results.sort(key=lambda fi: fi.get("path", ""))
|
||||||
|
return results
|
||||||
|
|
||||||
|
def glob_info(self, pattern: str, path: str = "/") -> list[FileInfo]: # type: ignore[override]
|
||||||
|
return asyncio.run(self.aglob_info(pattern, path))
|
||||||
|
|
||||||
|
async def agrep_raw( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
pattern: str,
|
||||||
|
path: str | None = None,
|
||||||
|
glob: str | None = None,
|
||||||
|
) -> list[GrepMatch] | str:
|
||||||
|
if not pattern:
|
||||||
|
return "Error: pattern cannot be empty"
|
||||||
|
|
||||||
|
normalized = self._normalize_listing_path(path or "/")
|
||||||
|
matches: list[GrepMatch] = []
|
||||||
|
|
||||||
|
files = self._state_files()
|
||||||
|
moved_removed, _ = self._moved_view_paths(files)
|
||||||
|
glob_re = re.compile(fnmatch.translate(glob)) if glob else None
|
||||||
|
for path_key, fd in files.items():
|
||||||
|
if path_key in moved_removed:
|
||||||
|
continue
|
||||||
|
if not _is_under(path_key, normalized):
|
||||||
|
continue
|
||||||
|
if glob_re is not None and not glob_re.match(_basename(path_key)):
|
||||||
|
continue
|
||||||
|
if not isinstance(fd, dict):
|
||||||
|
continue
|
||||||
|
for line_no, line in enumerate(fd.get("content") or [], 1):
|
||||||
|
if pattern in line:
|
||||||
|
matches.append(
|
||||||
|
GrepMatch(path=path_key, line=int(line_no), text=str(line))
|
||||||
|
)
|
||||||
|
if len(matches) >= _GREP_MAX_TOTAL_MATCHES:
|
||||||
|
return matches
|
||||||
|
|
||||||
|
if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/":
|
||||||
|
try:
|
||||||
|
async with shielded_async_session() as session:
|
||||||
|
index = await build_path_index(session, self.search_space_id)
|
||||||
|
sub = (
|
||||||
|
select(Chunk.document_id, Chunk.id, Chunk.content)
|
||||||
|
.join(Document, Document.id == Chunk.document_id)
|
||||||
|
.where(Document.search_space_id == self.search_space_id)
|
||||||
|
.where(Chunk.content.ilike(f"%{pattern}%"))
|
||||||
|
.order_by(Chunk.document_id, Chunk.id)
|
||||||
|
)
|
||||||
|
chunk_rows = await session.execute(sub)
|
||||||
|
per_doc: dict[int, int] = {}
|
||||||
|
doc_id_to_path: dict[int, str] = {}
|
||||||
|
needed_doc_ids: set[int] = set()
|
||||||
|
chunk_buffer: list[tuple[int, int, str]] = []
|
||||||
|
for row in chunk_rows.all():
|
||||||
|
per_doc.setdefault(row.document_id, 0)
|
||||||
|
if per_doc[row.document_id] >= _GREP_MAX_PER_DOC:
|
||||||
|
continue
|
||||||
|
per_doc[row.document_id] += 1
|
||||||
|
chunk_buffer.append((row.document_id, row.id, row.content))
|
||||||
|
needed_doc_ids.add(row.document_id)
|
||||||
|
if sum(per_doc.values()) >= _GREP_MAX_TOTAL_MATCHES - len(
|
||||||
|
matches
|
||||||
|
):
|
||||||
|
break
|
||||||
|
if needed_doc_ids:
|
||||||
|
doc_rows = await session.execute(
|
||||||
|
select(
|
||||||
|
Document.id, Document.title, Document.folder_id
|
||||||
|
).where(Document.id.in_(list(needed_doc_ids)))
|
||||||
|
)
|
||||||
|
for row in doc_rows.all():
|
||||||
|
doc_id_to_path[row.id] = doc_to_virtual_path(
|
||||||
|
doc_id=row.id,
|
||||||
|
title=str(row.title or "untitled"),
|
||||||
|
folder_id=row.folder_id,
|
||||||
|
index=index,
|
||||||
|
)
|
||||||
|
for doc_id, chunk_id, content in chunk_buffer:
|
||||||
|
candidate = doc_id_to_path.get(doc_id)
|
||||||
|
if not candidate or candidate in moved_removed:
|
||||||
|
continue
|
||||||
|
if not _is_under(candidate, normalized):
|
||||||
|
continue
|
||||||
|
if glob_re is not None and not glob_re.match(
|
||||||
|
_basename(candidate)
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
snippet = " ".join(str(content).split())[:240]
|
||||||
|
matches.append(
|
||||||
|
GrepMatch(
|
||||||
|
path=candidate,
|
||||||
|
line=0,
|
||||||
|
text=(
|
||||||
|
f"<chunk-match in {candidate} chunk_id={chunk_id}>: "
|
||||||
|
f"{snippet}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if len(matches) >= _GREP_MAX_TOTAL_MATCHES:
|
||||||
|
break
|
||||||
|
except Exception as exc: # pragma: no cover - defensive
|
||||||
|
logger.warning("KBPostgresBackend.agrep_raw DB error: %s", exc)
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|
||||||
|
def grep_raw( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
pattern: str,
|
||||||
|
path: str | None = None,
|
||||||
|
glob: str | None = None,
|
||||||
|
) -> list[GrepMatch] | str:
|
||||||
|
return asyncio.run(self.agrep_raw(pattern, path, glob))
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ list_tree (helper)
|
||||||
|
|
||||||
|
async def alist_tree_listing(
|
||||||
|
self,
|
||||||
|
path: str = DOCUMENTS_ROOT,
|
||||||
|
*,
|
||||||
|
max_depth: int | None = 8,
|
||||||
|
page_size: int = 500,
|
||||||
|
include_files: bool = True,
|
||||||
|
include_dirs: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Recursive tree listing for cloud mode.
|
||||||
|
|
||||||
|
Mirrors the shape returned by :class:`MultiRootLocalFolderBackend.list_tree`:
|
||||||
|
``{"entries": [{path, is_dir, size, modified_at, depth}, ...], "truncated": bool}``.
|
||||||
|
"""
|
||||||
|
normalized = self._normalize_listing_path(path or DOCUMENTS_ROOT)
|
||||||
|
if not normalized.startswith(DOCUMENTS_ROOT) and normalized != "/":
|
||||||
|
return {"error": "Error: path must be under /documents/"}
|
||||||
|
|
||||||
|
entries: list[dict[str, Any]] = []
|
||||||
|
truncated = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with shielded_async_session() as session:
|
||||||
|
index = await build_path_index(session, self.search_space_id)
|
||||||
|
doc_rows_raw = await session.execute(
|
||||||
|
select(
|
||||||
|
Document.id,
|
||||||
|
Document.title,
|
||||||
|
Document.folder_id,
|
||||||
|
Document.updated_at,
|
||||||
|
).where(Document.search_space_id == self.search_space_id)
|
||||||
|
)
|
||||||
|
doc_rows = list(doc_rows_raw.all())
|
||||||
|
except Exception as exc: # pragma: no cover
|
||||||
|
logger.warning("KBPostgresBackend.alist_tree_listing DB error: %s", exc)
|
||||||
|
return {"entries": [], "truncated": False}
|
||||||
|
|
||||||
|
files = self._state_files()
|
||||||
|
moved_removed, _ = self._moved_view_paths(files)
|
||||||
|
anon = self._kb_anon_doc()
|
||||||
|
anon_path = str(anon.get("path") or "") if anon else ""
|
||||||
|
|
||||||
|
def _depth_of(p: str) -> int:
|
||||||
|
if p == DOCUMENTS_ROOT:
|
||||||
|
return 0
|
||||||
|
rel_root = (
|
||||||
|
p[len(DOCUMENTS_ROOT) :].lstrip("/")
|
||||||
|
if normalized.startswith(DOCUMENTS_ROOT)
|
||||||
|
else p.lstrip("/")
|
||||||
|
)
|
||||||
|
return len([part for part in rel_root.split("/") if part])
|
||||||
|
|
||||||
|
def _add_entry(entry: dict[str, Any]) -> bool:
|
||||||
|
nonlocal truncated
|
||||||
|
if len(entries) >= page_size:
|
||||||
|
truncated = True
|
||||||
|
return False
|
||||||
|
entries.append(entry)
|
||||||
|
return True
|
||||||
|
|
||||||
|
if include_dirs:
|
||||||
|
for _fid, fpath in sorted(index.folder_paths.items(), key=lambda kv: kv[1]):
|
||||||
|
if not _is_under(fpath, normalized):
|
||||||
|
continue
|
||||||
|
depth = _depth_of(fpath)
|
||||||
|
if max_depth is not None and depth > max_depth:
|
||||||
|
continue
|
||||||
|
if not _add_entry(
|
||||||
|
{
|
||||||
|
"path": fpath,
|
||||||
|
"is_dir": True,
|
||||||
|
"size": 0,
|
||||||
|
"modified_at": "",
|
||||||
|
"depth": depth,
|
||||||
|
}
|
||||||
|
):
|
||||||
|
return {"entries": entries, "truncated": True}
|
||||||
|
for staged in self._staged_dirs():
|
||||||
|
if not _is_under(staged, normalized):
|
||||||
|
continue
|
||||||
|
depth = _depth_of(staged)
|
||||||
|
if max_depth is not None and depth > max_depth:
|
||||||
|
continue
|
||||||
|
if any(e["path"] == staged for e in entries):
|
||||||
|
continue
|
||||||
|
if not _add_entry(
|
||||||
|
{
|
||||||
|
"path": staged,
|
||||||
|
"is_dir": True,
|
||||||
|
"size": 0,
|
||||||
|
"modified_at": "",
|
||||||
|
"depth": depth,
|
||||||
|
}
|
||||||
|
):
|
||||||
|
return {"entries": entries, "truncated": True}
|
||||||
|
|
||||||
|
if include_files:
|
||||||
|
for row in sorted(doc_rows, key=lambda r: str(r.title or "")):
|
||||||
|
candidate = doc_to_virtual_path(
|
||||||
|
doc_id=row.id,
|
||||||
|
title=str(row.title or "untitled"),
|
||||||
|
folder_id=row.folder_id,
|
||||||
|
index=index,
|
||||||
|
)
|
||||||
|
if candidate in moved_removed:
|
||||||
|
continue
|
||||||
|
if not _is_under(candidate, normalized):
|
||||||
|
continue
|
||||||
|
depth = _depth_of(candidate)
|
||||||
|
if max_depth is not None and depth > max_depth:
|
||||||
|
continue
|
||||||
|
modified = ""
|
||||||
|
if row.updated_at is not None:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
modified = row.updated_at.astimezone(UTC).isoformat()
|
||||||
|
if not _add_entry(
|
||||||
|
{
|
||||||
|
"path": candidate,
|
||||||
|
"is_dir": False,
|
||||||
|
"size": 0,
|
||||||
|
"modified_at": modified,
|
||||||
|
"depth": depth,
|
||||||
|
}
|
||||||
|
):
|
||||||
|
return {"entries": entries, "truncated": True}
|
||||||
|
|
||||||
|
if anon_path and _is_under(anon_path, normalized):
|
||||||
|
depth = _depth_of(anon_path)
|
||||||
|
if (max_depth is None or depth <= max_depth) and not _add_entry(
|
||||||
|
{
|
||||||
|
"path": anon_path,
|
||||||
|
"is_dir": False,
|
||||||
|
"size": len(str(anon.get("content") or "")),
|
||||||
|
"modified_at": "",
|
||||||
|
"depth": depth,
|
||||||
|
}
|
||||||
|
):
|
||||||
|
return {"entries": entries, "truncated": True}
|
||||||
|
|
||||||
|
for path_key, fd in files.items():
|
||||||
|
if not isinstance(path_key, str):
|
||||||
|
continue
|
||||||
|
if not _is_under(path_key, normalized):
|
||||||
|
continue
|
||||||
|
if any(e["path"] == path_key for e in entries):
|
||||||
|
continue
|
||||||
|
if not (
|
||||||
|
path_key.startswith(DOCUMENTS_ROOT)
|
||||||
|
or _basename(path_key).startswith(_TEMP_PREFIX)
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
depth = _depth_of(path_key)
|
||||||
|
if max_depth is not None and depth > max_depth:
|
||||||
|
continue
|
||||||
|
size = self._file_data_size(fd) if isinstance(fd, dict) else 0
|
||||||
|
if not _add_entry(
|
||||||
|
{
|
||||||
|
"path": path_key,
|
||||||
|
"is_dir": False,
|
||||||
|
"size": int(size),
|
||||||
|
"modified_at": fd.get("modified_at", "")
|
||||||
|
if isinstance(fd, dict)
|
||||||
|
else "",
|
||||||
|
"depth": depth,
|
||||||
|
}
|
||||||
|
):
|
||||||
|
return {"entries": entries, "truncated": True}
|
||||||
|
|
||||||
|
return {"entries": entries, "truncated": truncated}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ uploads (unsupported)
|
||||||
|
|
||||||
|
def upload_files( # type: ignore[override]
|
||||||
|
self, files: list[tuple[str, bytes]]
|
||||||
|
) -> list[FileUploadResponse]:
|
||||||
|
msg = "KBPostgresBackend does not support upload_files."
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
def download_files( # type: ignore[override]
|
||||||
|
self, paths: list[str]
|
||||||
|
) -> list[FileDownloadResponse]:
|
||||||
|
responses: list[FileDownloadResponse] = []
|
||||||
|
files = self._state_files()
|
||||||
|
for path in paths:
|
||||||
|
fd = files.get(path)
|
||||||
|
if fd is None:
|
||||||
|
responses.append(
|
||||||
|
FileDownloadResponse(
|
||||||
|
path=path, content=None, error="file_not_found"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
content_str = file_data_to_string(fd)
|
||||||
|
responses.append(
|
||||||
|
FileDownloadResponse(
|
||||||
|
path=path,
|
||||||
|
content=content_str.encode("utf-8"),
|
||||||
|
error=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return responses
|
||||||
|
|
||||||
|
|
||||||
|
# --- module-level small helpers ---------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def list_tree_listing(
|
||||||
|
backend: KBPostgresBackend,
|
||||||
|
path: str,
|
||||||
|
*,
|
||||||
|
max_depth: int | None = 8,
|
||||||
|
page_size: int = 500,
|
||||||
|
include_files: bool = True,
|
||||||
|
include_dirs: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Async helper used by the overridden ``list_tree`` tool wrapper."""
|
||||||
|
return await backend.alist_tree_listing(
|
||||||
|
path,
|
||||||
|
max_depth=max_depth,
|
||||||
|
page_size=page_size,
|
||||||
|
include_files=include_files,
|
||||||
|
include_dirs=include_dirs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"KBPostgresBackend",
|
||||||
|
"list_tree_listing",
|
||||||
|
"paginate_listing",
|
||||||
|
]
|
||||||
|
|
@ -1,10 +1,24 @@
|
||||||
"""Knowledge-base pre-search middleware for the SurfSense new chat agent.
|
"""Hybrid-search priority middleware for the SurfSense new chat agent.
|
||||||
|
|
||||||
This middleware runs before the main agent loop and seeds a virtual filesystem
|
This middleware runs ``before_agent`` on every turn and writes:
|
||||||
(`files` state) with relevant documents retrieved via hybrid search. On each
|
|
||||||
turn the filesystem is *expanded* — new results merge with documents loaded
|
* ``state["kb_priority"]`` — the top-K most relevant documents for the
|
||||||
during prior turns — and a synthetic ``ls`` result is injected into the message
|
current user message, used to render a ``<priority_documents>`` system
|
||||||
history so the LLM is immediately aware of the current filesystem structure.
|
message immediately before the user turn.
|
||||||
|
* ``state["kb_matched_chunk_ids"]`` — internal hand-off mapping
|
||||||
|
(``Document.id`` → matched chunk IDs) consumed by
|
||||||
|
:class:`KBPostgresBackend._load_file_data` when the agent first reads each
|
||||||
|
document, so the XML wrapper can flag matched sections in
|
||||||
|
``<chunk_index>``.
|
||||||
|
|
||||||
|
The previous "scoped filesystem" behaviour (synthetic ``ls`` + state
|
||||||
|
``files`` seeding) is intentionally removed: documents are now lazy-loaded
|
||||||
|
from Postgres on demand, with the full workspace tree rendered separately
|
||||||
|
by :class:`KnowledgeTreeMiddleware`.
|
||||||
|
|
||||||
|
In anonymous mode the middleware skips hybrid search entirely and emits a
|
||||||
|
single-entry priority list pointing at the Redis-loaded document
|
||||||
|
(``state["kb_anon_doc"]``).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -13,27 +27,30 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import uuid
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
from litellm import token_counter
|
from litellm import token_counter
|
||||||
from pydantic import BaseModel, Field, ValidationError
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
|
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
|
||||||
|
from app.agents.new_chat.path_resolver import (
|
||||||
|
PathIndex,
|
||||||
|
build_path_index,
|
||||||
|
doc_to_virtual_path,
|
||||||
|
)
|
||||||
from app.agents.new_chat.utils import parse_date_or_datetime, resolve_date_range
|
from app.agents.new_chat.utils import parse_date_or_datetime, resolve_date_range
|
||||||
from app.db import (
|
from app.db import (
|
||||||
NATIVE_TO_LEGACY_DOCTYPE,
|
NATIVE_TO_LEGACY_DOCTYPE,
|
||||||
Chunk,
|
Chunk,
|
||||||
Document,
|
Document,
|
||||||
Folder,
|
|
||||||
shielded_async_session,
|
shielded_async_session,
|
||||||
)
|
)
|
||||||
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
|
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||||
|
|
@ -70,7 +87,6 @@ class KBSearchPlan(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
def _extract_text_from_message(message: BaseMessage) -> str:
|
def _extract_text_from_message(message: BaseMessage) -> str:
|
||||||
"""Extract plain text from a message content."""
|
|
||||||
content = getattr(message, "content", "")
|
content = getattr(message, "content", "")
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
return content
|
return content
|
||||||
|
|
@ -85,19 +101,6 @@ def _extract_text_from_message(message: BaseMessage) -> str:
|
||||||
return str(content)
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
def _safe_filename(value: str, *, fallback: str = "untitled.xml") -> str:
|
|
||||||
"""Convert arbitrary text into a filesystem-safe filename."""
|
|
||||||
name = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
|
|
||||||
name = re.sub(r"\s+", " ", name)
|
|
||||||
if not name:
|
|
||||||
name = fallback
|
|
||||||
if len(name) > 180:
|
|
||||||
name = name[:180].rstrip()
|
|
||||||
if not name.lower().endswith(".xml"):
|
|
||||||
name = f"{name}.xml"
|
|
||||||
return name
|
|
||||||
|
|
||||||
|
|
||||||
def _render_recent_conversation(
|
def _render_recent_conversation(
|
||||||
messages: Sequence[BaseMessage],
|
messages: Sequence[BaseMessage],
|
||||||
*,
|
*,
|
||||||
|
|
@ -107,10 +110,9 @@ def _render_recent_conversation(
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Render recent dialogue for internal planning under a token budget.
|
"""Render recent dialogue for internal planning under a token budget.
|
||||||
|
|
||||||
Prefers the latest messages and uses the project's existing model-aware
|
Filters to ``HumanMessage`` and ``AIMessage`` (without tool_calls) so that
|
||||||
token budgeting hooks when available on the LLM (`_count_tokens`,
|
injected ``SystemMessage`` artefacts (priority list, workspace tree,
|
||||||
`_get_max_input_tokens`). Falls back to the prior fixed-message heuristic
|
file-write contract) don't pollute the planner prompt.
|
||||||
if token counting is unavailable.
|
|
||||||
"""
|
"""
|
||||||
rendered: list[tuple[str, str]] = []
|
rendered: list[tuple[str, str]] = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
|
|
@ -133,8 +135,6 @@ def _render_recent_conversation(
|
||||||
if not rendered:
|
if not rendered:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Exclude the latest user message from "recent conversation" because it is
|
|
||||||
# already passed separately as "Latest user message" in the planner prompt.
|
|
||||||
if rendered and rendered[-1][0] == "user" and rendered[-1][1] == user_text.strip():
|
if rendered and rendered[-1][0] == "user" and rendered[-1][1] == user_text.strip():
|
||||||
rendered = rendered[:-1]
|
rendered = rendered[:-1]
|
||||||
|
|
||||||
|
|
@ -216,8 +216,6 @@ def _render_recent_conversation(
|
||||||
selected_lines = candidate_lines
|
selected_lines = candidate_lines
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# If the full message does not fit, keep as much of this most-recent
|
|
||||||
# older message as possible via binary search.
|
|
||||||
lo, hi = 1, len(text)
|
lo, hi = 1, len(text)
|
||||||
best_line: str | None = None
|
best_line: str | None = None
|
||||||
while lo <= hi:
|
while lo <= hi:
|
||||||
|
|
@ -249,7 +247,6 @@ def _build_kb_planner_prompt(
|
||||||
recent_conversation: str,
|
recent_conversation: str,
|
||||||
user_text: str,
|
user_text: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Build a compact internal prompt for KB query rewriting and date scoping."""
|
|
||||||
today = datetime.now(UTC).date().isoformat()
|
today = datetime.now(UTC).date().isoformat()
|
||||||
return (
|
return (
|
||||||
"You optimize internal knowledge-base search inputs for document retrieval.\n"
|
"You optimize internal knowledge-base search inputs for document retrieval.\n"
|
||||||
|
|
@ -275,12 +272,10 @@ def _build_kb_planner_prompt(
|
||||||
|
|
||||||
|
|
||||||
def _extract_json_payload(text: str) -> str:
|
def _extract_json_payload(text: str) -> str:
|
||||||
"""Extract a JSON object from a raw LLM response."""
|
|
||||||
stripped = text.strip()
|
stripped = text.strip()
|
||||||
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL)
|
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL)
|
||||||
if fenced:
|
if fenced:
|
||||||
return fenced.group(1)
|
return fenced.group(1)
|
||||||
|
|
||||||
start = stripped.find("{")
|
start = stripped.find("{")
|
||||||
end = stripped.rfind("}")
|
end = stripped.rfind("}")
|
||||||
if start != -1 and end != -1 and end > start:
|
if start != -1 and end != -1 and end > start:
|
||||||
|
|
@ -289,7 +284,6 @@ def _extract_json_payload(text: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
def _parse_kb_search_plan_response(response_text: str) -> KBSearchPlan:
|
def _parse_kb_search_plan_response(response_text: str) -> KBSearchPlan:
|
||||||
"""Parse and validate the planner's JSON response."""
|
|
||||||
payload = json.loads(_extract_json_payload(response_text))
|
payload = json.loads(_extract_json_payload(response_text))
|
||||||
return KBSearchPlan.model_validate(payload)
|
return KBSearchPlan.model_validate(payload)
|
||||||
|
|
||||||
|
|
@ -298,212 +292,19 @@ def _normalize_optional_date_range(
|
||||||
start_date: str | None,
|
start_date: str | None,
|
||||||
end_date: str | None,
|
end_date: str | None,
|
||||||
) -> tuple[datetime | None, datetime | None]:
|
) -> tuple[datetime | None, datetime | None]:
|
||||||
"""Normalize optional planner dates into a UTC datetime range."""
|
|
||||||
parsed_start = parse_date_or_datetime(start_date) if start_date else None
|
parsed_start = parse_date_or_datetime(start_date) if start_date else None
|
||||||
parsed_end = parse_date_or_datetime(end_date) if end_date else None
|
parsed_end = parse_date_or_datetime(end_date) if end_date else None
|
||||||
|
|
||||||
if parsed_start is None and parsed_end is None:
|
if parsed_start is None and parsed_end is None:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
resolved_start, resolved_end = resolve_date_range(parsed_start, parsed_end)
|
return resolve_date_range(parsed_start, parsed_end)
|
||||||
return resolved_start, resolved_end
|
|
||||||
|
|
||||||
|
|
||||||
def _build_document_xml(
|
|
||||||
document: dict[str, Any],
|
|
||||||
matched_chunk_ids: set[int] | None = None,
|
|
||||||
) -> str:
|
|
||||||
"""Build citation-friendly XML with a ``<chunk_index>`` for smart seeking.
|
|
||||||
|
|
||||||
The ``<chunk_index>`` at the top of each document lists every chunk with its
|
|
||||||
line range inside ``<document_content>`` and flags chunks that directly
|
|
||||||
matched the search query (``matched="true"``). This lets the LLM jump
|
|
||||||
straight to the most relevant section via ``read_file(offset=…, limit=…)``
|
|
||||||
instead of reading sequentially from the start.
|
|
||||||
"""
|
|
||||||
matched = matched_chunk_ids or set()
|
|
||||||
|
|
||||||
doc_meta = document.get("document") or {}
|
|
||||||
metadata = (doc_meta.get("metadata") or {}) if isinstance(doc_meta, dict) else {}
|
|
||||||
document_id = doc_meta.get("id", document.get("document_id", "unknown"))
|
|
||||||
document_type = doc_meta.get("document_type", document.get("source", "UNKNOWN"))
|
|
||||||
title = doc_meta.get("title") or metadata.get("title") or "Untitled Document"
|
|
||||||
url = (
|
|
||||||
metadata.get("url") or metadata.get("source") or metadata.get("page_url") or ""
|
|
||||||
)
|
|
||||||
metadata_json = json.dumps(metadata, ensure_ascii=False)
|
|
||||||
|
|
||||||
# --- 1. Metadata header (fixed structure) ---
|
|
||||||
metadata_lines: list[str] = [
|
|
||||||
"<document>",
|
|
||||||
"<document_metadata>",
|
|
||||||
f" <document_id>{document_id}</document_id>",
|
|
||||||
f" <document_type>{document_type}</document_type>",
|
|
||||||
f" <title><![CDATA[{title}]]></title>",
|
|
||||||
f" <url><![CDATA[{url}]]></url>",
|
|
||||||
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
|
|
||||||
"</document_metadata>",
|
|
||||||
"",
|
|
||||||
]
|
|
||||||
|
|
||||||
# --- 2. Pre-build chunk XML strings to compute line counts ---
|
|
||||||
chunks = document.get("chunks") or []
|
|
||||||
chunk_entries: list[tuple[int | None, str]] = [] # (chunk_id, xml_string)
|
|
||||||
if isinstance(chunks, list):
|
|
||||||
for chunk in chunks:
|
|
||||||
if not isinstance(chunk, dict):
|
|
||||||
continue
|
|
||||||
chunk_id = chunk.get("chunk_id") or chunk.get("id")
|
|
||||||
chunk_content = str(chunk.get("content", "")).strip()
|
|
||||||
if not chunk_content:
|
|
||||||
continue
|
|
||||||
if chunk_id is None:
|
|
||||||
xml = f" <chunk><![CDATA[{chunk_content}]]></chunk>"
|
|
||||||
else:
|
|
||||||
xml = f" <chunk id='{chunk_id}'><![CDATA[{chunk_content}]]></chunk>"
|
|
||||||
chunk_entries.append((chunk_id, xml))
|
|
||||||
|
|
||||||
# --- 3. Compute line numbers for every chunk ---
|
|
||||||
# Layout (1-indexed lines for read_file):
|
|
||||||
# metadata_lines -> len(metadata_lines) lines
|
|
||||||
# <chunk_index> -> 1 line
|
|
||||||
# index entries -> len(chunk_entries) lines
|
|
||||||
# </chunk_index> -> 1 line
|
|
||||||
# (empty line) -> 1 line
|
|
||||||
# <document_content> -> 1 line
|
|
||||||
# chunk xml lines…
|
|
||||||
# </document_content> -> 1 line
|
|
||||||
# </document> -> 1 line
|
|
||||||
index_overhead = (
|
|
||||||
1 + len(chunk_entries) + 1 + 1 + 1
|
|
||||||
) # tags + empty + <document_content>
|
|
||||||
first_chunk_line = len(metadata_lines) + index_overhead + 1 # 1-indexed
|
|
||||||
|
|
||||||
current_line = first_chunk_line
|
|
||||||
index_entry_lines: list[str] = []
|
|
||||||
for cid, xml_str in chunk_entries:
|
|
||||||
num_lines = xml_str.count("\n") + 1
|
|
||||||
end_line = current_line + num_lines - 1
|
|
||||||
matched_attr = ' matched="true"' if cid is not None and cid in matched else ""
|
|
||||||
if cid is not None:
|
|
||||||
index_entry_lines.append(
|
|
||||||
f' <entry chunk_id="{cid}" lines="{current_line}-{end_line}"{matched_attr}/>'
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
index_entry_lines.append(
|
|
||||||
f' <entry lines="{current_line}-{end_line}"{matched_attr}/>'
|
|
||||||
)
|
|
||||||
current_line = end_line + 1
|
|
||||||
|
|
||||||
# --- 4. Assemble final XML ---
|
|
||||||
lines = metadata_lines.copy()
|
|
||||||
lines.append("<chunk_index>")
|
|
||||||
lines.extend(index_entry_lines)
|
|
||||||
lines.append("</chunk_index>")
|
|
||||||
lines.append("")
|
|
||||||
lines.append("<document_content>")
|
|
||||||
for _, xml_str in chunk_entries:
|
|
||||||
lines.append(xml_str)
|
|
||||||
lines.extend(["</document_content>", "</document>"])
|
|
||||||
return "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_folder_paths(
|
|
||||||
session: AsyncSession, search_space_id: int
|
|
||||||
) -> dict[int, str]:
|
|
||||||
"""Return a map of folder_id -> virtual folder path under /documents."""
|
|
||||||
result = await session.execute(
|
|
||||||
select(Folder.id, Folder.name, Folder.parent_id).where(
|
|
||||||
Folder.search_space_id == search_space_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
rows = result.all()
|
|
||||||
by_id = {row.id: {"name": row.name, "parent_id": row.parent_id} for row in rows}
|
|
||||||
|
|
||||||
cache: dict[int, str] = {}
|
|
||||||
|
|
||||||
def resolve_path(folder_id: int) -> str:
|
|
||||||
if folder_id in cache:
|
|
||||||
return cache[folder_id]
|
|
||||||
parts: list[str] = []
|
|
||||||
cursor: int | None = folder_id
|
|
||||||
visited: set[int] = set()
|
|
||||||
while cursor is not None and cursor in by_id and cursor not in visited:
|
|
||||||
visited.add(cursor)
|
|
||||||
entry = by_id[cursor]
|
|
||||||
parts.append(
|
|
||||||
_safe_filename(str(entry["name"]), fallback="folder").removesuffix(
|
|
||||||
".xml"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
cursor = entry["parent_id"]
|
|
||||||
parts.reverse()
|
|
||||||
path = "/documents/" + "/".join(parts) if parts else "/documents"
|
|
||||||
cache[folder_id] = path
|
|
||||||
return path
|
|
||||||
|
|
||||||
for folder_id in by_id:
|
|
||||||
resolve_path(folder_id)
|
|
||||||
return cache
|
|
||||||
|
|
||||||
|
|
||||||
def _build_synthetic_ls(
|
|
||||||
existing_files: dict[str, Any] | None,
|
|
||||||
new_files: dict[str, Any],
|
|
||||||
*,
|
|
||||||
mentioned_paths: set[str] | None = None,
|
|
||||||
) -> tuple[AIMessage, ToolMessage]:
|
|
||||||
"""Build a synthetic ls("/documents") tool-call + result for the LLM context.
|
|
||||||
|
|
||||||
Mentioned files are listed first. A separate header tells the LLM which
|
|
||||||
files the user explicitly selected; the path list itself stays clean so
|
|
||||||
paths can be passed directly to ``read_file`` without stripping tags.
|
|
||||||
"""
|
|
||||||
_mentioned = mentioned_paths or set()
|
|
||||||
merged: dict[str, Any] = {**(existing_files or {}), **new_files}
|
|
||||||
doc_paths = [
|
|
||||||
p for p, v in merged.items() if p.startswith("/documents/") and v is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
new_set = set(new_files)
|
|
||||||
mentioned_list = [p for p in doc_paths if p in _mentioned]
|
|
||||||
new_non_mentioned = [p for p in doc_paths if p in new_set and p not in _mentioned]
|
|
||||||
old_paths = [p for p in doc_paths if p not in new_set]
|
|
||||||
ordered = mentioned_list + new_non_mentioned + old_paths
|
|
||||||
|
|
||||||
parts: list[str] = []
|
|
||||||
if mentioned_list:
|
|
||||||
parts.append(
|
|
||||||
"USER-MENTIONED documents (read these thoroughly before answering):"
|
|
||||||
)
|
|
||||||
for p in mentioned_list:
|
|
||||||
parts.append(f" {p}")
|
|
||||||
parts.append("")
|
|
||||||
parts.append(str(ordered) if ordered else "No documents found.")
|
|
||||||
|
|
||||||
tool_call_id = f"auto_ls_{uuid.uuid4().hex[:12]}"
|
|
||||||
ai_msg = AIMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[{"name": "ls", "args": {"path": "/documents"}, "id": tool_call_id}],
|
|
||||||
)
|
|
||||||
tool_msg = ToolMessage(
|
|
||||||
content="\n".join(parts),
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
)
|
|
||||||
return ai_msg, tool_msg
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_search_types(
|
def _resolve_search_types(
|
||||||
available_connectors: list[str] | None,
|
available_connectors: list[str] | None,
|
||||||
available_document_types: list[str] | None,
|
available_document_types: list[str] | None,
|
||||||
) -> list[str] | None:
|
) -> list[str] | None:
|
||||||
"""Build a flat list of document-type strings for the chunk retriever.
|
|
||||||
|
|
||||||
Includes legacy equivalents from ``NATIVE_TO_LEGACY_DOCTYPE`` so that
|
|
||||||
old documents indexed under Composio names are still found.
|
|
||||||
|
|
||||||
Returns ``None`` when no filtering is desired (search all types).
|
|
||||||
"""
|
|
||||||
types: set[str] = set()
|
types: set[str] = set()
|
||||||
if available_document_types:
|
if available_document_types:
|
||||||
types.update(available_document_types)
|
types.update(available_document_types)
|
||||||
|
|
@ -531,13 +332,8 @@ async def browse_recent_documents(
|
||||||
start_date: datetime | None = None,
|
start_date: datetime | None = None,
|
||||||
end_date: datetime | None = None,
|
end_date: datetime | None = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Return documents ordered by recency (newest first), no relevance ranking.
|
"""Return documents ordered by recency (newest first), no relevance ranking."""
|
||||||
|
from sqlalchemy import func
|
||||||
Used when the user's intent is temporal ("latest file", "most recent upload")
|
|
||||||
and hybrid search would produce poor results because the query has no
|
|
||||||
meaningful topical signal.
|
|
||||||
"""
|
|
||||||
from sqlalchemy import func, select
|
|
||||||
|
|
||||||
from app.db import DocumentType
|
from app.db import DocumentType
|
||||||
|
|
||||||
|
|
@ -581,7 +377,6 @@ async def browse_recent_documents(
|
||||||
return []
|
return []
|
||||||
|
|
||||||
doc_ids = [d.id for d in documents]
|
doc_ids = [d.id for d in documents]
|
||||||
|
|
||||||
numbered = (
|
numbered = (
|
||||||
select(
|
select(
|
||||||
Chunk.id.label("chunk_id"),
|
Chunk.id.label("chunk_id"),
|
||||||
|
|
@ -632,6 +427,7 @@ async def browse_recent_documents(
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
|
"folder_id": getattr(doc, "folder_id", None),
|
||||||
},
|
},
|
||||||
"source": (
|
"source": (
|
||||||
doc.document_type.value
|
doc.document_type.value
|
||||||
|
|
@ -640,12 +436,6 @@ async def browse_recent_documents(
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"browse_recent_documents: %d docs returned for space=%d",
|
|
||||||
len(results),
|
|
||||||
search_space_id,
|
|
||||||
)
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -659,17 +449,11 @@ async def search_knowledge_base(
|
||||||
start_date: datetime | None = None,
|
start_date: datetime | None = None,
|
||||||
end_date: datetime | None = None,
|
end_date: datetime | None = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Run a single unified hybrid search against the knowledge base.
|
"""Run a single unified hybrid search against the knowledge base."""
|
||||||
|
|
||||||
Uses one ``ChucksHybridSearchRetriever`` call across all document types
|
|
||||||
instead of fanning out per-connector. This reduces the number of DB
|
|
||||||
queries from ~10 to 2 (one RRF query + one chunk fetch).
|
|
||||||
"""
|
|
||||||
if not query:
|
if not query:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
[embedding] = embed_texts([query])
|
[embedding] = embed_texts([query])
|
||||||
|
|
||||||
doc_types = _resolve_search_types(available_connectors, available_document_types)
|
doc_types = _resolve_search_types(available_connectors, available_document_types)
|
||||||
retriever_top_k = min(top_k * 3, 30)
|
retriever_top_k = min(top_k * 3, 30)
|
||||||
|
|
||||||
|
|
@ -693,14 +477,7 @@ async def fetch_mentioned_documents(
|
||||||
document_ids: list[int],
|
document_ids: list[int],
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Fetch explicitly mentioned documents with *all* their chunks.
|
"""Fetch explicitly mentioned documents."""
|
||||||
|
|
||||||
Returns the same dict structure as ``search_knowledge_base`` so results
|
|
||||||
can be merged directly into ``build_scoped_filesystem``. Unlike search
|
|
||||||
results, every chunk is included (no top-K limiting) and none are marked
|
|
||||||
as ``matched`` since the entire document is relevant by virtue of the
|
|
||||||
user's explicit mention.
|
|
||||||
"""
|
|
||||||
if not document_ids:
|
if not document_ids:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -750,6 +527,7 @@ async def fetch_mentioned_documents(
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
|
"folder_id": getattr(doc, "folder_id", None),
|
||||||
},
|
},
|
||||||
"source": (
|
"source": (
|
||||||
doc.document_type.value
|
doc.document_type.value
|
||||||
|
|
@ -762,96 +540,36 @@ async def fetch_mentioned_documents(
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
async def build_scoped_filesystem(
|
def _render_priority_message(priority: list[dict[str, Any]]) -> SystemMessage:
|
||||||
*,
|
"""Render the priority list as a single ``<priority_documents>`` system message."""
|
||||||
documents: Sequence[dict[str, Any]],
|
if not priority:
|
||||||
search_space_id: int,
|
body = "(no priority documents for this turn)"
|
||||||
) -> tuple[dict[str, dict[str, str]], dict[int, str]]:
|
else:
|
||||||
"""Build a StateBackend-compatible files dict from search results.
|
lines: list[str] = []
|
||||||
|
for entry in priority:
|
||||||
Returns ``(files, doc_id_to_path)`` so callers can reliably map a
|
score = entry.get("score")
|
||||||
document id back to its filesystem path without guessing by title.
|
mentioned = entry.get("mentioned")
|
||||||
Paths are collision-proof: when two documents resolve to the same
|
score_str = f"{score:.3f}" if isinstance(score, (int, float)) else "n/a"
|
||||||
path the doc-id is appended to disambiguate.
|
mark = " [USER-MENTIONED]" if mentioned else ""
|
||||||
"""
|
lines.append(f"- {entry.get('path', '')} (score={score_str}){mark}")
|
||||||
async with shielded_async_session() as session:
|
body = "\n".join(lines)
|
||||||
folder_paths = await _get_folder_paths(session, search_space_id)
|
return SystemMessage(
|
||||||
doc_ids = [
|
content=(
|
||||||
(doc.get("document") or {}).get("id")
|
"<priority_documents>\n"
|
||||||
for doc in documents
|
"These documents are most relevant to the latest user message; "
|
||||||
if isinstance(doc, dict)
|
"read them first. Matched sections are flagged inside each "
|
||||||
]
|
"document's <chunk_index>.\n"
|
||||||
doc_ids = [doc_id for doc_id in doc_ids if isinstance(doc_id, int)]
|
f"{body}\n"
|
||||||
folder_by_doc_id: dict[int, int | None] = {}
|
"</priority_documents>"
|
||||||
if doc_ids:
|
)
|
||||||
doc_rows = await session.execute(
|
)
|
||||||
select(Document.id, Document.folder_id).where(
|
|
||||||
Document.search_space_id == search_space_id,
|
|
||||||
Document.id.in_(doc_ids),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
folder_by_doc_id = {
|
|
||||||
row.id: row.folder_id for row in doc_rows.all() if row.id is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
files: dict[str, dict[str, str]] = {}
|
|
||||||
doc_id_to_path: dict[int, str] = {}
|
|
||||||
for document in documents:
|
|
||||||
doc_meta = document.get("document") or {}
|
|
||||||
title = str(doc_meta.get("title") or "untitled")
|
|
||||||
doc_id = doc_meta.get("id")
|
|
||||||
folder_id = folder_by_doc_id.get(doc_id) if isinstance(doc_id, int) else None
|
|
||||||
base_folder = folder_paths.get(folder_id, "/documents")
|
|
||||||
file_name = _safe_filename(title)
|
|
||||||
path = f"{base_folder}/{file_name}"
|
|
||||||
if path in files:
|
|
||||||
stem = file_name.removesuffix(".xml")
|
|
||||||
path = f"{base_folder}/{stem} ({doc_id}).xml"
|
|
||||||
matched_ids = set(document.get("matched_chunk_ids") or [])
|
|
||||||
xml_content = _build_document_xml(document, matched_chunk_ids=matched_ids)
|
|
||||||
files[path] = {
|
|
||||||
"content": xml_content.split("\n"),
|
|
||||||
"encoding": "utf-8",
|
|
||||||
"created_at": "",
|
|
||||||
"modified_at": "",
|
|
||||||
}
|
|
||||||
if isinstance(doc_id, int):
|
|
||||||
doc_id_to_path[doc_id] = path
|
|
||||||
return files, doc_id_to_path
|
|
||||||
|
|
||||||
|
|
||||||
def _build_anon_scoped_filesystem(
|
class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
documents: Sequence[dict[str, Any]],
|
"""Compute hybrid-search priority hints for the current turn."""
|
||||||
) -> dict[str, dict[str, str]]:
|
|
||||||
"""Build a scoped filesystem for anonymous documents without DB queries.
|
|
||||||
|
|
||||||
Anonymous uploads have no folders, so all files go under /documents.
|
|
||||||
"""
|
|
||||||
files: dict[str, dict[str, str]] = {}
|
|
||||||
for document in documents:
|
|
||||||
doc_meta = document.get("document") or {}
|
|
||||||
title = str(doc_meta.get("title") or "untitled")
|
|
||||||
file_name = _safe_filename(title)
|
|
||||||
path = f"/documents/{file_name}"
|
|
||||||
if path in files:
|
|
||||||
doc_id = doc_meta.get("id", "dup")
|
|
||||||
stem = file_name.removesuffix(".xml")
|
|
||||||
path = f"/documents/{stem} ({doc_id}).xml"
|
|
||||||
matched_ids = set(document.get("matched_chunk_ids") or [])
|
|
||||||
xml_content = _build_document_xml(document, matched_chunk_ids=matched_ids)
|
|
||||||
files[path] = {
|
|
||||||
"content": xml_content.split("\n"),
|
|
||||||
"encoding": "utf-8",
|
|
||||||
"created_at": "",
|
|
||||||
"modified_at": "",
|
|
||||||
}
|
|
||||||
return files
|
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|
||||||
"""Pre-agent middleware that always searches the KB and seeds a scoped filesystem."""
|
|
||||||
|
|
||||||
tools = ()
|
tools = ()
|
||||||
|
state_schema = SurfSenseFilesystemState
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -863,7 +581,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
available_document_types: list[str] | None = None,
|
available_document_types: list[str] | None = None,
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
mentioned_document_ids: list[int] | None = None,
|
mentioned_document_ids: list[int] | None = None,
|
||||||
anon_session_id: str | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
self.search_space_id = search_space_id
|
self.search_space_id = search_space_id
|
||||||
|
|
@ -872,7 +589,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
self.available_document_types = available_document_types
|
self.available_document_types = available_document_types
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.mentioned_document_ids = mentioned_document_ids or []
|
self.mentioned_document_ids = mentioned_document_ids or []
|
||||||
self.anon_session_id = anon_session_id
|
|
||||||
|
|
||||||
async def _plan_search_inputs(
|
async def _plan_search_inputs(
|
||||||
self,
|
self,
|
||||||
|
|
@ -880,10 +596,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
messages: Sequence[BaseMessage],
|
messages: Sequence[BaseMessage],
|
||||||
user_text: str,
|
user_text: str,
|
||||||
) -> tuple[str, datetime | None, datetime | None, bool]:
|
) -> tuple[str, datetime | None, datetime | None, bool]:
|
||||||
"""Rewrite the KB query and infer optional date filters with the LLM.
|
|
||||||
|
|
||||||
Returns (optimized_query, start_date, end_date, is_recency_query).
|
|
||||||
"""
|
|
||||||
if self.llm is None:
|
if self.llm is None:
|
||||||
return user_text, None, None, False
|
return user_text, None, None, False
|
||||||
|
|
||||||
|
|
@ -914,7 +626,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
)
|
)
|
||||||
is_recency = plan.is_recency_query
|
is_recency = plan.is_recency_query
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[kb_fs_middleware] planner in %.3fs query=%r optimized=%r "
|
"[kb_priority] planner in %.3fs query=%r optimized=%r "
|
||||||
"start=%s end=%s recency=%s",
|
"start=%s end=%s recency=%s",
|
||||||
loop.time() - t0,
|
loop.time() - t0,
|
||||||
user_text[:80],
|
user_text[:80],
|
||||||
|
|
@ -946,106 +658,68 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
pass
|
pass
|
||||||
return asyncio.run(self.abefore_agent(state, runtime))
|
return asyncio.run(self.abefore_agent(state, runtime))
|
||||||
|
|
||||||
async def _load_anon_document(self) -> dict[str, Any] | None:
|
|
||||||
"""Load the anonymous user's uploaded document from Redis."""
|
|
||||||
if not self.anon_session_id:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
import redis.asyncio as aioredis
|
|
||||||
|
|
||||||
from app.config import config
|
|
||||||
|
|
||||||
redis_client = aioredis.from_url(
|
|
||||||
config.REDIS_APP_URL, decode_responses=True
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
redis_key = f"anon:doc:{self.anon_session_id}"
|
|
||||||
data = await redis_client.get(redis_key)
|
|
||||||
if not data:
|
|
||||||
return None
|
|
||||||
doc = json.loads(data)
|
|
||||||
return {
|
|
||||||
"document_id": -1,
|
|
||||||
"content": doc.get("content", ""),
|
|
||||||
"score": 1.0,
|
|
||||||
"chunks": [
|
|
||||||
{
|
|
||||||
"chunk_id": -1,
|
|
||||||
"content": doc.get("content", ""),
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"matched_chunk_ids": [-1],
|
|
||||||
"document": {
|
|
||||||
"id": -1,
|
|
||||||
"title": doc.get("filename", "uploaded_document"),
|
|
||||||
"document_type": "FILE",
|
|
||||||
"metadata": {"source": "anonymous_upload"},
|
|
||||||
},
|
|
||||||
"source": "FILE",
|
|
||||||
"_user_mentioned": True,
|
|
||||||
}
|
|
||||||
finally:
|
|
||||||
await redis_client.aclose()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("Failed to load anonymous document from Redis: %s", exc)
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def abefore_agent( # type: ignore[override]
|
async def abefore_agent( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
state: AgentState,
|
state: AgentState,
|
||||||
runtime: Runtime[Any],
|
runtime: Runtime[Any],
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
del runtime
|
del runtime
|
||||||
|
if self.filesystem_mode != FilesystemMode.CLOUD:
|
||||||
|
return None
|
||||||
|
|
||||||
messages = state.get("messages") or []
|
messages = state.get("messages") or []
|
||||||
if not messages:
|
if not messages:
|
||||||
return None
|
return None
|
||||||
if self.filesystem_mode != FilesystemMode.CLOUD:
|
|
||||||
# Local-folder mode should not seed cloud KB documents into filesystem.
|
|
||||||
return None
|
|
||||||
|
|
||||||
last_human = None
|
last_human: HumanMessage | None = None
|
||||||
for msg in reversed(messages):
|
for msg in reversed(messages):
|
||||||
if isinstance(msg, HumanMessage):
|
if isinstance(msg, HumanMessage):
|
||||||
last_human = msg
|
last_human = msg
|
||||||
break
|
break
|
||||||
if last_human is None:
|
if last_human is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
user_text = _extract_text_from_message(last_human).strip()
|
user_text = _extract_text_from_message(last_human).strip()
|
||||||
if not user_text:
|
if not user_text:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
t0 = _perf_log and asyncio.get_event_loop().time()
|
anon_doc = state.get("kb_anon_doc")
|
||||||
existing_files = state.get("files")
|
if anon_doc:
|
||||||
|
return self._anon_priority(state, anon_doc)
|
||||||
|
|
||||||
# --- Anonymous session: load Redis doc and skip DB queries ---
|
return await self._authenticated_priority(state, messages, user_text)
|
||||||
if self.anon_session_id:
|
|
||||||
merged: list[dict[str, Any]] = []
|
|
||||||
anon_doc = await self._load_anon_document()
|
|
||||||
if anon_doc:
|
|
||||||
merged.append(anon_doc)
|
|
||||||
|
|
||||||
if merged:
|
def _anon_priority(
|
||||||
new_files = _build_anon_scoped_filesystem(merged)
|
self,
|
||||||
mentioned_paths = set(new_files.keys())
|
state: AgentState,
|
||||||
else:
|
anon_doc: dict[str, Any],
|
||||||
new_files = {}
|
) -> dict[str, Any]:
|
||||||
mentioned_paths = set()
|
path = str(anon_doc.get("path") or "")
|
||||||
|
title = str(anon_doc.get("title") or "uploaded_document")
|
||||||
|
priority = [
|
||||||
|
{
|
||||||
|
"path": path,
|
||||||
|
"score": 1.0,
|
||||||
|
"document_id": None,
|
||||||
|
"title": title,
|
||||||
|
"mentioned": True,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
new_messages = list(state.get("messages") or [])
|
||||||
|
insert_at = max(len(new_messages) - 1, 0)
|
||||||
|
new_messages.insert(insert_at, _render_priority_message(priority))
|
||||||
|
return {
|
||||||
|
"kb_priority": priority,
|
||||||
|
"kb_matched_chunk_ids": {},
|
||||||
|
"messages": new_messages,
|
||||||
|
}
|
||||||
|
|
||||||
ai_msg, tool_msg = _build_synthetic_ls(
|
async def _authenticated_priority(
|
||||||
existing_files,
|
self,
|
||||||
new_files,
|
state: AgentState,
|
||||||
mentioned_paths=mentioned_paths,
|
messages: Sequence[BaseMessage],
|
||||||
)
|
user_text: str,
|
||||||
if t0 is not None:
|
) -> dict[str, Any]:
|
||||||
_perf_log.info(
|
t0 = asyncio.get_event_loop().time()
|
||||||
"[kb_fs_middleware] anon completed in %.3fs new_files=%d",
|
|
||||||
asyncio.get_event_loop().time() - t0,
|
|
||||||
len(new_files),
|
|
||||||
)
|
|
||||||
return {"files": new_files, "messages": [ai_msg, tool_msg]}
|
|
||||||
|
|
||||||
# --- Authenticated session: full KB search ---
|
|
||||||
(
|
(
|
||||||
planned_query,
|
planned_query,
|
||||||
start_date,
|
start_date,
|
||||||
|
|
@ -1056,7 +730,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
user_text=user_text,
|
user_text=user_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- 1. Fetch mentioned documents (user-selected, all chunks) ---
|
|
||||||
mentioned_results: list[dict[str, Any]] = []
|
mentioned_results: list[dict[str, Any]] = []
|
||||||
if self.mentioned_document_ids:
|
if self.mentioned_document_ids:
|
||||||
mentioned_results = await fetch_mentioned_documents(
|
mentioned_results = await fetch_mentioned_documents(
|
||||||
|
|
@ -1065,7 +738,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
)
|
)
|
||||||
self.mentioned_document_ids = []
|
self.mentioned_document_ids = []
|
||||||
|
|
||||||
# --- 2. Run KB search (recency browse or hybrid) ---
|
|
||||||
if is_recency:
|
if is_recency:
|
||||||
doc_types = _resolve_search_types(
|
doc_types = _resolve_search_types(
|
||||||
self.available_connectors, self.available_document_types
|
self.available_connectors, self.available_document_types
|
||||||
|
|
@ -1088,48 +760,108 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- 3. Merge: mentioned first, then search (dedup by doc id) ---
|
|
||||||
seen_doc_ids: set[int] = set()
|
seen_doc_ids: set[int] = set()
|
||||||
merged_auth: list[dict[str, Any]] = []
|
merged: list[dict[str, Any]] = []
|
||||||
for doc in mentioned_results:
|
for doc in mentioned_results:
|
||||||
doc_id = (doc.get("document") or {}).get("id")
|
doc_id = (doc.get("document") or {}).get("id")
|
||||||
if doc_id is not None:
|
if isinstance(doc_id, int):
|
||||||
seen_doc_ids.add(doc_id)
|
seen_doc_ids.add(doc_id)
|
||||||
merged_auth.append(doc)
|
merged.append(doc)
|
||||||
for doc in search_results:
|
for doc in search_results:
|
||||||
doc_id = (doc.get("document") or {}).get("id")
|
doc_id = (doc.get("document") or {}).get("id")
|
||||||
if doc_id is not None and doc_id in seen_doc_ids:
|
if isinstance(doc_id, int) and doc_id in seen_doc_ids:
|
||||||
continue
|
continue
|
||||||
merged_auth.append(doc)
|
merged.append(doc)
|
||||||
|
|
||||||
# --- 4. Build scoped filesystem ---
|
priority, matched_chunk_ids = await self._materialize_priority(merged)
|
||||||
new_files, doc_id_to_path = await build_scoped_filesystem(
|
|
||||||
documents=merged_auth,
|
new_messages = list(messages)
|
||||||
search_space_id=self.search_space_id,
|
insert_at = max(len(new_messages) - 1, 0)
|
||||||
|
new_messages.insert(insert_at, _render_priority_message(priority))
|
||||||
|
|
||||||
|
_perf_log.info(
|
||||||
|
"[kb_priority] completed in %.3fs query=%r priority=%d mentioned=%d",
|
||||||
|
asyncio.get_event_loop().time() - t0,
|
||||||
|
user_text[:80],
|
||||||
|
len(priority),
|
||||||
|
len(mentioned_results),
|
||||||
)
|
)
|
||||||
|
|
||||||
mentioned_doc_ids = {
|
return {
|
||||||
(d.get("document") or {}).get("id") for d in mentioned_results
|
"kb_priority": priority,
|
||||||
}
|
"kb_matched_chunk_ids": matched_chunk_ids,
|
||||||
mentioned_paths = {
|
"messages": new_messages,
|
||||||
doc_id_to_path[did] for did in mentioned_doc_ids if did in doc_id_to_path
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ai_msg, tool_msg = _build_synthetic_ls(
|
async def _materialize_priority(
|
||||||
existing_files,
|
self, merged: list[dict[str, Any]]
|
||||||
new_files,
|
) -> tuple[list[dict[str, Any]], dict[int, list[int]]]:
|
||||||
mentioned_paths=mentioned_paths,
|
"""Resolve canonical paths and matched chunk ids for the priority list."""
|
||||||
)
|
priority: list[dict[str, Any]] = []
|
||||||
|
matched_chunk_ids: dict[int, list[int]] = {}
|
||||||
|
|
||||||
if t0 is not None:
|
if not merged:
|
||||||
_perf_log.info(
|
return priority, matched_chunk_ids
|
||||||
"[kb_fs_middleware] completed in %.3fs query=%r optimized=%r "
|
|
||||||
"mentioned=%d new_files=%d total=%d",
|
async with shielded_async_session() as session:
|
||||||
asyncio.get_event_loop().time() - t0,
|
index: PathIndex = await build_path_index(session, self.search_space_id)
|
||||||
user_text[:80],
|
doc_ids = [
|
||||||
planned_query[:120],
|
(doc.get("document") or {}).get("id")
|
||||||
len(mentioned_results),
|
for doc in merged
|
||||||
len(new_files),
|
if isinstance(doc, dict)
|
||||||
len(new_files) + len(existing_files or {}),
|
]
|
||||||
|
doc_ids = [doc_id for doc_id in doc_ids if isinstance(doc_id, int)]
|
||||||
|
folder_by_doc_id: dict[int, int | None] = {}
|
||||||
|
if doc_ids:
|
||||||
|
folder_rows = await session.execute(
|
||||||
|
select(Document.id, Document.folder_id).where(
|
||||||
|
Document.search_space_id == self.search_space_id,
|
||||||
|
Document.id.in_(doc_ids),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
folder_by_doc_id = {row.id: row.folder_id for row in folder_rows.all()}
|
||||||
|
|
||||||
|
for doc in merged:
|
||||||
|
doc_meta = doc.get("document") or {}
|
||||||
|
doc_id = doc_meta.get("id")
|
||||||
|
title = doc_meta.get("title") or "untitled"
|
||||||
|
folder_id = (
|
||||||
|
folder_by_doc_id.get(doc_id)
|
||||||
|
if isinstance(doc_id, int)
|
||||||
|
else doc_meta.get("folder_id")
|
||||||
)
|
)
|
||||||
return {"files": new_files, "messages": [ai_msg, tool_msg]}
|
path = doc_to_virtual_path(
|
||||||
|
doc_id=doc_id if isinstance(doc_id, int) else None,
|
||||||
|
title=str(title),
|
||||||
|
folder_id=folder_id if isinstance(folder_id, int) else None,
|
||||||
|
index=index,
|
||||||
|
)
|
||||||
|
priority.append(
|
||||||
|
{
|
||||||
|
"path": path,
|
||||||
|
"score": float(doc.get("score") or 0.0),
|
||||||
|
"document_id": doc_id if isinstance(doc_id, int) else None,
|
||||||
|
"title": str(title),
|
||||||
|
"mentioned": bool(doc.get("_user_mentioned")),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if isinstance(doc_id, int):
|
||||||
|
chunk_ids = doc.get("matched_chunk_ids") or []
|
||||||
|
if chunk_ids:
|
||||||
|
matched_chunk_ids[doc_id] = [
|
||||||
|
int(cid) for cid in chunk_ids if isinstance(cid, (int, str))
|
||||||
|
]
|
||||||
|
return priority, matched_chunk_ids
|
||||||
|
|
||||||
|
|
||||||
|
# Backwards-compatible alias for any external imports.
|
||||||
|
KnowledgeBaseSearchMiddleware = KnowledgePriorityMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"KnowledgeBaseSearchMiddleware",
|
||||||
|
"KnowledgePriorityMiddleware",
|
||||||
|
"browse_recent_documents",
|
||||||
|
"fetch_mentioned_documents",
|
||||||
|
"search_knowledge_base",
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,272 @@
|
||||||
|
"""Workspace-tree middleware for the SurfSense agent.
|
||||||
|
|
||||||
|
Renders the full ``Folder``+``Document`` tree under ``/documents/`` once per
|
||||||
|
turn (cloud only), caches it by ``(search_space_id, tree_version)``, and
|
||||||
|
injects the result as a ``<workspace_tree>`` system message immediately
|
||||||
|
before the latest human turn.
|
||||||
|
|
||||||
|
The render is bounded by two truncation layers:
|
||||||
|
|
||||||
|
1. **Entry cap** — at most ``MAX_TREE_ENTRIES`` lines. The remainder is
|
||||||
|
replaced with a "use ls" hint.
|
||||||
|
2. **Token cap** — at most ``MAX_TREE_TOKENS`` tokens (using the LLM's
|
||||||
|
token-count profile when available). If the entry-truncated tree still
|
||||||
|
exceeds the token cap we fall back to a root-only summary.
|
||||||
|
|
||||||
|
Anonymous mode renders only ``state['kb_anon_doc']`` (no DB calls).
|
||||||
|
|
||||||
|
This middleware also performs a one-time initialization of ``state['cwd']``
|
||||||
|
to ``"/documents"`` so subsequent middlewares and tools always see a valid
|
||||||
|
cwd in cloud mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from langchain_core.messages import SystemMessage
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
|
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
|
||||||
|
from app.agents.new_chat.path_resolver import (
|
||||||
|
DOCUMENTS_ROOT,
|
||||||
|
PathIndex,
|
||||||
|
build_path_index,
|
||||||
|
doc_to_virtual_path,
|
||||||
|
)
|
||||||
|
from app.db import Document, shielded_async_session
|
||||||
|
|
||||||
|
try:
|
||||||
|
from litellm import token_counter
|
||||||
|
except Exception: # pragma: no cover - optional dep
|
||||||
|
token_counter = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
MAX_TREE_ENTRIES = 500
|
||||||
|
MAX_TREE_TOKENS = 4000
|
||||||
|
|
||||||
|
|
||||||
|
def _approx_tokens(text: str) -> int:
|
||||||
|
"""Cheap fallback token estimate (1 token ~= 4 chars)."""
|
||||||
|
return max(1, (len(text) + 3) // 4)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_tokens(text: str, *, llm: BaseChatModel | None) -> int:
|
||||||
|
if llm is None:
|
||||||
|
return _approx_tokens(text)
|
||||||
|
count_fn = getattr(llm, "_count_tokens", None)
|
||||||
|
if callable(count_fn):
|
||||||
|
try:
|
||||||
|
return int(count_fn([{"role": "user", "content": text}]))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
profile = getattr(llm, "profile", None)
|
||||||
|
model_names: list[str] = []
|
||||||
|
if isinstance(profile, dict):
|
||||||
|
tcms = profile.get("token_count_models")
|
||||||
|
if isinstance(tcms, list):
|
||||||
|
model_names.extend(name for name in tcms if isinstance(name, str) and name)
|
||||||
|
tcm = profile.get("token_count_model")
|
||||||
|
if isinstance(tcm, str) and tcm and tcm not in model_names:
|
||||||
|
model_names.append(tcm)
|
||||||
|
model_name = model_names[0] if model_names else getattr(llm, "model", None)
|
||||||
|
if not isinstance(model_name, str) or not model_name or token_counter is None:
|
||||||
|
return _approx_tokens(text)
|
||||||
|
try:
|
||||||
|
return int(
|
||||||
|
token_counter(
|
||||||
|
messages=[{"role": "user", "content": text}],
|
||||||
|
model=model_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return _approx_tokens(text)
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
"""Inject the workspace folder/document tree into the agent's context."""
|
||||||
|
|
||||||
|
tools = ()
|
||||||
|
state_schema = SurfSenseFilesystemState
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
search_space_id: int,
|
||||||
|
filesystem_mode: FilesystemMode,
|
||||||
|
llm: BaseChatModel | None = None,
|
||||||
|
max_entries: int = MAX_TREE_ENTRIES,
|
||||||
|
max_tokens: int = MAX_TREE_TOKENS,
|
||||||
|
) -> None:
|
||||||
|
self.search_space_id = search_space_id
|
||||||
|
self.filesystem_mode = filesystem_mode
|
||||||
|
self.llm = llm
|
||||||
|
self.max_entries = max_entries
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self._cache: dict[tuple[int, int, bool], str] = {}
|
||||||
|
|
||||||
|
async def abefore_agent( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
state: AgentState,
|
||||||
|
runtime: Runtime[Any],
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
del runtime
|
||||||
|
if self.filesystem_mode != FilesystemMode.CLOUD:
|
||||||
|
return None
|
||||||
|
|
||||||
|
update: dict[str, Any] = {}
|
||||||
|
if not state.get("cwd"):
|
||||||
|
update["cwd"] = DOCUMENTS_ROOT
|
||||||
|
|
||||||
|
anon_doc = state.get("kb_anon_doc")
|
||||||
|
if anon_doc:
|
||||||
|
tree_msg = self._render_anon_tree(anon_doc)
|
||||||
|
else:
|
||||||
|
tree_msg = await self._render_kb_tree(state)
|
||||||
|
|
||||||
|
messages = list(state.get("messages") or [])
|
||||||
|
insert_at = max(len(messages) - 1, 0)
|
||||||
|
messages.insert(insert_at, SystemMessage(content=tree_msg))
|
||||||
|
update["messages"] = messages
|
||||||
|
return update
|
||||||
|
|
||||||
|
def before_agent( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
state: AgentState,
|
||||||
|
runtime: Runtime[Any],
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
if loop.is_running():
|
||||||
|
return None
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
return asyncio.run(self.abefore_agent(state, runtime))
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ render
|
||||||
|
|
||||||
|
def _render_anon_tree(self, anon_doc: dict[str, Any]) -> str:
|
||||||
|
path = str(anon_doc.get("path") or "")
|
||||||
|
title = str(anon_doc.get("title") or "uploaded_document")
|
||||||
|
return (
|
||||||
|
"<workspace_tree>\n"
|
||||||
|
"Anonymous session — only one read-only document is available.\n"
|
||||||
|
f"{DOCUMENTS_ROOT}/\n"
|
||||||
|
f" {path} — {title}\n"
|
||||||
|
"</workspace_tree>"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _render_kb_tree(self, state: AgentState) -> str:
|
||||||
|
version = int(state.get("tree_version") or 0)
|
||||||
|
cache_key = (self.search_space_id, version, False)
|
||||||
|
cached = self._cache.get(cache_key)
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with shielded_async_session() as session:
|
||||||
|
index = await build_path_index(session, self.search_space_id)
|
||||||
|
doc_rows = await session.execute(
|
||||||
|
select(Document.id, Document.title, Document.folder_id).where(
|
||||||
|
Document.search_space_id == self.search_space_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
docs = list(doc_rows.all())
|
||||||
|
except Exception as exc: # pragma: no cover - defensive
|
||||||
|
logger.warning("knowledge_tree: DB error %s", exc)
|
||||||
|
return "<workspace_tree>\n(unavailable)\n</workspace_tree>"
|
||||||
|
|
||||||
|
rendered = self._format_tree(index, docs)
|
||||||
|
self._cache[cache_key] = rendered
|
||||||
|
return rendered
|
||||||
|
|
||||||
|
def _format_tree(self, index: PathIndex, docs: list[Any]) -> str:
|
||||||
|
folder_paths = sorted(set(index.folder_paths.values()))
|
||||||
|
doc_paths = sorted(
|
||||||
|
doc_to_virtual_path(
|
||||||
|
doc_id=row.id,
|
||||||
|
title=str(row.title or "untitled"),
|
||||||
|
folder_id=row.folder_id,
|
||||||
|
index=index,
|
||||||
|
)
|
||||||
|
for row in docs
|
||||||
|
)
|
||||||
|
all_paths = sorted(set(folder_paths + doc_paths + [DOCUMENTS_ROOT]))
|
||||||
|
|
||||||
|
lines: list[str] = []
|
||||||
|
for path in all_paths:
|
||||||
|
depth = (
|
||||||
|
0
|
||||||
|
if path == DOCUMENTS_ROOT
|
||||||
|
else len([p for p in path[len(DOCUMENTS_ROOT) :].split("/") if p])
|
||||||
|
)
|
||||||
|
indent = " " * depth
|
||||||
|
is_dir = path == DOCUMENTS_ROOT or path in folder_paths
|
||||||
|
display = (
|
||||||
|
path.rsplit("/", 1)[-1] if path != DOCUMENTS_ROOT else "/documents"
|
||||||
|
)
|
||||||
|
if is_dir:
|
||||||
|
lines.append(f"{indent}{display}/")
|
||||||
|
else:
|
||||||
|
lines.append(f"{indent}{display}")
|
||||||
|
if len(lines) >= self.max_entries:
|
||||||
|
remaining = len(all_paths) - len(lines)
|
||||||
|
if remaining > 0:
|
||||||
|
lines.append(
|
||||||
|
f"... {remaining} more entries — use "
|
||||||
|
"ls('/documents/<folder>', offset, limit) to expand"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
body = "\n".join(lines)
|
||||||
|
rendered = f"<workspace_tree>\n{body}\n</workspace_tree>"
|
||||||
|
|
||||||
|
token_count = _count_tokens(rendered, llm=self.llm)
|
||||||
|
if token_count <= self.max_tokens:
|
||||||
|
return rendered
|
||||||
|
|
||||||
|
return self._format_root_summary(folder_paths, doc_paths)
|
||||||
|
|
||||||
|
def _format_root_summary(
|
||||||
|
self, folder_paths: list[str], doc_paths: list[str]
|
||||||
|
) -> str:
|
||||||
|
top_level: dict[str, int] = {}
|
||||||
|
loose_docs = 0
|
||||||
|
for path in doc_paths:
|
||||||
|
rel = path[len(DOCUMENTS_ROOT) :].lstrip("/")
|
||||||
|
if "/" in rel:
|
||||||
|
top = rel.split("/", 1)[0]
|
||||||
|
top_level[top] = top_level.get(top, 0) + 1
|
||||||
|
else:
|
||||||
|
loose_docs += 1
|
||||||
|
for path in folder_paths:
|
||||||
|
rel = path[len(DOCUMENTS_ROOT) :].lstrip("/")
|
||||||
|
if not rel:
|
||||||
|
continue
|
||||||
|
top = rel.split("/", 1)[0]
|
||||||
|
top_level.setdefault(top, 0)
|
||||||
|
|
||||||
|
lines = [DOCUMENTS_ROOT + "/"]
|
||||||
|
for name in sorted(top_level):
|
||||||
|
count = top_level[name]
|
||||||
|
lines.append(f" {name}/ ({count} document{'s' if count != 1 else ''})")
|
||||||
|
if loose_docs:
|
||||||
|
lines.append(
|
||||||
|
f" ({loose_docs} loose document{'s' if loose_docs != 1 else ''})"
|
||||||
|
)
|
||||||
|
lines.append(
|
||||||
|
"Tree is large; use list_tree('/documents/<folder>') to drill in "
|
||||||
|
"or ls('/documents/<folder>', offset, limit) for paginated listings."
|
||||||
|
)
|
||||||
|
return "<workspace_tree>\n" + "\n".join(lines) + "\n</workspace_tree>"
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["KnowledgeTreeMiddleware"]
|
||||||
351
surfsense_backend/app/agents/new_chat/path_resolver.py
Normal file
351
surfsense_backend/app/agents/new_chat/path_resolver.py
Normal file
|
|
@ -0,0 +1,351 @@
|
||||||
|
"""Canonical virtual-path resolver for SurfSense knowledge-base documents.
|
||||||
|
|
||||||
|
This module is the single source of truth for mapping ``Document`` rows to
|
||||||
|
virtual paths under ``/documents/`` and back. It is used by:
|
||||||
|
|
||||||
|
* :class:`KnowledgeTreeMiddleware` (rendering the workspace tree)
|
||||||
|
* :class:`KnowledgePriorityMiddleware` (computing priority paths)
|
||||||
|
* :class:`KBPostgresBackend` (``als_info`` / ``aread`` / move operations)
|
||||||
|
* :class:`KnowledgeBasePersistenceMiddleware` (resolving moves and creates)
|
||||||
|
|
||||||
|
Centralising the logic ensures that title-collision suffixes, folder paths,
|
||||||
|
and ``unique_identifier_hash`` lookups never drift between renders and
|
||||||
|
commits.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import Document, DocumentType, Folder
|
||||||
|
from app.utils.document_converters import generate_unique_identifier_hash
|
||||||
|
|
||||||
|
DOCUMENTS_ROOT = "/documents"
|
||||||
|
"""Root virtual folder for all KB documents."""
|
||||||
|
|
||||||
|
_INVALID_FILENAME_CHARS = re.compile(r"[\\/:*?\"<>|]+")
|
||||||
|
_WHITESPACE_RUN = re.compile(r"\s+")
|
||||||
|
|
||||||
|
|
||||||
|
def safe_filename(value: str, *, fallback: str = "untitled.xml") -> str:
|
||||||
|
"""Convert arbitrary text into a filesystem-safe ``.xml`` filename."""
|
||||||
|
name = _INVALID_FILENAME_CHARS.sub("_", value).strip()
|
||||||
|
name = _WHITESPACE_RUN.sub(" ", name)
|
||||||
|
if not name:
|
||||||
|
name = fallback
|
||||||
|
if len(name) > 180:
|
||||||
|
name = name[:180].rstrip()
|
||||||
|
if not name.lower().endswith(".xml"):
|
||||||
|
name = f"{name}.xml"
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def safe_folder_segment(value: str, *, fallback: str = "folder") -> str:
|
||||||
|
"""Sanitize a single folder name into a path-safe segment."""
|
||||||
|
name = _INVALID_FILENAME_CHARS.sub("_", value).strip()
|
||||||
|
name = _WHITESPACE_RUN.sub(" ", name)
|
||||||
|
if not name:
|
||||||
|
return fallback
|
||||||
|
if len(name) > 180:
|
||||||
|
name = name[:180].rstrip()
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def _suffix_with_doc_id(filename: str, doc_id: int | None) -> str:
|
||||||
|
if doc_id is None:
|
||||||
|
return filename
|
||||||
|
if not filename.lower().endswith(".xml"):
|
||||||
|
return f"{filename} ({doc_id}).xml"
|
||||||
|
stem = filename[:-4]
|
||||||
|
return f"{stem} ({doc_id}).xml"
|
||||||
|
|
||||||
|
|
||||||
|
_SUFFIX_PATTERN = re.compile(r"\s\((\d+)\)\.xml$", re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_doc_id_suffix(filename: str) -> tuple[str, int | None]:
|
||||||
|
"""Strip a trailing ``" (<doc_id>).xml"`` suffix; return ``(stem, doc_id)``.
|
||||||
|
|
||||||
|
If no suffix is present, returns ``(stem_without_xml_extension, None)``.
|
||||||
|
"""
|
||||||
|
match = _SUFFIX_PATTERN.search(filename)
|
||||||
|
if match:
|
||||||
|
doc_id = int(match.group(1))
|
||||||
|
stem = filename[: match.start()]
|
||||||
|
return stem, doc_id
|
||||||
|
if filename.lower().endswith(".xml"):
|
||||||
|
return filename[:-4], None
|
||||||
|
return filename, None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PathIndex:
|
||||||
|
"""In-memory occupancy snapshot used by :func:`doc_to_virtual_path`.
|
||||||
|
|
||||||
|
Built once per call site so collision handling is deterministic and so
|
||||||
|
we don't perform N folder lookups per render.
|
||||||
|
"""
|
||||||
|
|
||||||
|
folder_paths: dict[int, str] = field(default_factory=dict)
|
||||||
|
"""``Folder.id`` -> absolute virtual folder path under ``/documents``."""
|
||||||
|
|
||||||
|
occupants: dict[str, int] = field(default_factory=dict)
|
||||||
|
"""virtual path -> ``Document.id`` already occupying that path (this render)."""
|
||||||
|
|
||||||
|
|
||||||
|
async def _build_folder_paths(
|
||||||
|
session: AsyncSession,
|
||||||
|
search_space_id: int,
|
||||||
|
) -> dict[int, str]:
|
||||||
|
"""Compute ``Folder.id`` -> absolute virtual path under ``/documents``."""
|
||||||
|
result = await session.execute(
|
||||||
|
select(Folder.id, Folder.name, Folder.parent_id).where(
|
||||||
|
Folder.search_space_id == search_space_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rows = result.all()
|
||||||
|
by_id = {row.id: {"name": row.name, "parent_id": row.parent_id} for row in rows}
|
||||||
|
cache: dict[int, str] = {}
|
||||||
|
|
||||||
|
def resolve(folder_id: int) -> str:
|
||||||
|
if folder_id in cache:
|
||||||
|
return cache[folder_id]
|
||||||
|
parts: list[str] = []
|
||||||
|
cursor: int | None = folder_id
|
||||||
|
visited: set[int] = set()
|
||||||
|
while cursor is not None and cursor in by_id and cursor not in visited:
|
||||||
|
visited.add(cursor)
|
||||||
|
entry = by_id[cursor]
|
||||||
|
parts.append(safe_folder_segment(str(entry["name"])))
|
||||||
|
cursor = entry["parent_id"]
|
||||||
|
parts.reverse()
|
||||||
|
path = f"{DOCUMENTS_ROOT}/" + "/".join(parts) if parts else DOCUMENTS_ROOT
|
||||||
|
cache[folder_id] = path
|
||||||
|
return path
|
||||||
|
|
||||||
|
for folder_id in by_id:
|
||||||
|
resolve(folder_id)
|
||||||
|
return cache
|
||||||
|
|
||||||
|
|
||||||
|
async def build_path_index(
|
||||||
|
session: AsyncSession,
|
||||||
|
search_space_id: int,
|
||||||
|
*,
|
||||||
|
populate_occupants: bool = True,
|
||||||
|
) -> PathIndex:
|
||||||
|
"""Build a :class:`PathIndex` for a search space.
|
||||||
|
|
||||||
|
``populate_occupants`` controls whether the occupancy map is pre-seeded
|
||||||
|
from existing ``Document`` rows. Most callers want this so that
|
||||||
|
:func:`doc_to_virtual_path` can detect collisions across the whole space;
|
||||||
|
the persistence middleware sets this to ``False`` when it is iterating to
|
||||||
|
decide where to place fresh documents.
|
||||||
|
"""
|
||||||
|
folder_paths = await _build_folder_paths(session, search_space_id)
|
||||||
|
occupants: dict[str, int] = {}
|
||||||
|
if populate_occupants:
|
||||||
|
rows = await session.execute(
|
||||||
|
select(Document.id, Document.title, Document.folder_id).where(
|
||||||
|
Document.search_space_id == search_space_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for row in rows.all():
|
||||||
|
base = folder_paths.get(row.folder_id, DOCUMENTS_ROOT)
|
||||||
|
filename = safe_filename(str(row.title or "untitled"))
|
||||||
|
path = f"{base}/{filename}"
|
||||||
|
if path in occupants and occupants[path] != row.id:
|
||||||
|
path = f"{base}/{_suffix_with_doc_id(filename, row.id)}"
|
||||||
|
occupants[path] = row.id
|
||||||
|
return PathIndex(folder_paths=folder_paths, occupants=occupants)
|
||||||
|
|
||||||
|
|
||||||
|
def doc_to_virtual_path(
|
||||||
|
*,
|
||||||
|
doc_id: int | None,
|
||||||
|
title: str,
|
||||||
|
folder_id: int | None,
|
||||||
|
index: PathIndex,
|
||||||
|
) -> str:
|
||||||
|
"""Return the canonical virtual path for a document.
|
||||||
|
|
||||||
|
Mutates ``index.occupants`` so subsequent calls see this assignment and
|
||||||
|
deterministically pick a different suffix for the next colliding doc.
|
||||||
|
"""
|
||||||
|
base = index.folder_paths.get(folder_id, DOCUMENTS_ROOT)
|
||||||
|
filename = safe_filename(str(title or "untitled"))
|
||||||
|
path = f"{base}/{filename}"
|
||||||
|
occupant = index.occupants.get(path)
|
||||||
|
if occupant is not None and occupant != doc_id:
|
||||||
|
path = f"{base}/{_suffix_with_doc_id(filename, doc_id)}"
|
||||||
|
if doc_id is not None:
|
||||||
|
index.occupants[path] = doc_id
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
async def virtual_path_to_doc(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
search_space_id: int,
|
||||||
|
virtual_path: str,
|
||||||
|
) -> Document | None:
|
||||||
|
"""Resolve a virtual path back to a ``Document`` row.
|
||||||
|
|
||||||
|
Resolution order:
|
||||||
|
1. ``Document.unique_identifier_hash`` lookup (fast path for paths created
|
||||||
|
by SurfSense itself — every NOTE write goes through this hash).
|
||||||
|
2. If the basename carries a ``" (<doc_id>).xml"`` disambiguation suffix,
|
||||||
|
try a direct id lookup constrained to the search space.
|
||||||
|
3. Title-from-basename + folder-resolution lookup as a last resort.
|
||||||
|
"""
|
||||||
|
if not virtual_path or not virtual_path.startswith(DOCUMENTS_ROOT):
|
||||||
|
return None
|
||||||
|
|
||||||
|
unique_hash = generate_unique_identifier_hash(
|
||||||
|
DocumentType.NOTE,
|
||||||
|
virtual_path,
|
||||||
|
search_space_id,
|
||||||
|
)
|
||||||
|
result = await session.execute(
|
||||||
|
select(Document).where(
|
||||||
|
Document.search_space_id == search_space_id,
|
||||||
|
Document.unique_identifier_hash == unique_hash,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
document = result.scalar_one_or_none()
|
||||||
|
if document is not None:
|
||||||
|
return document
|
||||||
|
|
||||||
|
rel = virtual_path[len(DOCUMENTS_ROOT) :].lstrip("/")
|
||||||
|
if not rel:
|
||||||
|
return None
|
||||||
|
parts = [p for p in rel.split("/") if p]
|
||||||
|
if not parts:
|
||||||
|
return None
|
||||||
|
basename = parts[-1]
|
||||||
|
folder_parts = parts[:-1]
|
||||||
|
|
||||||
|
stem, suffix_doc_id = parse_doc_id_suffix(basename)
|
||||||
|
if suffix_doc_id is not None:
|
||||||
|
result = await session.execute(
|
||||||
|
select(Document).where(
|
||||||
|
Document.search_space_id == search_space_id,
|
||||||
|
Document.id == suffix_doc_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
document = result.scalar_one_or_none()
|
||||||
|
if document is not None:
|
||||||
|
return document
|
||||||
|
|
||||||
|
folder_id = await _resolve_folder_id(
|
||||||
|
session, search_space_id=search_space_id, folder_parts=folder_parts
|
||||||
|
)
|
||||||
|
title_candidates: list[str] = []
|
||||||
|
raw_title = stem
|
||||||
|
title_candidates.append(raw_title)
|
||||||
|
if raw_title.endswith(".xml"):
|
||||||
|
title_candidates.append(raw_title[:-4])
|
||||||
|
|
||||||
|
for candidate in dict.fromkeys(title_candidates):
|
||||||
|
if not candidate:
|
||||||
|
continue
|
||||||
|
query = select(Document).where(
|
||||||
|
Document.search_space_id == search_space_id,
|
||||||
|
Document.title == candidate,
|
||||||
|
)
|
||||||
|
if folder_id is None:
|
||||||
|
query = query.where(Document.folder_id.is_(None))
|
||||||
|
else:
|
||||||
|
query = query.where(Document.folder_id == folder_id)
|
||||||
|
result = await session.execute(query)
|
||||||
|
document = result.scalars().first()
|
||||||
|
if document is not None:
|
||||||
|
return document
|
||||||
|
|
||||||
|
# Fallback: title-as-string lookup misses when the real DB title contains
|
||||||
|
# characters that ``safe_filename`` lossily replaces (``:``, ``/``, ``*``,
|
||||||
|
# etc.) — common for connector-imported docs (Google Calendar/Drive etc.).
|
||||||
|
# The workspace tree shows the lossy filename, so the agent passes that
|
||||||
|
# filename back here. Scan all documents in the resolved folder and match
|
||||||
|
# by ``safe_filename(title)`` to recover the original document.
|
||||||
|
folder_scan = select(Document).where(
|
||||||
|
Document.search_space_id == search_space_id,
|
||||||
|
)
|
||||||
|
if folder_id is None:
|
||||||
|
folder_scan = folder_scan.where(Document.folder_id.is_(None))
|
||||||
|
else:
|
||||||
|
folder_scan = folder_scan.where(Document.folder_id == folder_id)
|
||||||
|
result = await session.execute(folder_scan)
|
||||||
|
for candidate_doc in result.scalars().all():
|
||||||
|
encoded = safe_filename(str(candidate_doc.title or "untitled"))
|
||||||
|
if encoded == basename:
|
||||||
|
return candidate_doc
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _resolve_folder_id(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
search_space_id: int,
|
||||||
|
folder_parts: list[str],
|
||||||
|
) -> int | None:
|
||||||
|
"""Look up the leaf folder id for a chain of folder names; return ``None`` if missing."""
|
||||||
|
if not folder_parts:
|
||||||
|
return None
|
||||||
|
parent_id: int | None = None
|
||||||
|
for raw in folder_parts:
|
||||||
|
name = safe_folder_segment(raw)
|
||||||
|
query = select(Folder.id).where(
|
||||||
|
Folder.search_space_id == search_space_id,
|
||||||
|
Folder.name == name,
|
||||||
|
)
|
||||||
|
if parent_id is None:
|
||||||
|
query = query.where(Folder.parent_id.is_(None))
|
||||||
|
else:
|
||||||
|
query = query.where(Folder.parent_id == parent_id)
|
||||||
|
result = await session.execute(query)
|
||||||
|
row = result.first()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
parent_id = row[0]
|
||||||
|
return parent_id
|
||||||
|
|
||||||
|
|
||||||
|
def parse_documents_path(virtual_path: str) -> tuple[list[str], str]:
|
||||||
|
"""Parse a ``/documents/...`` path into ``(folder_parts, document_title)``.
|
||||||
|
|
||||||
|
The title has any ``.xml`` extension and trailing ``" (<doc_id>)"``
|
||||||
|
disambiguation suffix stripped.
|
||||||
|
"""
|
||||||
|
if not virtual_path or not virtual_path.startswith(DOCUMENTS_ROOT):
|
||||||
|
return [], ""
|
||||||
|
rel = virtual_path[len(DOCUMENTS_ROOT) :].strip("/")
|
||||||
|
if not rel:
|
||||||
|
return [], ""
|
||||||
|
parts = [p for p in rel.split("/") if p]
|
||||||
|
if not parts:
|
||||||
|
return [], ""
|
||||||
|
folder_parts = parts[:-1]
|
||||||
|
basename = parts[-1]
|
||||||
|
stem, _ = parse_doc_id_suffix(basename)
|
||||||
|
title = stem
|
||||||
|
if title.endswith(".xml"):
|
||||||
|
title = title[:-4]
|
||||||
|
return folder_parts, title
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DOCUMENTS_ROOT",
|
||||||
|
"PathIndex",
|
||||||
|
"build_path_index",
|
||||||
|
"doc_to_virtual_path",
|
||||||
|
"parse_doc_id_suffix",
|
||||||
|
"parse_documents_path",
|
||||||
|
"safe_filename",
|
||||||
|
"safe_folder_segment",
|
||||||
|
"virtual_path_to_doc",
|
||||||
|
]
|
||||||
201
surfsense_backend/app/agents/new_chat/state_reducers.py
Normal file
201
surfsense_backend/app/agents/new_chat/state_reducers.py
Normal file
|
|
@ -0,0 +1,201 @@
|
||||||
|
"""Reducers and sentinels for SurfSense filesystem state.
|
||||||
|
|
||||||
|
These reducers back the extra state fields used by the cloud-mode filesystem
|
||||||
|
agent (`cwd`, `staged_dirs`, `pending_moves`, `dirty_paths`, `doc_id_by_path`,
|
||||||
|
`kb_priority`, `kb_matched_chunk_ids`, `kb_anon_doc`, `tree_version`).
|
||||||
|
|
||||||
|
Tools mutate these fields ONLY via `Command(update={...})` returns; the
|
||||||
|
reducers are responsible for merging successive updates atomically and for
|
||||||
|
honouring an explicit reset sentinel (`_CLEAR`) so that a single update can
|
||||||
|
both reset and reseed a list (used by `move_file` / `aafter_agent`).
|
||||||
|
|
||||||
|
The sentinel is intentionally a plain string constant rather than a custom
|
||||||
|
object so that LangGraph's checkpointer (which serializes raw `Command.update`
|
||||||
|
deltas via ``ormsgpack`` BEFORE reducers are applied) can round-trip writes
|
||||||
|
that contain it. The token uses a NUL-bracketed form that cannot collide with
|
||||||
|
any real virtual path, document title, or dict key produced by the agent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Final, TypeVar
|
||||||
|
|
||||||
|
_CLEAR: Final[str] = "\x00__SURFSENSE_FILESYSTEM_CLEAR__\x00"
|
||||||
|
"""Reset sentinel; pass it inside a list/dict update to request a reset.
|
||||||
|
|
||||||
|
For list reducers: ``[_CLEAR, *items]`` resets the field then appends ``items``.
|
||||||
|
For dict reducers: ``{_CLEAR: True, **items}`` resets the field then merges ``items``.
|
||||||
|
|
||||||
|
Because the value is a plain string with embedded NUL bytes, it is natively
|
||||||
|
serializable by ``ormsgpack`` (used by LangGraph's PostgreSQL checkpointer)
|
||||||
|
yet still distinct from any real path / key produced by application code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_reducer[T](left: T | None, right: T | None) -> T | None:
|
||||||
|
"""Replace `left` outright with `right`. ``None`` on the right is honored as a reset."""
|
||||||
|
return right
|
||||||
|
|
||||||
|
|
||||||
|
def _is_clear(value: Any) -> bool:
|
||||||
|
return isinstance(value, str) and value == _CLEAR
|
||||||
|
|
||||||
|
|
||||||
|
def _add_unique_reducer(
|
||||||
|
left: list[Any] | None,
|
||||||
|
right: list[Any] | None,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""Append items from ``right`` to ``left`` while preserving uniqueness.
|
||||||
|
|
||||||
|
Semantics:
|
||||||
|
- If ``right`` is ``None`` or empty, return ``left`` unchanged.
|
||||||
|
- If ``right`` contains the ``_CLEAR`` sentinel anywhere, the result is
|
||||||
|
reseeded with only the items that appear AFTER the LAST occurrence of
|
||||||
|
``_CLEAR`` (deduplicated, preserving first-seen order). This gives a
|
||||||
|
single-update "reset and reseed" capability.
|
||||||
|
- Otherwise, items from ``right`` are appended to ``left`` (order preserved
|
||||||
|
from first seen) while skipping values that are already present.
|
||||||
|
"""
|
||||||
|
if right is None:
|
||||||
|
return list(left or [])
|
||||||
|
if not right:
|
||||||
|
return list(left or [])
|
||||||
|
|
||||||
|
last_clear = -1
|
||||||
|
for index, item in enumerate(right):
|
||||||
|
if _is_clear(item):
|
||||||
|
last_clear = index
|
||||||
|
|
||||||
|
if last_clear >= 0:
|
||||||
|
seed: list[Any] = []
|
||||||
|
seen: set[Any] = set()
|
||||||
|
for item in right[last_clear + 1 :]:
|
||||||
|
if _is_clear(item):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
if item in seen:
|
||||||
|
continue
|
||||||
|
seen.add(item)
|
||||||
|
except TypeError:
|
||||||
|
if item in seed:
|
||||||
|
continue
|
||||||
|
seed.append(item)
|
||||||
|
return seed
|
||||||
|
|
||||||
|
base = list(left or [])
|
||||||
|
try:
|
||||||
|
seen: set[Any] = set(base)
|
||||||
|
except TypeError:
|
||||||
|
seen = set()
|
||||||
|
for item in right:
|
||||||
|
if _is_clear(item):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
if item in seen:
|
||||||
|
continue
|
||||||
|
seen.add(item)
|
||||||
|
except TypeError:
|
||||||
|
if item in base:
|
||||||
|
continue
|
||||||
|
base.append(item)
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
def _list_append_reducer(
|
||||||
|
left: list[Any] | None,
|
||||||
|
right: list[Any] | None,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""Append items from ``right`` to ``left`` preserving order and duplicates.
|
||||||
|
|
||||||
|
Honours the ``_CLEAR`` sentinel exactly like :func:`_add_unique_reducer`,
|
||||||
|
but does NOT deduplicate. Used for queues whose ordering and duplicate
|
||||||
|
occurrences matter (e.g. ``pending_moves``).
|
||||||
|
"""
|
||||||
|
if right is None:
|
||||||
|
return list(left or [])
|
||||||
|
if not right:
|
||||||
|
return list(left or [])
|
||||||
|
|
||||||
|
last_clear = -1
|
||||||
|
for index, item in enumerate(right):
|
||||||
|
if _is_clear(item):
|
||||||
|
last_clear = index
|
||||||
|
|
||||||
|
if last_clear >= 0:
|
||||||
|
return [item for item in right[last_clear + 1 :] if not _is_clear(item)]
|
||||||
|
|
||||||
|
base = list(left or [])
|
||||||
|
base.extend(item for item in right if not _is_clear(item))
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
def _dict_merge_with_tombstones_reducer(
|
||||||
|
left: dict[Any, Any] | None,
|
||||||
|
right: dict[Any, Any] | None,
|
||||||
|
) -> dict[Any, Any]:
|
||||||
|
"""Merge ``right`` into ``left`` with two extra capabilities:
|
||||||
|
|
||||||
|
* Keys whose value is ``None`` are removed from the merged result
|
||||||
|
(tombstone semantics, matching the deepagents file-data reducer).
|
||||||
|
* The special key ``_CLEAR`` (with any truthy value) resets ``left`` to
|
||||||
|
``{}`` before merging the remaining keys from ``right``. This makes it
|
||||||
|
possible to atomically clear and reseed the dictionary in a single
|
||||||
|
update.
|
||||||
|
"""
|
||||||
|
if right is None:
|
||||||
|
return dict(left or {})
|
||||||
|
|
||||||
|
if _CLEAR in right or any(_is_clear(k) for k in right):
|
||||||
|
result: dict[Any, Any] = {}
|
||||||
|
for key, value in right.items():
|
||||||
|
if _is_clear(key):
|
||||||
|
continue
|
||||||
|
if value is None:
|
||||||
|
result.pop(key, None)
|
||||||
|
continue
|
||||||
|
result[key] = value
|
||||||
|
return result
|
||||||
|
|
||||||
|
if left is None:
|
||||||
|
return {key: value for key, value in right.items() if value is not None}
|
||||||
|
|
||||||
|
result = dict(left)
|
||||||
|
for key, value in right.items():
|
||||||
|
if value is None:
|
||||||
|
result.pop(key, None)
|
||||||
|
else:
|
||||||
|
result[key] = value
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _initial_filesystem_state() -> dict[str, Any]:
|
||||||
|
"""Default empty values for SurfSense filesystem state fields.
|
||||||
|
|
||||||
|
Consumers should always treat these fields as ``state.get(key) or
|
||||||
|
DEFAULT`` so that fresh threads (without checkpointed state) work
|
||||||
|
correctly.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"cwd": "/documents",
|
||||||
|
"staged_dirs": [],
|
||||||
|
"pending_moves": [],
|
||||||
|
"doc_id_by_path": {},
|
||||||
|
"dirty_paths": [],
|
||||||
|
"kb_priority": [],
|
||||||
|
"kb_matched_chunk_ids": {},
|
||||||
|
"kb_anon_doc": None,
|
||||||
|
"tree_version": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"_CLEAR",
|
||||||
|
"_add_unique_reducer",
|
||||||
|
"_dict_merge_with_tombstones_reducer",
|
||||||
|
"_initial_filesystem_state",
|
||||||
|
"_list_append_reducer",
|
||||||
|
"_replace_reducer",
|
||||||
|
]
|
||||||
|
|
@ -332,7 +332,7 @@ _TOOL_INSTRUCTIONS["scrape_webpage"] = """
|
||||||
* When a user asks to "get", "fetch", "pull", "grab", "scrape", or "read" content from a URL
|
* When a user asks to "get", "fetch", "pull", "grab", "scrape", or "read" content from a URL
|
||||||
* When the user wants live/dynamic data from a specific webpage (e.g., tables, scores, stats, prices)
|
* When the user wants live/dynamic data from a specific webpage (e.g., tables, scores, stats, prices)
|
||||||
* When a URL was mentioned earlier in the conversation and the user asks for its actual content
|
* When a URL was mentioned earlier in the conversation and the user asks for its actual content
|
||||||
* When preloaded `/documents/` data is insufficient and the user wants more
|
* When `/documents/` knowledge-base data is insufficient and the user wants more
|
||||||
- Trigger scenarios:
|
- Trigger scenarios:
|
||||||
* "Read this article and summarize it"
|
* "Read this article and summarize it"
|
||||||
* "What does this page say about X?"
|
* "What does this page say about X?"
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,12 @@ from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.db import ImageGeneration, ImageGenerationConfig, SearchSpace
|
from app.db import (
|
||||||
|
ImageGeneration,
|
||||||
|
ImageGenerationConfig,
|
||||||
|
SearchSpace,
|
||||||
|
shielded_async_session,
|
||||||
|
)
|
||||||
from app.services.image_gen_router_service import (
|
from app.services.image_gen_router_service import (
|
||||||
IMAGE_GEN_AUTO_MODE_ID,
|
IMAGE_GEN_AUTO_MODE_ID,
|
||||||
ImageGenRouterService,
|
ImageGenRouterService,
|
||||||
|
|
@ -70,8 +75,13 @@ def create_generate_image_tool(
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
search_space_id: The search space ID (for config resolution)
|
search_space_id: The search space ID (for config resolution)
|
||||||
db_session: Async database session
|
db_session: Reserved for compatibility with the tool registry.
|
||||||
|
The streaming task's ``AsyncSession`` is shared by every tool;
|
||||||
|
because AsyncSession is not concurrency-safe, parallel tool calls
|
||||||
|
would interleave flushes (e.g. podcast + image in the same step)
|
||||||
|
and poison the transaction. This tool opens its own session.
|
||||||
"""
|
"""
|
||||||
|
del db_session # use a fresh per-call session, see below
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def generate_image(
|
async def generate_image(
|
||||||
|
|
@ -93,110 +103,119 @@ def create_generate_image_tool(
|
||||||
A dictionary containing the generated image(s) for display in the chat.
|
A dictionary containing the generated image(s) for display in the chat.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Resolve the image generation config from the search space preference
|
# Use a per-call session so concurrent tool calls don't share an
|
||||||
result = await db_session.execute(
|
# AsyncSession (which is not concurrency-safe). The streaming
|
||||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
# task's session is shared across every tool; without isolation,
|
||||||
)
|
# autoflushes from a concurrent writer poison this tool too.
|
||||||
search_space = result.scalars().first()
|
async with shielded_async_session() as session:
|
||||||
if not search_space:
|
result = await session.execute(
|
||||||
return {"error": "Search space not found"}
|
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||||
|
|
||||||
config_id = (
|
|
||||||
search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build generation kwargs
|
|
||||||
# NOTE: size, quality, and style are intentionally NOT passed.
|
|
||||||
# Different models support different values for these params
|
|
||||||
# (e.g. DALL-E 3 wants "hd"/"standard" for quality while
|
|
||||||
# gpt-image-1 wants "high"/"medium"/"low"; size options also
|
|
||||||
# differ). Letting the model use its own defaults avoids errors.
|
|
||||||
gen_kwargs: dict[str, Any] = {}
|
|
||||||
if n is not None and n > 1:
|
|
||||||
gen_kwargs["n"] = n
|
|
||||||
|
|
||||||
# Call litellm based on config type
|
|
||||||
if is_image_gen_auto_mode(config_id):
|
|
||||||
if not ImageGenRouterService.is_initialized():
|
|
||||||
return {
|
|
||||||
"error": "No image generation models configured. "
|
|
||||||
"Please add an image model in Settings > Image Models."
|
|
||||||
}
|
|
||||||
response = await ImageGenRouterService.aimage_generation(
|
|
||||||
prompt=prompt, model="auto", **gen_kwargs
|
|
||||||
)
|
)
|
||||||
elif config_id < 0:
|
search_space = result.scalars().first()
|
||||||
cfg = _get_global_image_gen_config(config_id)
|
if not search_space:
|
||||||
if not cfg:
|
return {"error": "Search space not found"}
|
||||||
return {"error": f"Image generation config {config_id} not found"}
|
|
||||||
|
|
||||||
model_string = _build_model_string(
|
config_id = (
|
||||||
cfg.get("provider", ""),
|
search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
||||||
cfg["model_name"],
|
|
||||||
cfg.get("custom_provider"),
|
|
||||||
)
|
)
|
||||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
|
||||||
if cfg.get("api_base"):
|
|
||||||
gen_kwargs["api_base"] = cfg["api_base"]
|
|
||||||
if cfg.get("api_version"):
|
|
||||||
gen_kwargs["api_version"] = cfg["api_version"]
|
|
||||||
if cfg.get("litellm_params"):
|
|
||||||
gen_kwargs.update(cfg["litellm_params"])
|
|
||||||
|
|
||||||
response = await aimage_generation(
|
# Build generation kwargs
|
||||||
prompt=prompt, model=model_string, **gen_kwargs
|
# NOTE: size, quality, and style are intentionally NOT passed.
|
||||||
)
|
# Different models support different values for these params
|
||||||
else:
|
# (e.g. DALL-E 3 wants "hd"/"standard" for quality while
|
||||||
# Positive ID = user-created ImageGenerationConfig
|
# gpt-image-1 wants "high"/"medium"/"low"; size options also
|
||||||
cfg_result = await db_session.execute(
|
# differ). Letting the model use its own defaults avoids errors.
|
||||||
select(ImageGenerationConfig).filter(
|
gen_kwargs: dict[str, Any] = {}
|
||||||
ImageGenerationConfig.id == config_id
|
if n is not None and n > 1:
|
||||||
|
gen_kwargs["n"] = n
|
||||||
|
|
||||||
|
# Call litellm based on config type
|
||||||
|
if is_image_gen_auto_mode(config_id):
|
||||||
|
if not ImageGenRouterService.is_initialized():
|
||||||
|
return {
|
||||||
|
"error": "No image generation models configured. "
|
||||||
|
"Please add an image model in Settings > Image Models."
|
||||||
|
}
|
||||||
|
response = await ImageGenRouterService.aimage_generation(
|
||||||
|
prompt=prompt, model="auto", **gen_kwargs
|
||||||
)
|
)
|
||||||
)
|
elif config_id < 0:
|
||||||
db_cfg = cfg_result.scalars().first()
|
cfg = _get_global_image_gen_config(config_id)
|
||||||
if not db_cfg:
|
if not cfg:
|
||||||
return {"error": f"Image generation config {config_id} not found"}
|
return {
|
||||||
|
"error": f"Image generation config {config_id} not found"
|
||||||
|
}
|
||||||
|
|
||||||
model_string = _build_model_string(
|
model_string = _build_model_string(
|
||||||
db_cfg.provider.value,
|
cfg.get("provider", ""),
|
||||||
db_cfg.model_name,
|
cfg["model_name"],
|
||||||
db_cfg.custom_provider,
|
cfg.get("custom_provider"),
|
||||||
)
|
)
|
||||||
gen_kwargs["api_key"] = db_cfg.api_key
|
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||||
if db_cfg.api_base:
|
if cfg.get("api_base"):
|
||||||
gen_kwargs["api_base"] = db_cfg.api_base
|
gen_kwargs["api_base"] = cfg["api_base"]
|
||||||
if db_cfg.api_version:
|
if cfg.get("api_version"):
|
||||||
gen_kwargs["api_version"] = db_cfg.api_version
|
gen_kwargs["api_version"] = cfg["api_version"]
|
||||||
if db_cfg.litellm_params:
|
if cfg.get("litellm_params"):
|
||||||
gen_kwargs.update(db_cfg.litellm_params)
|
gen_kwargs.update(cfg["litellm_params"])
|
||||||
|
|
||||||
response = await aimage_generation(
|
response = await aimage_generation(
|
||||||
prompt=prompt, model=model_string, **gen_kwargs
|
prompt=prompt, model=model_string, **gen_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Positive ID = user-created ImageGenerationConfig
|
||||||
|
cfg_result = await session.execute(
|
||||||
|
select(ImageGenerationConfig).filter(
|
||||||
|
ImageGenerationConfig.id == config_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
db_cfg = cfg_result.scalars().first()
|
||||||
|
if not db_cfg:
|
||||||
|
return {
|
||||||
|
"error": f"Image generation config {config_id} not found"
|
||||||
|
}
|
||||||
|
|
||||||
|
model_string = _build_model_string(
|
||||||
|
db_cfg.provider.value,
|
||||||
|
db_cfg.model_name,
|
||||||
|
db_cfg.custom_provider,
|
||||||
|
)
|
||||||
|
gen_kwargs["api_key"] = db_cfg.api_key
|
||||||
|
if db_cfg.api_base:
|
||||||
|
gen_kwargs["api_base"] = db_cfg.api_base
|
||||||
|
if db_cfg.api_version:
|
||||||
|
gen_kwargs["api_version"] = db_cfg.api_version
|
||||||
|
if db_cfg.litellm_params:
|
||||||
|
gen_kwargs.update(db_cfg.litellm_params)
|
||||||
|
|
||||||
|
response = await aimage_generation(
|
||||||
|
prompt=prompt, model=model_string, **gen_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse the response and store in DB
|
||||||
|
response_dict = (
|
||||||
|
response.model_dump()
|
||||||
|
if hasattr(response, "model_dump")
|
||||||
|
else dict(response)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse the response and store in DB
|
# Generate a random access token for this image
|
||||||
response_dict = (
|
access_token = generate_image_token()
|
||||||
response.model_dump()
|
|
||||||
if hasattr(response, "model_dump")
|
|
||||||
else dict(response)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate a random access token for this image
|
# Save to image_generations table for history
|
||||||
access_token = generate_image_token()
|
db_image_gen = ImageGeneration(
|
||||||
|
prompt=prompt,
|
||||||
# Save to image_generations table for history
|
model=getattr(response, "_hidden_params", {}).get("model"),
|
||||||
db_image_gen = ImageGeneration(
|
n=n,
|
||||||
prompt=prompt,
|
image_generation_config_id=config_id,
|
||||||
model=getattr(response, "_hidden_params", {}).get("model"),
|
response_data=response_dict,
|
||||||
n=n,
|
search_space_id=search_space_id,
|
||||||
image_generation_config_id=config_id,
|
access_token=access_token,
|
||||||
response_data=response_dict,
|
)
|
||||||
search_space_id=search_space_id,
|
session.add(db_image_gen)
|
||||||
access_token=access_token,
|
await session.commit()
|
||||||
)
|
await session.refresh(db_image_gen)
|
||||||
db_session.add(db_image_gen)
|
db_image_gen_id = db_image_gen.id
|
||||||
await db_session.commit()
|
|
||||||
await db_session.refresh(db_image_gen)
|
|
||||||
|
|
||||||
# Extract image URLs from response
|
# Extract image URLs from response
|
||||||
images = response_dict.get("data", [])
|
images = response_dict.get("data", [])
|
||||||
|
|
@ -217,7 +236,7 @@ def create_generate_image_tool(
|
||||||
backend_url = config.BACKEND_URL or "http://localhost:8000"
|
backend_url = config.BACKEND_URL or "http://localhost:8000"
|
||||||
image_url = (
|
image_url = (
|
||||||
f"{backend_url}/api/v1/image-generations/"
|
f"{backend_url}/api/v1/image-generations/"
|
||||||
f"{db_image_gen.id}/image?token={access_token}"
|
f"{db_image_gen_id}/image?token={access_token}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return {"error": "No displayable image data in the response"}
|
return {"error": "No displayable image data in the response"}
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from typing import Any
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import Podcast, PodcastStatus
|
from app.db import Podcast, PodcastStatus, shielded_async_session
|
||||||
|
|
||||||
|
|
||||||
def create_generate_podcast_tool(
|
def create_generate_podcast_tool(
|
||||||
|
|
@ -27,12 +27,16 @@ def create_generate_podcast_tool(
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
search_space_id: The user's search space ID
|
search_space_id: The user's search space ID
|
||||||
db_session: Database session for creating the podcast record
|
db_session: Reserved for future read-side use; the row is written via a
|
||||||
|
fresh, tool-local session so parallel tool calls (e.g. podcast +
|
||||||
|
video presentation in the same agent step) don't share an
|
||||||
|
``AsyncSession`` (which is not concurrency-safe).
|
||||||
thread_id: The chat thread ID for associating the podcast
|
thread_id: The chat thread ID for associating the podcast
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A configured tool function for generating podcasts
|
A configured tool function for generating podcasts
|
||||||
"""
|
"""
|
||||||
|
del db_session # writes use a fresh tool-local session, see below
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def generate_podcast(
|
async def generate_podcast(
|
||||||
|
|
@ -64,32 +68,40 @@ def create_generate_podcast_tool(
|
||||||
- message: Status message (or "error" field if status is failed)
|
- message: Status message (or "error" field if status is failed)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
podcast = Podcast(
|
# Open a fresh session per call. The streaming task's session is
|
||||||
title=podcast_title,
|
# shared between every tool, and ``AsyncSession`` is NOT safe for
|
||||||
status=PodcastStatus.PENDING,
|
# concurrent use: when the LLM emits parallel tool calls, two
|
||||||
search_space_id=search_space_id,
|
# concurrent ``add()`` / ``commit()`` paths interleave and the
|
||||||
thread_id=thread_id,
|
# second one hits "Session.add() during flush" → the transaction
|
||||||
)
|
# is poisoned for both tools.
|
||||||
db_session.add(podcast)
|
async with shielded_async_session() as session:
|
||||||
await db_session.commit()
|
podcast = Podcast(
|
||||||
await db_session.refresh(podcast)
|
title=podcast_title,
|
||||||
|
status=PodcastStatus.PENDING,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
)
|
||||||
|
session.add(podcast)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(podcast)
|
||||||
|
podcast_id = podcast.id
|
||||||
|
|
||||||
from app.tasks.celery_tasks.podcast_tasks import (
|
from app.tasks.celery_tasks.podcast_tasks import (
|
||||||
generate_content_podcast_task,
|
generate_content_podcast_task,
|
||||||
)
|
)
|
||||||
|
|
||||||
task = generate_content_podcast_task.delay(
|
task = generate_content_podcast_task.delay(
|
||||||
podcast_id=podcast.id,
|
podcast_id=podcast_id,
|
||||||
source_content=source_content,
|
source_content=source_content,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
user_prompt=user_prompt,
|
user_prompt=user_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"[generate_podcast] Created podcast {podcast.id}, task: {task.id}")
|
print(f"[generate_podcast] Created podcast {podcast_id}, task: {task.id}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": PodcastStatus.PENDING.value,
|
"status": PodcastStatus.PENDING.value,
|
||||||
"podcast_id": podcast.id,
|
"podcast_id": podcast_id,
|
||||||
"title": podcast_title,
|
"title": podcast_title,
|
||||||
"message": "Podcast generation started. This may take a few minutes.",
|
"message": "Podcast generation started. This may take a few minutes.",
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from typing import Any
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import VideoPresentation, VideoPresentationStatus
|
from app.db import VideoPresentation, VideoPresentationStatus, shielded_async_session
|
||||||
|
|
||||||
|
|
||||||
def create_generate_video_presentation_tool(
|
def create_generate_video_presentation_tool(
|
||||||
|
|
@ -23,8 +23,11 @@ def create_generate_video_presentation_tool(
|
||||||
Factory function to create the generate_video_presentation tool with injected dependencies.
|
Factory function to create the generate_video_presentation tool with injected dependencies.
|
||||||
|
|
||||||
Pre-creates video presentation record with pending status so the ID is available
|
Pre-creates video presentation record with pending status so the ID is available
|
||||||
immediately for frontend polling.
|
immediately for frontend polling. The row is written via a fresh, tool-local
|
||||||
|
session so parallel tool calls (e.g. video + podcast in the same agent step)
|
||||||
|
don't share an ``AsyncSession`` (which is not concurrency-safe).
|
||||||
"""
|
"""
|
||||||
|
del db_session # writes use a fresh tool-local session, see below
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def generate_video_presentation(
|
async def generate_video_presentation(
|
||||||
|
|
@ -42,34 +45,40 @@ def create_generate_video_presentation_tool(
|
||||||
user_prompt: Optional style/tone instructions.
|
user_prompt: Optional style/tone instructions.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
video_pres = VideoPresentation(
|
# See podcast.py for the rationale: parallel tool calls share the
|
||||||
title=video_title,
|
# streaming session, and AsyncSession is not concurrency-safe —
|
||||||
status=VideoPresentationStatus.PENDING,
|
# interleaved flushes produce "Session.add() during flush" and
|
||||||
search_space_id=search_space_id,
|
# poison the transaction for every concurrent tool.
|
||||||
thread_id=thread_id,
|
async with shielded_async_session() as session:
|
||||||
)
|
video_pres = VideoPresentation(
|
||||||
db_session.add(video_pres)
|
title=video_title,
|
||||||
await db_session.commit()
|
status=VideoPresentationStatus.PENDING,
|
||||||
await db_session.refresh(video_pres)
|
search_space_id=search_space_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
)
|
||||||
|
session.add(video_pres)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(video_pres)
|
||||||
|
video_pres_id = video_pres.id
|
||||||
|
|
||||||
from app.tasks.celery_tasks.video_presentation_tasks import (
|
from app.tasks.celery_tasks.video_presentation_tasks import (
|
||||||
generate_video_presentation_task,
|
generate_video_presentation_task,
|
||||||
)
|
)
|
||||||
|
|
||||||
task = generate_video_presentation_task.delay(
|
task = generate_video_presentation_task.delay(
|
||||||
video_presentation_id=video_pres.id,
|
video_presentation_id=video_pres_id,
|
||||||
source_content=source_content,
|
source_content=source_content,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
user_prompt=user_prompt,
|
user_prompt=user_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"[generate_video_presentation] Created video presentation {video_pres.id}, task: {task.id}"
|
f"[generate_video_presentation] Created video presentation {video_pres_id}, task: {task.id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": VideoPresentationStatus.PENDING.value,
|
"status": VideoPresentationStatus.PENDING.value,
|
||||||
"video_presentation_id": video_pres.id,
|
"video_presentation_id": video_pres_id,
|
||||||
"title": video_title,
|
"title": video_title,
|
||||||
"message": "Video presentation generation started. This may take a few minutes.",
|
"message": "Video presentation generation started. This may take a few minutes.",
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
||||||
from app.agents.new_chat.checkpointer import get_checkpointer
|
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||||
from app.agents.new_chat.filesystem_selection import FilesystemSelection
|
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||||
from app.agents.new_chat.llm_config import (
|
from app.agents.new_chat.llm_config import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
create_chat_litellm_from_agent_config,
|
create_chat_litellm_from_agent_config,
|
||||||
|
|
@ -42,6 +42,9 @@ from app.agents.new_chat.memory_extraction import (
|
||||||
extract_and_save_memory,
|
extract_and_save_memory,
|
||||||
extract_and_save_team_memory,
|
extract_and_save_team_memory,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.middleware.kb_persistence import (
|
||||||
|
commit_staged_filesystem_state,
|
||||||
|
)
|
||||||
from app.db import (
|
from app.db import (
|
||||||
ChatVisibility,
|
ChatVisibility,
|
||||||
NewChatMessage,
|
NewChatMessage,
|
||||||
|
|
@ -258,6 +261,10 @@ async def _stream_agent_events(
|
||||||
initial_step_id: str | None = None,
|
initial_step_id: str | None = None,
|
||||||
initial_step_title: str = "",
|
initial_step_title: str = "",
|
||||||
initial_step_items: list[str] | None = None,
|
initial_step_items: list[str] | None = None,
|
||||||
|
*,
|
||||||
|
fallback_commit_search_space_id: int | None = None,
|
||||||
|
fallback_commit_created_by_id: str | None = None,
|
||||||
|
fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""Shared async generator that streams and formats astream_events from the agent.
|
"""Shared async generator that streams and formats astream_events from the agent.
|
||||||
|
|
||||||
|
|
@ -1280,6 +1287,40 @@ async def _stream_agent_events(
|
||||||
|
|
||||||
state = await agent.aget_state(config)
|
state = await agent.aget_state(config)
|
||||||
state_values = getattr(state, "values", {}) or {}
|
state_values = getattr(state, "values", {}) or {}
|
||||||
|
|
||||||
|
# Safety net: if astream_events was cancelled before
|
||||||
|
# KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work
|
||||||
|
# (dirty_paths / staged_dirs / pending_moves) will still be in the
|
||||||
|
# checkpointed state. Run the SAME shared commit helper here so the
|
||||||
|
# turn's writes don't get lost on client disconnect, then push the
|
||||||
|
# delta back into the graph using `as_node=...` so reducers fire as if
|
||||||
|
# the after_agent hook produced it.
|
||||||
|
if (
|
||||||
|
fallback_commit_filesystem_mode == FilesystemMode.CLOUD
|
||||||
|
and fallback_commit_search_space_id is not None
|
||||||
|
and (
|
||||||
|
(state_values.get("dirty_paths") or [])
|
||||||
|
or (state_values.get("staged_dirs") or [])
|
||||||
|
or (state_values.get("pending_moves") or [])
|
||||||
|
)
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
delta = await commit_staged_filesystem_state(
|
||||||
|
state_values,
|
||||||
|
search_space_id=fallback_commit_search_space_id,
|
||||||
|
created_by_id=fallback_commit_created_by_id,
|
||||||
|
filesystem_mode=fallback_commit_filesystem_mode,
|
||||||
|
dispatch_events=False,
|
||||||
|
)
|
||||||
|
if delta:
|
||||||
|
await agent.aupdate_state(
|
||||||
|
config,
|
||||||
|
delta,
|
||||||
|
as_node="KnowledgeBasePersistenceMiddleware.after_agent",
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
_perf_log.warning("[stream_new_chat] safety-net commit failed: %s", exc)
|
||||||
|
|
||||||
contract_state = state_values.get("file_operation_contract") or {}
|
contract_state = state_values.get("file_operation_contract") or {}
|
||||||
contract_turn_id = contract_state.get("turn_id")
|
contract_turn_id = contract_state.get("turn_id")
|
||||||
current_turn_id = config.get("configurable", {}).get("turn_id", "")
|
current_turn_id = config.get("configurable", {}).get("turn_id", "")
|
||||||
|
|
@ -1814,6 +1855,13 @@ async def stream_new_chat(
|
||||||
initial_step_id=initial_step_id,
|
initial_step_id=initial_step_id,
|
||||||
initial_step_title=initial_title,
|
initial_step_title=initial_title,
|
||||||
initial_step_items=initial_items,
|
initial_step_items=initial_items,
|
||||||
|
fallback_commit_search_space_id=search_space_id,
|
||||||
|
fallback_commit_created_by_id=user_id,
|
||||||
|
fallback_commit_filesystem_mode=(
|
||||||
|
filesystem_selection.mode
|
||||||
|
if filesystem_selection
|
||||||
|
else FilesystemMode.CLOUD
|
||||||
|
),
|
||||||
):
|
):
|
||||||
if not _first_event_logged:
|
if not _first_event_logged:
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
|
|
@ -2251,6 +2299,13 @@ async def stream_resume_chat(
|
||||||
streaming_service=streaming_service,
|
streaming_service=streaming_service,
|
||||||
result=stream_result,
|
result=stream_result,
|
||||||
step_prefix="thinking-resume",
|
step_prefix="thinking-resume",
|
||||||
|
fallback_commit_search_space_id=search_space_id,
|
||||||
|
fallback_commit_created_by_id=user_id,
|
||||||
|
fallback_commit_filesystem_mode=(
|
||||||
|
filesystem_selection.mode
|
||||||
|
if filesystem_selection
|
||||||
|
else FilesystemMode.CLOUD
|
||||||
|
),
|
||||||
):
|
):
|
||||||
if not _first_event_logged:
|
if not _first_event_logged:
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,198 @@
|
||||||
|
"""Tests for canonical virtual-path resolver helpers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.path_resolver import (
|
||||||
|
DOCUMENTS_ROOT,
|
||||||
|
PathIndex,
|
||||||
|
doc_to_virtual_path,
|
||||||
|
parse_doc_id_suffix,
|
||||||
|
parse_documents_path,
|
||||||
|
safe_filename,
|
||||||
|
safe_folder_segment,
|
||||||
|
virtual_path_to_doc,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
class TestSafeFilename:
|
||||||
|
def test_appends_xml_extension(self):
|
||||||
|
assert safe_filename("notes").endswith(".xml")
|
||||||
|
|
||||||
|
def test_strips_invalid_chars(self):
|
||||||
|
assert "/" not in safe_filename("a/b\\c.xml")
|
||||||
|
|
||||||
|
def test_falls_back_when_empty(self):
|
||||||
|
assert safe_filename("").endswith(".xml")
|
||||||
|
assert safe_filename("///") == "untitled.xml" or safe_filename("///").endswith(
|
||||||
|
".xml"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSafeFolderSegment:
|
||||||
|
def test_strips_path_separators(self):
|
||||||
|
assert "/" not in safe_folder_segment("a/b")
|
||||||
|
|
||||||
|
def test_falls_back(self):
|
||||||
|
assert safe_folder_segment("") == "folder"
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseDocIdSuffix:
|
||||||
|
def test_parses_suffix(self):
|
||||||
|
stem, doc_id = parse_doc_id_suffix("My Doc (42).xml")
|
||||||
|
assert stem == "My Doc"
|
||||||
|
assert doc_id == 42
|
||||||
|
|
||||||
|
def test_no_suffix_returns_none(self):
|
||||||
|
stem, doc_id = parse_doc_id_suffix("My Doc.xml")
|
||||||
|
assert stem == "My Doc"
|
||||||
|
assert doc_id is None
|
||||||
|
|
||||||
|
def test_no_xml_extension(self):
|
||||||
|
stem, doc_id = parse_doc_id_suffix("plain")
|
||||||
|
assert stem == "plain"
|
||||||
|
assert doc_id is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocToVirtualPath:
|
||||||
|
def test_root_when_no_folder(self):
|
||||||
|
index = PathIndex()
|
||||||
|
path = doc_to_virtual_path(doc_id=1, title="Hello", folder_id=None, index=index)
|
||||||
|
assert path == f"{DOCUMENTS_ROOT}/Hello.xml"
|
||||||
|
assert index.occupants[path] == 1
|
||||||
|
|
||||||
|
def test_collision_picks_doc_id_suffix(self):
|
||||||
|
index = PathIndex(occupants={f"{DOCUMENTS_ROOT}/Hello.xml": 7})
|
||||||
|
path = doc_to_virtual_path(doc_id=8, title="Hello", folder_id=None, index=index)
|
||||||
|
assert path == f"{DOCUMENTS_ROOT}/Hello (8).xml"
|
||||||
|
assert index.occupants[path] == 8
|
||||||
|
|
||||||
|
def test_uses_folder_path_when_known(self):
|
||||||
|
index = PathIndex(folder_paths={5: f"{DOCUMENTS_ROOT}/notes"})
|
||||||
|
path = doc_to_virtual_path(doc_id=2, title="A", folder_id=5, index=index)
|
||||||
|
assert path == f"{DOCUMENTS_ROOT}/notes/A.xml"
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseDocumentsPath:
|
||||||
|
def test_extracts_folder_parts_and_title(self):
|
||||||
|
parts, title = parse_documents_path(f"{DOCUMENTS_ROOT}/foo/bar/baz.xml")
|
||||||
|
assert parts == ["foo", "bar"]
|
||||||
|
assert title == "baz"
|
||||||
|
|
||||||
|
def test_strips_doc_id_suffix(self):
|
||||||
|
parts, title = parse_documents_path(f"{DOCUMENTS_ROOT}/foo/My Doc (12).xml")
|
||||||
|
assert parts == ["foo"]
|
||||||
|
assert title == "My Doc"
|
||||||
|
|
||||||
|
def test_non_documents_returns_empty(self):
|
||||||
|
assert parse_documents_path("/other/x.xml") == ([], "")
|
||||||
|
|
||||||
|
|
||||||
|
def _result_from_scalars(rows: list):
|
||||||
|
"""Build a fake SQLAlchemy ``Result`` whose ``.scalars().all()`` and
|
||||||
|
``.scalars().first()`` yield ``rows``."""
|
||||||
|
scalars = MagicMock()
|
||||||
|
scalars.all.return_value = list(rows)
|
||||||
|
scalars.first.return_value = rows[0] if rows else None
|
||||||
|
result = MagicMock()
|
||||||
|
result.scalars.return_value = scalars
|
||||||
|
result.scalar_one_or_none.return_value = None
|
||||||
|
result.first.return_value = None
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _result_from_one(value):
|
||||||
|
result = MagicMock()
|
||||||
|
result.scalar_one_or_none.return_value = value
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class TestVirtualPathToDoc:
|
||||||
|
"""Lookup must round-trip through ``safe_filename``'s lossy encoding.
|
||||||
|
|
||||||
|
The workspace tree displays ``safe_filename(title)`` as the basename, so
|
||||||
|
when the agent passes that basename back to a tool (move/edit/read) the
|
||||||
|
resolver must find the original document even though characters like
|
||||||
|
``:`` were replaced with ``_``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_falls_back_to_safe_filename_match_when_title_lossy(self):
|
||||||
|
# A Google Calendar-style title that contains a colon — safe_filename
|
||||||
|
# rewrites the colon to ``_``, so the literal title-equality lookup
|
||||||
|
# would miss this row.
|
||||||
|
original_title = "Calendar: Happy birthday!"
|
||||||
|
encoded_basename = safe_filename(original_title)
|
||||||
|
assert encoded_basename == "Calendar_ Happy birthday!.xml"
|
||||||
|
|
||||||
|
target_doc = SimpleNamespace(id=42, title=original_title, folder_id=None)
|
||||||
|
|
||||||
|
session = MagicMock()
|
||||||
|
# Each ``await session.execute(...)`` returns a fresh canned result.
|
||||||
|
# Order matches the resolver's lookup steps:
|
||||||
|
# 1) unique_identifier_hash → no match
|
||||||
|
# 2) literal title match → no match (lossy encoding)
|
||||||
|
# 3) folder scan → returns the row whose title encodes to basename
|
||||||
|
session.execute = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
_result_from_one(None),
|
||||||
|
_result_from_scalars([]),
|
||||||
|
_result_from_scalars([target_doc]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
document = await virtual_path_to_doc(
|
||||||
|
session,
|
||||||
|
search_space_id=5,
|
||||||
|
virtual_path=f"{DOCUMENTS_ROOT}/{encoded_basename}",
|
||||||
|
)
|
||||||
|
assert document is target_doc
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_when_no_doc_matches_safe_filename(self):
|
||||||
|
session = MagicMock()
|
||||||
|
session.execute = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
_result_from_one(None),
|
||||||
|
_result_from_scalars([]),
|
||||||
|
_result_from_scalars(
|
||||||
|
[SimpleNamespace(id=1, title="Something else", folder_id=None)]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
document = await virtual_path_to_doc(
|
||||||
|
session,
|
||||||
|
search_space_id=5,
|
||||||
|
virtual_path=f"{DOCUMENTS_ROOT}/Calendar_ Happy birthday!.xml",
|
||||||
|
)
|
||||||
|
assert document is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_literal_title_match_short_circuits_fallback(self):
|
||||||
|
# When the literal title query hits, the folder-scan fallback must
|
||||||
|
# NOT run (saves a query and avoids picking the wrong doc when two
|
||||||
|
# rows share a lossy encoding).
|
||||||
|
target_doc = SimpleNamespace(id=7, title="Plain Note", folder_id=None)
|
||||||
|
|
||||||
|
session = MagicMock()
|
||||||
|
session.execute = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
_result_from_one(None),
|
||||||
|
_result_from_scalars([target_doc]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
document = await virtual_path_to_doc(
|
||||||
|
session,
|
||||||
|
search_space_id=5,
|
||||||
|
virtual_path=f"{DOCUMENTS_ROOT}/Plain Note.xml",
|
||||||
|
)
|
||||||
|
assert document is target_doc
|
||||||
|
assert session.execute.await_count == 2
|
||||||
|
|
@ -0,0 +1,107 @@
|
||||||
|
"""Tests for SurfSense filesystem state reducers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.state_reducers import (
|
||||||
|
_CLEAR,
|
||||||
|
_add_unique_reducer,
|
||||||
|
_dict_merge_with_tombstones_reducer,
|
||||||
|
_initial_filesystem_state,
|
||||||
|
_list_append_reducer,
|
||||||
|
_replace_reducer,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
class TestReplaceReducer:
|
||||||
|
def test_right_wins_outright(self):
|
||||||
|
assert _replace_reducer("a", "b") == "b"
|
||||||
|
|
||||||
|
def test_none_right_returns_none(self):
|
||||||
|
assert _replace_reducer("a", None) is None
|
||||||
|
|
||||||
|
def test_none_left_returns_right(self):
|
||||||
|
assert _replace_reducer(None, "b") == "b"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAddUniqueReducer:
|
||||||
|
def test_appends_unique_items(self):
|
||||||
|
assert _add_unique_reducer(["a"], ["b", "c"]) == ["a", "b", "c"]
|
||||||
|
|
||||||
|
def test_dedupes_against_left(self):
|
||||||
|
assert _add_unique_reducer(["a", "b"], ["b", "c"]) == ["a", "b", "c"]
|
||||||
|
|
||||||
|
def test_dedupes_within_right(self):
|
||||||
|
assert _add_unique_reducer([], ["a", "a", "b"]) == ["a", "b"]
|
||||||
|
|
||||||
|
def test_clear_anywhere_resets_and_reseeds_with_after_items(self):
|
||||||
|
# _CLEAR semantics: only items AFTER the LAST _CLEAR are kept.
|
||||||
|
result = _add_unique_reducer(["x", "y"], ["a", _CLEAR, "b", "c"])
|
||||||
|
assert result == ["b", "c"]
|
||||||
|
|
||||||
|
def test_multiple_clears_use_last(self):
|
||||||
|
result = _add_unique_reducer(["x"], [_CLEAR, "a", _CLEAR, "b"])
|
||||||
|
assert result == ["b"]
|
||||||
|
|
||||||
|
def test_clear_only_resets_to_empty(self):
|
||||||
|
assert _add_unique_reducer(["x", "y"], [_CLEAR]) == []
|
||||||
|
|
||||||
|
def test_empty_right_keeps_left(self):
|
||||||
|
assert _add_unique_reducer(["a"], []) == ["a"]
|
||||||
|
assert _add_unique_reducer(["a"], None) == ["a"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestListAppendReducer:
|
||||||
|
def test_preserves_order_and_duplicates(self):
|
||||||
|
result = _list_append_reducer([{"a": 1}], [{"b": 2}, {"a": 1}])
|
||||||
|
assert result == [{"a": 1}, {"b": 2}, {"a": 1}]
|
||||||
|
|
||||||
|
def test_clear_resets_keeping_after_items(self):
|
||||||
|
result = _list_append_reducer([{"a": 1}], [{"old": 1}, _CLEAR, {"new": 2}])
|
||||||
|
assert result == [{"new": 2}]
|
||||||
|
|
||||||
|
|
||||||
|
class TestDictMergeWithTombstones:
|
||||||
|
def test_merges_keys(self):
|
||||||
|
assert _dict_merge_with_tombstones_reducer({"a": 1}, {"b": 2}) == {
|
||||||
|
"a": 1,
|
||||||
|
"b": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_none_value_deletes_key(self):
|
||||||
|
result = _dict_merge_with_tombstones_reducer({"a": 1, "b": 2}, {"a": None})
|
||||||
|
assert result == {"b": 2}
|
||||||
|
|
||||||
|
def test_clear_resets_then_merges(self):
|
||||||
|
result = _dict_merge_with_tombstones_reducer(
|
||||||
|
{"a": 1, "b": 2}, {_CLEAR: True, "c": 3}
|
||||||
|
)
|
||||||
|
assert result == {"c": 3}
|
||||||
|
|
||||||
|
def test_clear_keeps_only_post_clear_non_none(self):
|
||||||
|
result = _dict_merge_with_tombstones_reducer(
|
||||||
|
{"a": 1}, {_CLEAR: True, "b": 2, "c": None}
|
||||||
|
)
|
||||||
|
assert result == {"b": 2}
|
||||||
|
|
||||||
|
def test_none_left_handled(self):
|
||||||
|
assert _dict_merge_with_tombstones_reducer(None, {"a": 1, "b": None}) == {
|
||||||
|
"a": 1
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestInitialFilesystemState:
|
||||||
|
def test_default_shape(self):
|
||||||
|
state = _initial_filesystem_state()
|
||||||
|
assert state["cwd"] == "/documents"
|
||||||
|
assert state["staged_dirs"] == []
|
||||||
|
assert state["pending_moves"] == []
|
||||||
|
assert state["doc_id_by_path"] == {}
|
||||||
|
assert state["dirty_paths"] == []
|
||||||
|
assert state["kb_priority"] == []
|
||||||
|
assert state["kb_matched_chunk_ids"] == {}
|
||||||
|
assert state["kb_anon_doc"] is None
|
||||||
|
assert state["tree_version"] == 0
|
||||||
|
|
@ -36,11 +36,18 @@ def test_backend_resolver_returns_multi_root_backend_for_single_root(tmp_path: P
|
||||||
def test_backend_resolver_uses_cloud_mode_by_default():
|
def test_backend_resolver_uses_cloud_mode_by_default():
|
||||||
resolver = build_backend_resolver(FilesystemSelection())
|
resolver = build_backend_resolver(FilesystemSelection())
|
||||||
backend = resolver(_RuntimeStub())
|
backend = resolver(_RuntimeStub())
|
||||||
# StateBackend class name check keeps this test decoupled
|
# When no search_space_id is provided we fall back to StateBackend so
|
||||||
# from internal deepagents runtime class identity.
|
# sub-agents / tests without DB access still work.
|
||||||
assert backend.__class__.__name__ == "StateBackend"
|
assert backend.__class__.__name__ == "StateBackend"
|
||||||
|
|
||||||
|
|
||||||
|
def test_backend_resolver_uses_kb_postgres_in_cloud_with_search_space():
|
||||||
|
resolver = build_backend_resolver(FilesystemSelection(), search_space_id=42)
|
||||||
|
backend = resolver(_RuntimeStub())
|
||||||
|
assert backend.__class__.__name__ == "KBPostgresBackend"
|
||||||
|
assert backend.search_space_id == 42
|
||||||
|
|
||||||
|
|
||||||
def test_backend_resolver_returns_multi_root_backend_for_multiple_roots(tmp_path: Path):
|
def test_backend_resolver_returns_multi_root_backend_for_multiple_roots(tmp_path: Path):
|
||||||
root_one = tmp_path / "resume"
|
root_one = tmp_path / "resume"
|
||||||
root_two = tmp_path / "notes"
|
root_two = tmp_path / "notes"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,204 @@
|
||||||
|
"""Unit tests for the SurfSense filesystem middleware new behaviors.
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
* cloud cwd defaults to ``/documents`` and relative paths resolve under it
|
||||||
|
* cloud writes outside ``/documents/`` are rejected unless basename starts
|
||||||
|
with ``temp_``
|
||||||
|
* cloud writes/edits to the anonymous document are rejected (read-only)
|
||||||
|
* helper methods on the middleware (``_resolve_relative``,
|
||||||
|
``_check_cloud_write_namespace``, ``_default_cwd``)
|
||||||
|
|
||||||
|
These tests use ``__new__`` to bypass the heavy ``__init__`` and exercise
|
||||||
|
the helper methods directly so the test surface stays narrow and fast.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
|
from app.agents.new_chat.middleware.filesystem import (
|
||||||
|
SurfSenseFilesystemMiddleware,
|
||||||
|
_build_filesystem_system_prompt,
|
||||||
|
_build_tool_descriptions,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _make_middleware(mode: FilesystemMode = FilesystemMode.CLOUD):
|
||||||
|
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||||
|
middleware._filesystem_mode = mode
|
||||||
|
return middleware
|
||||||
|
|
||||||
|
|
||||||
|
def _runtime(state: dict | None = None) -> SimpleNamespace:
|
||||||
|
return SimpleNamespace(state=state or {})
|
||||||
|
|
||||||
|
|
||||||
|
class TestCloudCwdDefaults:
|
||||||
|
def test_default_cwd_in_cloud_is_documents_root(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
assert m._default_cwd() == "/documents"
|
||||||
|
|
||||||
|
def test_default_cwd_in_desktop_is_root(self):
|
||||||
|
m = _make_middleware(FilesystemMode.DESKTOP_LOCAL_FOLDER)
|
||||||
|
assert m._default_cwd() == "/"
|
||||||
|
|
||||||
|
def test_current_cwd_uses_state_when_set(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime({"cwd": "/documents/notes"})
|
||||||
|
assert m._current_cwd(runtime) == "/documents/notes"
|
||||||
|
|
||||||
|
def test_current_cwd_falls_back_to_default(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime({})
|
||||||
|
assert m._current_cwd(runtime) == "/documents"
|
||||||
|
|
||||||
|
def test_current_cwd_ignores_invalid(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime({"cwd": "not-absolute"})
|
||||||
|
assert m._current_cwd(runtime) == "/documents"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRelativePathResolution:
|
||||||
|
def test_relative_path_resolves_against_cwd(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime({"cwd": "/documents/projects"})
|
||||||
|
assert (
|
||||||
|
m._resolve_relative("notes.md", runtime) == "/documents/projects/notes.md"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_relative_path_with_dotdot(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime({"cwd": "/documents/a/b"})
|
||||||
|
assert m._resolve_relative("../c.md", runtime) == "/documents/a/c.md"
|
||||||
|
|
||||||
|
def test_absolute_path_is_kept(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime({"cwd": "/documents"})
|
||||||
|
assert m._resolve_relative("/other/x.md", runtime) == "/other/x.md"
|
||||||
|
|
||||||
|
def test_empty_path_returns_cwd(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime({"cwd": "/documents/projects"})
|
||||||
|
assert m._resolve_relative("", runtime) == "/documents/projects"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCloudWriteNamespacePolicy:
|
||||||
|
def test_documents_path_allowed(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime()
|
||||||
|
assert m._check_cloud_write_namespace("/documents/foo.md", runtime) is None
|
||||||
|
|
||||||
|
def test_documents_root_allowed(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime()
|
||||||
|
assert m._check_cloud_write_namespace("/documents", runtime) is None
|
||||||
|
|
||||||
|
def test_temp_basename_anywhere_allowed(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime()
|
||||||
|
assert m._check_cloud_write_namespace("/temp_scratch.md", runtime) is None
|
||||||
|
assert m._check_cloud_write_namespace("/foo/temp_x.md", runtime) is None
|
||||||
|
assert m._check_cloud_write_namespace("/documents/temp_x.md", runtime) is None
|
||||||
|
|
||||||
|
def test_other_paths_rejected(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime()
|
||||||
|
err = m._check_cloud_write_namespace("/foo/bar.md", runtime)
|
||||||
|
assert err is not None
|
||||||
|
assert "must target /documents" in err
|
||||||
|
|
||||||
|
def test_anon_doc_path_is_read_only(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime(
|
||||||
|
{
|
||||||
|
"kb_anon_doc": {
|
||||||
|
"path": "/documents/uploaded.xml",
|
||||||
|
"title": "uploaded",
|
||||||
|
"content": "",
|
||||||
|
"chunks": [],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
err = m._check_cloud_write_namespace("/documents/uploaded.xml", runtime)
|
||||||
|
assert err is not None
|
||||||
|
assert "read-only" in err
|
||||||
|
|
||||||
|
def test_desktop_mode_skips_namespace_policy(self):
|
||||||
|
m = _make_middleware(FilesystemMode.DESKTOP_LOCAL_FOLDER)
|
||||||
|
runtime = _runtime()
|
||||||
|
assert m._check_cloud_write_namespace("/random/path.md", runtime) is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestModeSpecificPrompts:
|
||||||
|
"""The prompt and tool descriptions must only describe the active mode.
|
||||||
|
|
||||||
|
Cross-mode noise wastes tokens and confuses the model with rules it
|
||||||
|
cannot use this session.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_cloud_prompt_omits_desktop_section(self):
|
||||||
|
prompt = _build_filesystem_system_prompt(
|
||||||
|
FilesystemMode.CLOUD, sandbox_available=False
|
||||||
|
)
|
||||||
|
assert "Local Folder Mode" not in prompt
|
||||||
|
assert "mount-prefixed" not in prompt
|
||||||
|
assert "Persistence Rules" in prompt
|
||||||
|
assert "/documents" in prompt
|
||||||
|
assert "temp_" in prompt
|
||||||
|
|
||||||
|
def test_desktop_prompt_omits_cloud_persistence_rules(self):
|
||||||
|
prompt = _build_filesystem_system_prompt(
|
||||||
|
FilesystemMode.DESKTOP_LOCAL_FOLDER, sandbox_available=False
|
||||||
|
)
|
||||||
|
assert "Persistence Rules" not in prompt
|
||||||
|
assert "Workspace Tree" not in prompt
|
||||||
|
assert "<priority_documents>" not in prompt
|
||||||
|
assert "Local Folder Mode" in prompt
|
||||||
|
assert "mount-prefixed" in prompt
|
||||||
|
|
||||||
|
def test_cloud_tool_descs_omit_desktop_phrases(self):
|
||||||
|
descs = _build_tool_descriptions(FilesystemMode.CLOUD)
|
||||||
|
for name in (
|
||||||
|
"write_file",
|
||||||
|
"edit_file",
|
||||||
|
"move_file",
|
||||||
|
"mkdir",
|
||||||
|
"list_tree",
|
||||||
|
"grep",
|
||||||
|
):
|
||||||
|
text = descs[name]
|
||||||
|
assert "Desktop" not in text, f"{name} leaks desktop hints"
|
||||||
|
assert "Cloud mode:" not in text, f"{name} qualifies a cloud-only desc"
|
||||||
|
|
||||||
|
def test_desktop_tool_descs_omit_cloud_phrases(self):
|
||||||
|
descs = _build_tool_descriptions(FilesystemMode.DESKTOP_LOCAL_FOLDER)
|
||||||
|
for name in (
|
||||||
|
"write_file",
|
||||||
|
"edit_file",
|
||||||
|
"move_file",
|
||||||
|
"mkdir",
|
||||||
|
"list_tree",
|
||||||
|
"grep",
|
||||||
|
):
|
||||||
|
text = descs[name]
|
||||||
|
assert "Cloud" not in text, f"{name} leaks cloud hints"
|
||||||
|
assert "/documents/" not in text, f"{name} mentions cloud namespace"
|
||||||
|
assert "temp_" not in text, f"{name} mentions cloud temp_ semantics"
|
||||||
|
|
||||||
|
def test_sandbox_addendum_appended_when_available(self):
|
||||||
|
prompt = _build_filesystem_system_prompt(
|
||||||
|
FilesystemMode.CLOUD, sandbox_available=True
|
||||||
|
)
|
||||||
|
assert "execute_code" in prompt
|
||||||
|
assert "Code Execution" in prompt
|
||||||
|
|
||||||
|
def test_sandbox_addendum_absent_when_unavailable(self):
|
||||||
|
prompt = _build_filesystem_system_prompt(
|
||||||
|
FilesystemMode.CLOUD, sandbox_available=False
|
||||||
|
)
|
||||||
|
assert "execute_code" not in prompt
|
||||||
|
|
@ -11,25 +11,6 @@ from app.agents.new_chat.middleware.multi_root_local_folder_backend import (
|
||||||
pytestmark = pytest.mark.unit
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
class _BackendWithRawRead:
|
|
||||||
def __init__(self, content: str) -> None:
|
|
||||||
self._content = content
|
|
||||||
|
|
||||||
def read(self, file_path: str, offset: int = 0, limit: int = 200000) -> str:
|
|
||||||
del file_path, offset, limit
|
|
||||||
return " 1\tline1\n 2\tline2"
|
|
||||||
|
|
||||||
async def aread(self, file_path: str, offset: int = 0, limit: int = 200000) -> str:
|
|
||||||
return self.read(file_path, offset, limit)
|
|
||||||
|
|
||||||
def read_raw(self, file_path: str) -> str:
|
|
||||||
del file_path
|
|
||||||
return self._content
|
|
||||||
|
|
||||||
async def aread_raw(self, file_path: str) -> str:
|
|
||||||
return self.read_raw(file_path)
|
|
||||||
|
|
||||||
|
|
||||||
class _RuntimeNoSuggestedPath:
|
class _RuntimeNoSuggestedPath:
|
||||||
state = {"file_operation_contract": {}}
|
state = {"file_operation_contract": {}}
|
||||||
|
|
||||||
|
|
@ -39,40 +20,19 @@ class _RuntimeWithSuggestedPath:
|
||||||
self.state = {"file_operation_contract": {"suggested_path": suggested_path}}
|
self.state = {"file_operation_contract": {"suggested_path": suggested_path}}
|
||||||
|
|
||||||
|
|
||||||
def test_verify_written_content_prefers_raw_sync() -> None:
|
def test_contract_suggested_path_falls_back_to_documents_notes_md() -> None:
|
||||||
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
|
||||||
expected = "line1\nline2"
|
|
||||||
backend = _BackendWithRawRead(expected)
|
|
||||||
|
|
||||||
verify_error = middleware._verify_written_content_sync(
|
|
||||||
backend=backend,
|
|
||||||
path="/note.md",
|
|
||||||
expected_content=expected,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert verify_error is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_contract_suggested_path_falls_back_to_notes_md() -> None:
|
|
||||||
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||||
middleware._filesystem_mode = FilesystemMode.CLOUD
|
middleware._filesystem_mode = FilesystemMode.CLOUD
|
||||||
suggested = middleware._get_contract_suggested_path(_RuntimeNoSuggestedPath()) # type: ignore[arg-type]
|
suggested = middleware._get_contract_suggested_path(_RuntimeNoSuggestedPath()) # type: ignore[arg-type]
|
||||||
assert suggested == "/notes.md"
|
# Cloud default cwd is /documents so the fallback lands in the KB.
|
||||||
|
assert suggested == "/documents/notes.md"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_contract_suggested_path_falls_back_to_root_notes_md_in_desktop() -> None:
|
||||||
async def test_verify_written_content_prefers_raw_async() -> None:
|
|
||||||
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||||
expected = "line1\nline2"
|
middleware._filesystem_mode = FilesystemMode.DESKTOP_LOCAL_FOLDER
|
||||||
backend = _BackendWithRawRead(expected)
|
suggested = middleware._get_contract_suggested_path(_RuntimeNoSuggestedPath()) # type: ignore[arg-type]
|
||||||
|
assert suggested == "/notes.md"
|
||||||
verify_error = await middleware._verify_written_content_async(
|
|
||||||
backend=backend,
|
|
||||||
path="/note.md",
|
|
||||||
expected_content=expected,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert verify_error is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_local_mount_path_prefixes_default_mount(tmp_path: Path) -> None:
|
def test_normalize_local_mount_path_prefixes_default_mount(tmp_path: Path) -> None:
|
||||||
|
|
|
||||||
|
|
@ -5,10 +5,10 @@ import json
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
|
from app.agents.new_chat.document_xml import build_document_xml as _build_document_xml
|
||||||
from app.agents.new_chat.middleware.knowledge_search import (
|
from app.agents.new_chat.middleware.knowledge_search import (
|
||||||
KBSearchPlan,
|
KBSearchPlan,
|
||||||
KnowledgeBaseSearchMiddleware,
|
KnowledgeBaseSearchMiddleware,
|
||||||
_build_document_xml,
|
|
||||||
_normalize_optional_date_range,
|
_normalize_optional_date_range,
|
||||||
_parse_kb_search_plan_response,
|
_parse_kb_search_plan_response,
|
||||||
_render_recent_conversation,
|
_render_recent_conversation,
|
||||||
|
|
@ -248,17 +248,10 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
|
||||||
captured.update(kwargs)
|
captured.update(kwargs)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def fake_build_scoped_filesystem(**kwargs):
|
|
||||||
return {}, {}
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||||
fake_search_knowledge_base,
|
fake_search_knowledge_base,
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(
|
|
||||||
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
|
|
||||||
fake_build_scoped_filesystem,
|
|
||||||
)
|
|
||||||
|
|
||||||
llm = FakeLLM(
|
llm = FakeLLM(
|
||||||
json.dumps(
|
json.dumps(
|
||||||
|
|
@ -298,17 +291,10 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
|
||||||
captured.update(kwargs)
|
captured.update(kwargs)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def fake_build_scoped_filesystem(**kwargs):
|
|
||||||
return {}, {}
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||||
fake_search_knowledge_base,
|
fake_search_knowledge_base,
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(
|
|
||||||
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
|
|
||||||
fake_build_scoped_filesystem,
|
|
||||||
)
|
|
||||||
|
|
||||||
middleware = KnowledgeBaseSearchMiddleware(
|
middleware = KnowledgeBaseSearchMiddleware(
|
||||||
llm=FakeLLM("not json"),
|
llm=FakeLLM("not json"),
|
||||||
|
|
@ -334,17 +320,10 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
|
||||||
captured.update(kwargs)
|
captured.update(kwargs)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def fake_build_scoped_filesystem(**kwargs):
|
|
||||||
return {}, {}
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||||
fake_search_knowledge_base,
|
fake_search_knowledge_base,
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(
|
|
||||||
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
|
|
||||||
fake_build_scoped_filesystem,
|
|
||||||
)
|
|
||||||
|
|
||||||
middleware = KnowledgeBaseSearchMiddleware(
|
middleware = KnowledgeBaseSearchMiddleware(
|
||||||
llm=FakeLLM(
|
llm=FakeLLM(
|
||||||
|
|
@ -386,9 +365,6 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
|
||||||
search_called = True
|
search_called = True
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def fake_build_scoped_filesystem(**kwargs):
|
|
||||||
return {}, {}
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"app.agents.new_chat.middleware.knowledge_search.browse_recent_documents",
|
"app.agents.new_chat.middleware.knowledge_search.browse_recent_documents",
|
||||||
fake_browse_recent_documents,
|
fake_browse_recent_documents,
|
||||||
|
|
@ -397,10 +373,6 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
|
||||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||||
fake_search_knowledge_base,
|
fake_search_knowledge_base,
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(
|
|
||||||
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
|
|
||||||
fake_build_scoped_filesystem,
|
|
||||||
)
|
|
||||||
|
|
||||||
llm = FakeLLM(
|
llm = FakeLLM(
|
||||||
json.dumps(
|
json.dumps(
|
||||||
|
|
@ -440,9 +412,6 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
|
||||||
search_captured.update(kwargs)
|
search_captured.update(kwargs)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def fake_build_scoped_filesystem(**kwargs):
|
|
||||||
return {}, {}
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"app.agents.new_chat.middleware.knowledge_search.browse_recent_documents",
|
"app.agents.new_chat.middleware.knowledge_search.browse_recent_documents",
|
||||||
fake_browse_recent_documents,
|
fake_browse_recent_documents,
|
||||||
|
|
@ -451,10 +420,6 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
|
||||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||||
fake_search_knowledge_base,
|
fake_search_knowledge_base,
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(
|
|
||||||
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
|
|
||||||
fake_build_scoped_filesystem,
|
|
||||||
)
|
|
||||||
|
|
||||||
llm = FakeLLM(
|
llm = FakeLLM(
|
||||||
json.dumps(
|
json.dumps(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue